diff --git a/cpp/src/groupby/sort/group_scan_util.cuh b/cpp/src/groupby/sort/group_scan_util.cuh index 013ea924cce..b565e8dc6d8 100644 --- a/cpp/src/groupby/sort/group_scan_util.cuh +++ b/cpp/src/groupby/sort/group_scan_util.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -221,16 +221,18 @@ struct group_scan_functor(0); if (values.has_nulls()) { - auto const binop = row_arg_minmax_fn(values.size(), - *d_flattened_values_ptr, - flattened_null_precedences.data(), - K == aggregation::MIN); + auto const binop = + cudf::reduction::detail::row_arg_minmax_fn(values.size(), + *d_flattened_values_ptr, + flattened_null_precedences.data(), + K == aggregation::MIN); do_scan(count_iter, map_begin, binop); } else { - auto const binop = row_arg_minmax_fn(values.size(), - *d_flattened_values_ptr, - flattened_null_precedences.data(), - K == aggregation::MIN); + auto const binop = + cudf::reduction::detail::row_arg_minmax_fn(values.size(), + *d_flattened_values_ptr, + flattened_null_precedences.data(), + K == aggregation::MIN); do_scan(count_iter, map_begin, binop); } diff --git a/cpp/src/groupby/sort/group_single_pass_reduction_util.cuh b/cpp/src/groupby/sort/group_single_pass_reduction_util.cuh index 4e0820af236..decb127b264 100644 --- a/cpp/src/groupby/sort/group_single_pass_reduction_util.cuh +++ b/cpp/src/groupby/sort/group_single_pass_reduction_util.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -271,10 +271,11 @@ struct group_reduction_functor< auto const count_iter = thrust::make_counting_iterator(0); auto const result_begin = result->mutable_view().template begin(); if (values.has_nulls()) { - auto const binop = row_arg_minmax_fn(values.size(), - *d_flattened_values_ptr, - flattened_null_precedences.data(), - K == aggregation::ARGMIN); + auto const binop = + cudf::reduction::detail::row_arg_minmax_fn(values.size(), + *d_flattened_values_ptr, + flattened_null_precedences.data(), + K == aggregation::ARGMIN); do_reduction(count_iter, result_begin, binop); // Generate bitmask for the output by segmented reduction of the input bitmask. @@ -288,10 +289,11 @@ struct group_reduction_functor< validity.begin(), validity.end(), thrust::identity{}, stream, mr); result->set_null_mask(std::move(null_mask), null_count); } else { - auto const binop = row_arg_minmax_fn(values.size(), - *d_flattened_values_ptr, - flattened_null_precedences.data(), - K == aggregation::ARGMIN); + auto const binop = + cudf::reduction::detail::row_arg_minmax_fn(values.size(), + *d_flattened_values_ptr, + flattened_null_precedences.data(), + K == aggregation::ARGMIN); do_reduction(count_iter, result_begin, binop); } diff --git a/cpp/src/groupby/sort/group_util.cuh b/cpp/src/reductions/arg_minmax_util.cuh similarity index 98% rename from cpp/src/groupby/sort/group_util.cuh rename to cpp/src/reductions/arg_minmax_util.cuh index 31ff29ed4c3..40df23bcd8e 100644 --- a/cpp/src/groupby/sort/group_util.cuh +++ b/cpp/src/reductions/arg_minmax_util.cuh @@ -19,7 +19,7 @@ #include namespace cudf { -namespace groupby { +namespace reduction { namespace detail { /** @@ -62,5 +62,5 @@ struct row_arg_minmax_fn { }; } // namespace detail -} // namespace groupby +} // namespace reduction } // namespace cudf diff --git a/cpp/src/reductions/simple.cuh b/cpp/src/reductions/simple.cuh index 13dfe5cb26c..7dd54e9250a 100644 --- a/cpp/src/reductions/simple.cuh +++ b/cpp/src/reductions/simple.cuh @@ -16,9 +16,13 @@ #pragma once +#include + #include #include +#include #include +#include #include #include #include @@ -28,6 +32,9 @@ #include #include +#include + +#include namespace cudf { namespace reduction { @@ -252,8 +259,7 @@ struct same_element_type_dispatcher { template static constexpr bool is_supported() { - return !(cudf::is_dictionary() || std::is_same_v || - std::is_same_v); + return !(cudf::is_dictionary() || std::is_same_v); } template () && - not cudf::is_fixed_point()>* = nullptr> + std::enable_if_t && + (std::is_same_v || + std::is_same_v)>* = nullptr> + std::unique_ptr operator()(column_view const& input, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + if (input.is_empty()) { return cudf::make_empty_scalar_like(input, stream, mr); } + + auto constexpr is_min_op = std::is_same_v; + + // We will do reduction to find the ARGMIN/ARGMAX index, then return the element at that index. + // When finding ARGMIN, we need to consider nulls as larger than non-null elements, and the + // opposite for ARGMAX. + auto constexpr null_precedence = is_min_op ? cudf::null_order::AFTER : cudf::null_order::BEFORE; + auto const flattened_input = cudf::structs::detail::flatten_nested_columns( + table_view{{input}}, {}, std::vector{null_precedence}); + auto const d_flattened_input_ptr = table_device_view::create(flattened_input, stream); + auto const flattened_null_precedences = + is_min_op ? cudf::detail::make_device_uvector_async(flattened_input.null_orders(), stream) + : rmm::device_uvector(0, stream); + + // Perform reduction to find ARGMIN/ARGMAX. + auto const do_reduction = [&](auto const& binop) { + return thrust::reduce(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.size()), + size_type{0}, + binop); + }; + + auto const minmax_idx = [&] { + if (input.has_nulls()) { + auto const binop = cudf::reduction::detail::row_arg_minmax_fn( + input.size(), *d_flattened_input_ptr, flattened_null_precedences.data(), is_min_op); + return do_reduction(binop); + } else { + auto const binop = cudf::reduction::detail::row_arg_minmax_fn( + input.size(), *d_flattened_input_ptr, flattened_null_precedences.data(), is_min_op); + return do_reduction(binop); + } + }(); + + return cudf::detail::get_element(input, minmax_idx, stream, mr); + } + + template () && !cudf::is_fixed_point() && + !std::is_same_v>* = nullptr> std::unique_ptr operator()(column_view const& col, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) diff --git a/cpp/tests/reductions/reduction_tests.cpp b/cpp/tests/reductions/reduction_tests.cpp index 376f5ce5dd2..2c9279260e7 100644 --- a/cpp/tests/reductions/reduction_tests.cpp +++ b/cpp/tests/reductions/reduction_tests.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -2055,7 +2056,7 @@ TEST_F(ListReductionTest, NonValidListReductionNthElement) struct StructReductionTest : public cudf::test::BaseFixture { using SCW = cudf::test::structs_column_wrapper; - void reduction_test(SCW const& struct_column, + void reduction_test(cudf::column_view const& struct_column, cudf::table_view const& expected_value, bool succeeded_condition, bool is_valid, @@ -2066,7 +2067,7 @@ struct StructReductionTest : public cudf::test::BaseFixture { cudf::reduce(struct_column, agg, cudf::data_type(cudf::type_id::STRUCT)); auto struct_result = dynamic_cast(result.get()); EXPECT_EQ(is_valid, struct_result->is_valid()); - if (is_valid) { CUDF_TEST_EXPECT_TABLES_EQUAL(expected_value, struct_result->view()); } + if (is_valid) { CUDF_TEST_EXPECT_TABLES_EQUIVALENT(expected_value, struct_result->view()); } }; if (succeeded_condition) { @@ -2210,4 +2211,130 @@ TEST_F(StructReductionTest, NonValidStructReductionNthElement) cudf::make_nth_element_aggregation(0, cudf::null_policy::INCLUDE)); } +TEST_F(StructReductionTest, StructReductionMinMaxNoNull) +{ + using INTS_CW = cudf::test::fixed_width_column_wrapper; + using STRINGS_CW = cudf::test::strings_column_wrapper; + using STRUCTS_CW = cudf::test::structs_column_wrapper; + + auto const input = [] { + auto child1 = STRINGS_CW{"año", "bit", "₹1", "aaa", "zit", "bat", "aab", "$1", "€1", "wut"}; + auto child2 = INTS_CW{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + return STRUCTS_CW{{child1, child2}}; + }(); + + { + auto const expected_child1 = STRINGS_CW{"$1"}; + auto const expected_child2 = INTS_CW{8}; + this->reduction_test(input, + cudf::table_view{{expected_child1, expected_child2}}, + true, + true, + cudf::make_min_aggregation()); + } + + { + auto const expected_child1 = STRINGS_CW{"₹1"}; + auto const expected_child2 = INTS_CW{3}; + this->reduction_test(input, + cudf::table_view{{expected_child1, expected_child2}}, + true, + true, + cudf::make_max_aggregation()); + } +} + +TEST_F(StructReductionTest, StructReductionMinMaxSlicedInput) +{ + using INTS_CW = cudf::test::fixed_width_column_wrapper; + using STRINGS_CW = cudf::test::strings_column_wrapper; + using STRUCTS_CW = cudf::test::structs_column_wrapper; + constexpr int32_t dont_care{1}; + + auto const input_original = [] { + auto child1 = STRINGS_CW{"$dont_care", + "$dont_care", + "año", + "bit", + "₹1", + "aaa", + "zit", + "bat", + "aab", + "$1", + "€1", + "wut", + "₹dont_care"}; + auto child2 = INTS_CW{dont_care, dont_care, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, dont_care}; + return STRUCTS_CW{{child1, child2}}; + }(); + + auto const input = cudf::slice(input_original, {2, 12})[0]; + + { + auto const expected_child1 = STRINGS_CW{"$1"}; + auto const expected_child2 = INTS_CW{8}; + this->reduction_test(input, + cudf::table_view{{expected_child1, expected_child2}}, + true, + true, + cudf::make_min_aggregation()); + } + + { + auto const expected_child1 = STRINGS_CW{"₹1"}; + auto const expected_child2 = INTS_CW{3}; + this->reduction_test(input, + cudf::table_view{{expected_child1, expected_child2}}, + true, + true, + cudf::make_max_aggregation()); + } +} + +TEST_F(StructReductionTest, StructReductionMinMaxWithNulls) +{ + using INTS_CW = cudf::test::fixed_width_column_wrapper; + using STRINGS_CW = cudf::test::strings_column_wrapper; + using STRUCTS_CW = cudf::test::structs_column_wrapper; + using cudf::test::iterators::nulls_at; + + auto const input = [] { + auto child1 = STRINGS_CW{{"año", + "bit", + "₹1" /*NULL*/, + "aaa" /*NULL*/, + "zit", + "bat", + "aab", + "$1" /*NULL*/, + "€1" /*NULL*/, + "wut"}, + nulls_at({2, 7})}; + auto child2 = INTS_CW{{1, 2, 3 /*NULL*/, 4 /*NULL*/, 5, 6, 7, 8 /*NULL*/, 9 /*NULL*/, 10}, + nulls_at({2, 7})}; + return STRUCTS_CW{{child1, child2}, nulls_at({3, 8})}; + }(); + + { + auto const expected_child1 = STRINGS_CW{"aab"}; + auto const expected_child2 = INTS_CW{7}; + this->reduction_test(input, + cudf::table_view{{expected_child1, expected_child2}}, + true, + true, + cudf::make_min_aggregation()); + } + + { + auto const expected_child1 = STRINGS_CW{"zit"}; + auto const expected_child2 = INTS_CW{5}; + this->reduction_test(input, + cudf::table_view{{expected_child1, expected_child2}}, + true, + true, + cudf::make_max_aggregation()); + } +} + CUDF_TEST_PROGRAM_MAIN()