diff --git a/cpp/cmake/thirdparty/patches/cub_segmented_sort_with_bool_key.diff b/cpp/cmake/thirdparty/patches/cub_segmented_sort_with_bool_key.diff new file mode 100644 index 00000000000..7c40fd4287d --- /dev/null +++ b/cpp/cmake/thirdparty/patches/cub_segmented_sort_with_bool_key.diff @@ -0,0 +1,14 @@ +diff --git a/dependencies/cub/cub/agent/agent_sub_warp_merge_sort.cuh b/dependencies/cub/cub/agent/agent_sub_warp_merge_sort.cuh +index ad65f2a3..ad45a21e 100644 +--- a/dependencies/cub/cub/agent/agent_sub_warp_merge_sort.cuh ++++ b/dependencies/cub/cub/agent/agent_sub_warp_merge_sort.cuh +@@ -221,7 +221,8 @@ public: + using UnsignedBitsT = typename Traits::UnsignedBits; + UnsignedBitsT default_key_bits = IS_DESCENDING ? Traits::LOWEST_KEY + : Traits::MAX_KEY; +- KeyT oob_default = reinterpret_cast(default_key_bits); ++ KeyT oob_default = std::is_same_v ? !IS_DESCENDING ++ : reinterpret_cast(default_key_bits); + + WarpLoadKeysT(storage.load_keys) + .Load(keys_input, keys, segment_size, oob_default); diff --git a/cpp/cmake/thirdparty/patches/thrust_override.json b/cpp/cmake/thirdparty/patches/thrust_override.json index f1908a64719..ded2b90eeba 100644 --- a/cpp/cmake/thirdparty/patches/thrust_override.json +++ b/cpp/cmake/thirdparty/patches/thrust_override.json @@ -27,6 +27,11 @@ "file" : "${current_json_dir}/thrust_faster_scan_compile_times.diff", "issue" : "Improve Thrust scan compile times by reducing the number of kernels generated [https://github.com/rapidsai/cudf/pull/8183]", "fixed_in" : "" + }, + { + "file" : "${current_json_dir}/cub_segmented_sort_with_bool_key.diff", + "issue" : "Fix an error in CUB DeviceSegmentedSort when the keys are bool type [https://github.com/NVIDIA/cub/issues/594]", + "fixed_in" : "2.1" } ] } diff --git a/cpp/src/sort/segmented_sort.cu b/cpp/src/sort/segmented_sort.cu index dc87d5ea326..685d8aa3ec1 100644 --- a/cpp/src/sort/segmented_sort.cu +++ b/cpp/src/sort/segmented_sort.cu @@ -52,7 +52,7 @@ struct column_fast_sort_fn { static bool is_fast_sort_supported(column_view const& col) { return !col.has_nulls() and - ((cudf::is_integral(col.type()) && !cudf::is_boolean(col.type())) || + (cudf::is_integral(col.type()) || (cudf::is_fixed_point(col.type()) and (col.type().id() != type_id::DECIMAL128))); }