From cd6dff34c7ca0810d535ba3a0ee6b3fa3e788187 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Mon, 21 Nov 2022 18:12:50 -0500 Subject: [PATCH] Workaround for CUB segmented-sort bug with boolean keys --- cpp/src/sort/segmented_sort.cu | 2 +- cpp/tests/sort/segmented_sort_tests.cpp | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/cpp/src/sort/segmented_sort.cu b/cpp/src/sort/segmented_sort.cu index 685d8aa3ec1..dc87d5ea326 100644 --- a/cpp/src/sort/segmented_sort.cu +++ b/cpp/src/sort/segmented_sort.cu @@ -52,7 +52,7 @@ struct column_fast_sort_fn { static bool is_fast_sort_supported(column_view const& col) { return !col.has_nulls() and - (cudf::is_integral(col.type()) || + ((cudf::is_integral(col.type()) && !cudf::is_boolean(col.type())) || (cudf::is_fixed_point(col.type()) and (col.type().id() != type_id::DECIMAL128))); } diff --git a/cpp/tests/sort/segmented_sort_tests.cpp b/cpp/tests/sort/segmented_sort_tests.cpp index c1a742e63b8..ad905b6d04f 100644 --- a/cpp/tests/sort/segmented_sort_tests.cpp +++ b/cpp/tests/sort/segmented_sort_tests.cpp @@ -274,5 +274,24 @@ TEST_F(SegmentedSortInt, ErrorsMismatchArgSizes) CUDF_EXPECT_NO_THROW(cudf::segmented_sort_by_key(input1, input1, segments)); } +TEST_F(SegmentedSortInt, Bool) +{ + cudf::test::fixed_width_column_wrapper col1{ + {true, false, false, true, true, true, true, true, true, true, true, true, true, false, + false, false, false, true, false, false, true, true, true, true, true, true, true, false, + true, false, true, true, true, true, true, true, false, true, false, false}}; + + cudf::test::fixed_width_column_wrapper segments{{0, 5, 10, 15, 20, 25, 30, 40}}; + + auto test_col = cudf::column_view{col1}; + auto result = cudf::segmented_sorted_order(cudf::table_view({test_col}), segments); + + cudf::test::fixed_width_column_wrapper expected( + {1, 2, 0, 3, 4, 5, 6, 7, 8, 9, 13, 14, 10, 11, 12, 15, 16, 18, 19, 17, + 20, 21, 22, 23, 24, 27, 29, 25, 26, 28, 36, 38, 39, 30, 31, 32, 33, 34, 35, 37}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); +} + } // namespace test } // namespace cudf