diff --git a/cpp/src/groupby/groupby.cu b/cpp/src/groupby/groupby.cu index a5fd6d6f9bb..f132d6b1511 100644 --- a/cpp/src/groupby/groupby.cu +++ b/cpp/src/groupby/groupby.cu @@ -79,6 +79,44 @@ std::pair, std::vector> groupby::disp groupby::~groupby() = default; namespace { + +/** + * @brief Factory to construct empty result columns. + * + * Adds special handling for COLLECT_LIST/COLLECT_SET, because: + * 1. `make_empty_column()` does not support construction of nested columns. + * 2. Empty lists need empty child columns, to persist type information. + */ +struct empty_column_constructor { + column_view values; + + template + std::unique_ptr operator()() const + { + using namespace cudf; + using namespace cudf::detail; + + if constexpr (k == aggregation::Kind::COLLECT_LIST || k == aggregation::Kind::COLLECT_SET) { + return make_lists_column( + 0, make_empty_column(data_type{type_to_id()}), empty_like(values), 0, {}); + } + + // If `values` is LIST typed, and the aggregation results match the type, + // construct empty results based on `values`. + // Most generally, this applies if input type matches output type. + // + // Note: `target_type_t` is not recursive, and `ValuesType` does not consider children. + // It is important that `COLLECT_LIST` and `COLLECT_SET` are handled before this + // point, because `COLLECT_LIST(LIST)` produces `LIST`, but `target_type_t` + // wouldn't know the difference. + if constexpr (std::is_same_v, ValuesType>) { + return empty_like(values); + } + + return make_empty_column(target_type(values.type(), k)); + } +}; + /// Make an empty table with appropriate types for requested aggs auto empty_results(host_span requests) { @@ -93,7 +131,8 @@ auto empty_results(host_span requests) request.aggregations.end(), std::back_inserter(results), [&request](auto const& agg) { - return make_empty_column(cudf::detail::target_type(request.values.type(), agg->kind)); + return cudf::detail::dispatch_type_and_aggregation( + request.values.type(), agg->kind, empty_column_constructor{request.values}); }); return aggregation_result{std::move(results)}; diff --git a/cpp/tests/groupby/collect_list_tests.cpp b/cpp/tests/groupby/collect_list_tests.cpp index 7580c1c4e3b..9d2141c913c 100644 --- a/cpp/tests/groupby/collect_list_tests.cpp +++ b/cpp/tests/groupby/collect_list_tests.cpp @@ -86,6 +86,21 @@ TYPED_TEST(groupby_collect_list_test, CollectWithNullExclusion) test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg)); } +TYPED_TEST(groupby_collect_list_test, CollectOnEmptyInput) +{ + using K = int32_t; + using V = TypeParam; + + fixed_width_column_wrapper keys{}; + fixed_width_column_wrapper values{}; + + fixed_width_column_wrapper expect_keys{}; + lists_column_wrapper expect_vals{}; + + auto agg = cudf::make_collect_list_aggregation(null_policy::EXCLUDE); + test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg)); +} + TYPED_TEST(groupby_collect_list_test, CollectLists) { using K = int32_t; @@ -124,6 +139,61 @@ TYPED_TEST(groupby_collect_list_test, CollectListsWithNullExclusion) test_single_agg(keys, values, expect_keys, expect_vals, std::move(agg)); } +TYPED_TEST(groupby_collect_list_test, CollectOnEmptyInputLists) +{ + using K = int32_t; + using V = TypeParam; + + using LCW = cudf::test::lists_column_wrapper; + + auto offsets = data_type{type_to_id()}; + + fixed_width_column_wrapper keys{}; + auto values = cudf::make_lists_column(0, make_empty_column(offsets), LCW{}.release(), 0, {}); + + fixed_width_column_wrapper expect_keys{}; + + auto expect_child = + cudf::make_lists_column(0, make_empty_column(offsets), LCW{}.release(), 0, {}); + auto expect_values = + cudf::make_lists_column(0, make_empty_column(offsets), std::move(expect_child), 0, {}); + + auto agg = cudf::make_collect_list_aggregation(); + test_single_agg(keys, values->view(), expect_keys, expect_values->view(), std::move(agg)); +} + +TYPED_TEST(groupby_collect_list_test, CollectOnEmptyInputListsOfStructs) +{ + using K = int32_t; + using V = TypeParam; + + using LCW = cudf::test::lists_column_wrapper; + + fixed_width_column_wrapper keys{}; + auto struct_child = LCW{}; + auto struct_column = structs_column_wrapper{{struct_child}}; + + auto values = cudf::make_lists_column( + 0, make_empty_column(data_type{type_to_id()}), struct_column.release(), 0, {}); + + fixed_width_column_wrapper expect_keys{}; + + auto expect_struct_child = LCW{}; + auto expect_struct_column = structs_column_wrapper{{expect_struct_child}}; + + auto expect_child = + cudf::make_lists_column(0, + make_empty_column(data_type{type_to_id()}), + expect_struct_column.release(), + 0, + {}); + auto expect_values = cudf::make_lists_column( + 0, make_empty_column(data_type{type_to_id()}), std::move(expect_child), 0, {}); + + auto agg = cudf::make_collect_list_aggregation(); + test_single_agg(keys, values->view(), expect_keys, expect_values->view(), std::move(agg)); +} + TYPED_TEST(groupby_collect_list_test, dictionary) { using K = int32_t; diff --git a/cpp/tests/groupby/collect_set_tests.cpp b/cpp/tests/groupby/collect_set_tests.cpp index ce3a9a49372..d5a881a1993 100644 --- a/cpp/tests/groupby/collect_set_tests.cpp +++ b/cpp/tests/groupby/collect_set_tests.cpp @@ -58,8 +58,7 @@ TYPED_TEST_CASE(CollectSetTypedTest, FixedWidthTypesNotBool); TYPED_TEST(CollectSetTypedTest, TrivialInput) { // Empty input - // TODO: Enable this test after issue#7611 has been fixed - // test_single_agg(COL_K{}, COL_V{}, COL_K{}, COL_V{}, COLLECT_SET); + test_single_agg(COL_K{}, COL_V{}, COL_K{}, LCL_V{}, CollectSetTest::collect_set()); // Single key input { diff --git a/cpp/tests/groupby/nth_element_tests.cpp b/cpp/tests/groupby/nth_element_tests.cpp index ec0265a3023..5630cba09da 100644 --- a/cpp/tests/groupby/nth_element_tests.cpp +++ b/cpp/tests/groupby/nth_element_tests.cpp @@ -362,5 +362,45 @@ TEST_F(groupby_nth_element_string_test, dictionary) keys, vals, expect_keys, expect_vals->view(), cudf::make_nth_element_aggregation(2)); } +template +struct groupby_nth_element_lists_test : BaseFixture { +}; + +TYPED_TEST_CASE(groupby_nth_element_lists_test, FixedWidthTypesWithoutFixedPoint); + +TYPED_TEST(groupby_nth_element_lists_test, Basics) +{ + using K = int32_t; + using V = TypeParam; + + using lists = cudf::test::lists_column_wrapper; + + auto keys = fixed_width_column_wrapper{1, 1, 2, 2, 3, 3}; + auto values = lists{{1, 2}, {3, 4}, {5, 6, 7}, lists{}, {9, 10}, {11}}; + + auto expected_keys = fixed_width_column_wrapper{1, 2, 3}; + auto expected_values = lists{{1, 2}, {5, 6, 7}, {9, 10}}; + + test_single_agg( + keys, values, expected_keys, expected_values, cudf::make_nth_element_aggregation(0)); +} + +TYPED_TEST(groupby_nth_element_lists_test, EmptyInput) +{ + using K = int32_t; + using V = TypeParam; + + using lists = cudf::test::lists_column_wrapper; + + auto keys = fixed_width_column_wrapper{}; + auto values = lists{}; + + auto expected_keys = fixed_width_column_wrapper{}; + auto expected_values = lists{}; + + test_single_agg( + keys, values, expected_keys, expected_values, cudf::make_nth_element_aggregation(2)); +} + } // namespace test } // namespace cudf