diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index 8c9aa30ba16..678f202d106 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -198,6 +198,7 @@ add_library(cudf
src/groupby/sort/group_argmin.cu
src/groupby/sort/aggregate.cpp
src/groupby/sort/group_collect.cu
+ src/groupby/sort/group_merge_lists.cu
src/groupby/sort/group_count.cu
src/groupby/sort/group_max.cu
src/groupby/sort/group_min.cu
diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp
index 2600926d363..5fab284d506 100644
--- a/cpp/include/cudf/aggregation.hpp
+++ b/cpp/include/cudf/aggregation.hpp
@@ -78,6 +78,8 @@ class aggregation {
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
+ MERGE_LISTS, ///< merge multiple lists values into one list
+ MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
@@ -250,7 +252,7 @@ std::unique_ptr make_collect_list_aggregation(
null_policy null_handling = null_policy::INCLUDE);
/**
- * @brief Factory to create a COLLECT_SET aggregation
+ * @brief Factory to create a COLLECT_SET aggregation.
*
* `COLLECT_SET` returns a lists column of all included elements in the group/series. Within each
* list, the duplicated entries are dropped out such that each entry appears only once.
@@ -259,16 +261,53 @@ std::unique_ptr make_collect_list_aggregation(
* of the list rows.
*
* @param null_handling Indicates whether to include/exclude nulls during collection
- * @param nulls_equal Flag to specify whether null entries within each list should be considered
- * equal
- * @param nans_equal Flag to specify whether NaN values in floating point column should be
- * considered equal
+ * @param nulls_equal Flag to specify whether null entries within each list should be considered
+ * equal.
+ * @param nans_equal Flag to specify whether NaN values in floating point column should be
+ * considered equal.
*/
template
std::unique_ptr make_collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);
+/**
+ * @brief Factory to create a MERGE_LISTS aggregation.
+ *
+ * Given a lists column, this aggregation merges all the lists corresponding to the same key value
+ * into one list. It is designed specificly to merge the partial results of multiple (distributed)
+ * groupby `COLLECT_LIST` aggregations into a final `COLLECT_LIST` result. As such, it requires the
+ * input lists column to be non-nullable (the child column containing list entries is not subjected
+ * to this requirement).
+ */
+template
+std::unique_ptr make_merge_lists_aggregation();
+
+/**
+ * @brief Factory to create a MERGE_SETS aggregation.
+ *
+ * Given a lists column, this aggregation firstly merges all the lists corresponding to the same key
+ * value into one list, then it drops all the duplicate entries in each lists, producing a lists
+ * column containing non-repeated entries.
+ *
+ * This aggregation is designed specificly to merge the partial results of multiple (distributed)
+ * groupby `COLLECT_LIST` or `COLLECT_SET` aggregations into a final `COLLECT_SET` result. As such,
+ * it requires the input lists column to be non-nullable (the child column containing list entries
+ * is not subjected to this requirement).
+ *
+ * In practice, the input (partial results) to this aggregation should be generated by (distributed)
+ * `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily removing duplicate entries
+ * for the partial results.
+ *
+ * @param nulls_equal Flag to specify whether nulls within each list should be considered equal
+ * during dropping duplicate list entries.
+ * @param nans_equal Flag to specify whether NaN values in floating point column should be
+ * considered equal during dropping duplicate list entries.
+ */
+template
+std::unique_ptr make_merge_sets_aggregation(null_equality nulls_equal = null_equality::EQUAL,
+ nan_equality nans_equal = nan_equality::UNEQUAL);
+
/// Factory to create a LAG aggregation
template
std::unique_ptr make_lag_aggregation(size_type offset);
diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp
index e230ce0b757..373d695a5b5 100644
--- a/cpp/include/cudf/detail/aggregation/aggregation.hpp
+++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp
@@ -75,6 +75,10 @@ class simple_aggregations_collector { // Declares the interface for the simple
data_type col_type, class collect_list_aggregation const& agg);
virtual std::vector> visit(data_type col_type,
class collect_set_aggregation const& agg);
+ virtual std::vector> visit(data_type col_type,
+ class merge_lists_aggregation const& agg);
+ virtual std::vector> visit(data_type col_type,
+ class merge_sets_aggregation const& agg);
virtual std::vector> visit(data_type col_type,
class lead_lag_aggregation const& agg);
virtual std::vector> visit(data_type col_type,
@@ -105,6 +109,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class row_number_aggregation const& agg);
virtual void visit(class collect_list_aggregation const& agg);
virtual void visit(class collect_set_aggregation const& agg);
+ virtual void visit(class merge_lists_aggregation const& agg);
+ virtual void visit(class merge_sets_aggregation const& agg);
virtual void visit(class lead_lag_aggregation const& agg);
virtual void visit(class udf_aggregation const& agg);
};
@@ -627,6 +633,66 @@ class collect_set_aggregation final : public rolling_aggregation {
}
};
+/**
+ * @brief Derived aggregation class for specifying MERGE_LISTs aggregation
+ */
+class merge_lists_aggregation final : public aggregation {
+ public:
+ explicit merge_lists_aggregation() : aggregation{MERGE_LISTS} {}
+
+ std::unique_ptr clone() const override
+ {
+ return std::make_unique(*this);
+ }
+ std::vector> get_simple_aggregations(
+ data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override
+ {
+ return collector.visit(col_type, *this);
+ }
+ void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
+};
+
+/**
+ * @brief Derived aggregation class for specifying MERGE_SETs aggregation
+ */
+class merge_sets_aggregation final : public aggregation {
+ public:
+ explicit merge_sets_aggregation(null_equality nulls_equal, nan_equality nans_equal)
+ : aggregation{MERGE_SETS}, _nulls_equal(nulls_equal), _nans_equal(nans_equal)
+ {
+ }
+
+ null_equality _nulls_equal; ///< whether to consider nulls as equal value
+ nan_equality _nans_equal; ///< whether to consider NaNs as equal value (applicable only to
+ ///< floating point types)
+
+ bool is_equal(aggregation const& _other) const override
+ {
+ if (!this->aggregation::is_equal(_other)) { return false; }
+ auto const& other = dynamic_cast(_other);
+ return (_nulls_equal == other._nulls_equal && _nans_equal == other._nans_equal);
+ }
+
+ size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); }
+
+ std::unique_ptr clone() const override
+ {
+ return std::make_unique(*this);
+ }
+ std::vector> get_simple_aggregations(
+ data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override
+ {
+ return collector.visit(col_type, *this);
+ }
+ void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
+
+ protected:
+ size_t hash_impl() const
+ {
+ return std::hash{}(static_cast(_nulls_equal) ^ static_cast(_nans_equal));
+ }
+};
+
/**
* @brief Derived aggregation class for specifying LEAD/LAG window aggregations
*/
@@ -904,6 +970,18 @@ struct target_type_impl {
using type = cudf::list_view;
};
+// Always use list for MERGE_LISTS
+template
+struct target_type_impl {
+ using type = cudf::list_view;
+};
+
+// Always use list for MERGE_SETS
+template
+struct target_type_impl {
+ using type = cudf::list_view;
+};
+
// Always use Source for LEAD
template
struct target_type_impl {
@@ -1005,6 +1083,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin
return f.template operator()(std::forward(args)...);
case aggregation::COLLECT_SET:
return f.template operator()(std::forward(args)...);
+ case aggregation::MERGE_LISTS:
+ return f.template operator()(std::forward(args)...);
+ case aggregation::MERGE_SETS:
+ return f.template operator()(std::forward(args)...);
case aggregation::LEAD:
return f.template operator()(std::forward(args)...);
case aggregation::LAG:
@@ -1107,4 +1189,4 @@ constexpr inline bool is_valid_aggregation()
bool is_valid_aggregation(data_type source, aggregation::Kind k);
} // namespace detail
-} // namespace cudf
\ No newline at end of file
+} // namespace cudf
diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp
index a878dbe1535..f0fd865f685 100644
--- a/cpp/src/aggregation/aggregation.cpp
+++ b/cpp/src/aggregation/aggregation.cpp
@@ -154,6 +154,18 @@ std::vector> simple_aggregations_collector::visit(
return visit(col_type, static_cast(agg));
}
+std::vector> simple_aggregations_collector::visit(
+ data_type col_type, merge_lists_aggregation const& agg)
+{
+ return visit(col_type, static_cast(agg));
+}
+
+std::vector> simple_aggregations_collector::visit(
+ data_type col_type, merge_sets_aggregation const& agg)
+{
+ return visit(col_type, static_cast(agg));
+}
+
std::vector> simple_aggregations_collector::visit(
data_type col_type, lead_lag_aggregation const& agg)
{
@@ -270,6 +282,16 @@ void aggregation_finalizer::visit(collect_set_aggregation const& agg)
visit(static_cast(agg));
}
+void aggregation_finalizer::visit(merge_lists_aggregation const& agg)
+{
+ visit(static_cast(agg));
+}
+
+void aggregation_finalizer::visit(merge_sets_aggregation const& agg)
+{
+ visit(static_cast(agg));
+}
+
void aggregation_finalizer::visit(lead_lag_aggregation const& agg)
{
visit(static_cast(agg));
@@ -471,6 +493,24 @@ template std::unique_ptr make_collect_set_aggregation(
template std::unique_ptr make_collect_set_aggregation(
null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);
+/// Factory to create a MERGE_LISTS aggregation
+template
+std::unique_ptr make_merge_lists_aggregation()
+{
+ return std::make_unique();
+}
+template std::unique_ptr make_merge_lists_aggregation();
+
+/// Factory to create a MERGE_SETS aggregation
+template
+std::unique_ptr make_merge_sets_aggregation(null_equality nulls_equal,
+ nan_equality nans_equal)
+{
+ return std::make_unique(nulls_equal, nans_equal);
+}
+template std::unique_ptr make_merge_sets_aggregation(null_equality,
+ nan_equality);
+
/// Factory to create a LAG aggregation
template
std::unique_ptr make_lag_aggregation(size_type offset)
diff --git a/cpp/src/groupby/sort/aggregate.cpp b/cpp/src/groupby/sort/aggregate.cpp
index 9d8f145a7c9..5e202b9ef3f 100644
--- a/cpp/src/groupby/sort/aggregate.cpp
+++ b/cpp/src/groupby/sort/aggregate.cpp
@@ -366,36 +366,32 @@ void aggregate_result_functor::operator()(aggregation
template <>
void aggregate_result_functor::operator()(aggregation const& agg)
{
- auto null_handling =
- dynamic_cast(agg)._null_handling;
- agg.do_hash();
-
- if (cache.has_result(col_idx, agg)) return;
+ if (cache.has_result(col_idx, agg)) { return; }
+ auto const null_handling =
+ dynamic_cast(agg)._null_handling;
auto result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);
-
cache.add_result(col_idx, agg, std::move(result));
};
template <>
void aggregate_result_functor::operator()(aggregation const& agg)
{
- auto const null_handling =
- dynamic_cast(agg)._null_handling;
-
if (cache.has_result(col_idx, agg)) { return; }
+ auto const null_handling =
+ dynamic_cast(agg)._null_handling;
auto const collect_result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
- mr);
+ rmm::mr::get_current_device_resource());
auto const nulls_equal =
dynamic_cast(agg)._nulls_equal;
auto const nans_equal =
@@ -406,6 +402,78 @@ void aggregate_result_functor::operator()(aggregation
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr));
};
+
+/**
+ * @brief Perform merging for the lists that correspond to the same key value.
+ *
+ * This aggregation is similar to `COLLECT_LIST` with the following differences:
+ * - It requires the input values to be a non-nullable lists column, and
+ * - The values (lists) corresponding to the same key will not result in a list of lists as output
+ * from `COLLECT_LIST`. Instead, those lists will result in a list generated by merging them
+ * together.
+ *
+ * In practice, this aggregation is used to merge the partial results of multiple (distributed)
+ * groupby `COLLECT_LIST` aggregations into a final `COLLECT_LIST` result. Those distributed
+ * aggregations were executed on different values columns partitioned from the original values
+ * column, then their results were (vertically) concatenated before given as the values column for
+ * this aggregation.
+ */
+template <>
+void aggregate_result_functor::operator()(aggregation const& agg)
+{
+ if (cache.has_result(col_idx, agg)) { return; }
+
+ cache.add_result(
+ col_idx,
+ agg,
+ detail::group_merge_lists(
+ get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
+};
+
+/**
+ * @brief Perform merging for the lists corresponding to the same key value, then dropping duplicate
+ * list entries.
+ *
+ * This aggregation is similar to `COLLECT_SET` with the following differences:
+ * - It requires the input values to be a non-nullable lists column, and
+ * - The values (lists) corresponding to the same key will result in a list generated by merging
+ * them together then dropping duplicate entries.
+ *
+ * In practice, this aggregation is used to merge the partial results of multiple (distributed)
+ * groupby `COLLECT_LIST` or `COLLECT_SET` aggregations into a final `COLLECT_SET` result. Those
+ * distributed aggregations were executed on different values columns partitioned from the original
+ * values column, then their results were (vertically) concatenated before given as the values
+ * column for this aggregation.
+ *
+ * Firstly, this aggregation performs `MERGE_LISTS` to concatenate the input lists (corresponding to
+ * the same key) into intermediate lists, then it calls `lists::drop_list_duplicates` on them to
+ * remove duplicate list entries. As such, the input (partial results) to this aggregation should be
+ * generated by (distributed) `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily
+ * removing duplicate entries for the partial results.
+ *
+ * Since duplicate list entries will be removed, the parameters `null_equality` and `nan_equality`
+ * are needed for calling to `lists::drop_list_duplicates`.
+ */
+template <>
+void aggregate_result_functor::operator()(aggregation const& agg)
+{
+ if (cache.has_result(col_idx, agg)) { return; }
+
+ auto const merged_result = detail::group_merge_lists(get_grouped_values(),
+ helper.group_offsets(stream),
+ helper.num_groups(stream),
+ stream,
+ rmm::mr::get_current_device_resource());
+ auto const merge_sets_agg = dynamic_cast(agg);
+ cache.add_result(col_idx,
+ agg,
+ lists::detail::drop_list_duplicates(lists_column_view(merged_result->view()),
+ merge_sets_agg._nulls_equal,
+ merge_sets_agg._nans_equal,
+ stream,
+ mr));
+};
+
} // namespace detail
// Sort-based groupby
diff --git a/cpp/src/groupby/sort/group_merge_lists.cu b/cpp/src/groupby/sort/group_merge_lists.cu
new file mode 100644
index 00000000000..3043d107635
--- /dev/null
+++ b/cpp/src/groupby/sort/group_merge_lists.cu
@@ -0,0 +1,74 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+namespace cudf {
+namespace groupby {
+namespace detail {
+std::unique_ptr group_merge_lists(column_view const& values,
+ cudf::device_span group_offsets,
+ size_type num_groups,
+ rmm::cuda_stream_view stream,
+ rmm::mr::device_memory_resource* mr)
+{
+ CUDF_EXPECTS(values.type().id() == type_id::LIST,
+ "Input to `group_merge_lists` must be a lists column.");
+ CUDF_EXPECTS(!values.nullable(),
+ "Input to `group_merge_lists` must be a non-nullable lists column.");
+
+ auto offsets_column = make_numeric_column(
+ data_type(type_to_id()), num_groups + 1, mask_state::UNALLOCATED, stream, mr);
+
+ // Generate offsets of the output lists column by gathering from the provided group offsets and
+ // the input list offsets.
+ //
+ // For example:
+ // values = [[2, 1], [], [4, -1, -2], [], [, 4, ]]
+ // list_offsets = [0, 2, 2, 5, 5 8]
+ // group_offsets = [0, 3, 5]
+ //
+ // then, the output offsets_column is [0, 5, 8].
+ //
+ thrust::gather(rmm::exec_policy(stream),
+ group_offsets.begin(),
+ group_offsets.end(),
+ lists_column_view(values).offsets_begin(),
+ offsets_column->mutable_view().template begin());
+
+ // The child column of the output lists column is just copied from the input column.
+ auto child_column =
+ std::make_unique(lists_column_view(values).get_sliced_child(stream), stream, mr);
+
+ return make_lists_column(num_groups,
+ std::move(offsets_column),
+ std::move(child_column),
+ 0,
+ rmm::device_buffer{},
+ stream,
+ mr);
+}
+
+} // namespace detail
+} // namespace groupby
+} // namespace cudf
diff --git a/cpp/src/groupby/sort/group_reductions.hpp b/cpp/src/groupby/sort/group_reductions.hpp
index 7cc0aea8362..3390af29330 100644
--- a/cpp/src/groupby/sort/group_reductions.hpp
+++ b/cpp/src/groupby/sort/group_reductions.hpp
@@ -348,19 +348,19 @@ std::unique_ptr group_nth_element(column_view const& values,
*
* @code{.pseudo}
* values = [2, 1, 4, -1, -2, , 4, ]
- * group_offsets = [0, 3, 5, 7, 8]
+ * group_offsets = [0, 3, 5, 7, 8]
* num_groups = 4
*
- * group_collect = [[2, 1, 4], [-1, -2] [, 4], []]
+ * group_collect(...) = [[2, 1, 4], [-1, -2], [, 4], []]
* @endcode
*
- * @param values Grouped values to collect
- * @param group_offsets Offsets of groups' starting points within @p values
- * @param num_groups Number of groups
+ * @param values Grouped values to collect.
+ * @param group_offsets Offsets of groups' starting points within @p values.
+ * @param num_groups Number of groups.
* @param null_handling Exclude nulls while counting if null_policy::EXCLUDE,
- * Include nulls if null_policy::INCLUDE.
+ * include nulls if null_policy::INCLUDE.
* @param stream CUDA stream used for device memory operations and kernel launches.
- * @param mr Device memory resource used to allocate the returned column's device memory
+ * @param mr Device memory resource used to allocate the returned column's device memory.
*/
std::unique_ptr group_collect(column_view const& values,
cudf::device_span group_offsets,
@@ -369,6 +369,29 @@ std::unique_ptr group_collect(column_view const& values,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr);
+/**
+ * @brief Internal API to merge grouped lists into one list.
+ *
+ * @code{.pseudo}
+ * values = [[2, 1], [], [4, -1, -2], [], [, 4, ]]
+ * group_offsets = [0, 3, 5]
+ * num_groups = 2
+ *
+ * group_merge_lists(...) = [[2, 1, 4, -1, -2], [, 4, ]]
+ * @endcode
+ *
+ * @param values Grouped values (lists column) to collect.
+ * @param group_offsets Offsets of groups' starting points within @p values.
+ * @param num_groups Number of groups.
+ * @param stream CUDA stream used for device memory operations and kernel launches.
+ * @param mr Device memory resource used to allocate the returned column's device memory.
+ */
+std::unique_ptr group_merge_lists(column_view const& values,
+ cudf::device_span group_offsets,
+ size_type num_groups,
+ rmm::cuda_stream_view stream,
+ rmm::mr::device_memory_resource* mr);
+
/** @endinternal
*
*/
diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt
index 2ac9382bd1e..215d93d4119 100644
--- a/cpp/tests/CMakeLists.txt
+++ b/cpp/tests/CMakeLists.txt
@@ -1,4 +1,4 @@
-#=============================================================================
+#=============================================================================
# Copyright (c) 2018-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -67,6 +67,8 @@ ConfigureTest(GROUPBY_TEST
groupby/max_tests.cpp
groupby/mean_tests.cpp
groupby/median_tests.cpp
+ groupby/merge_lists_tests.cpp
+ groupby/merge_sets_tests.cpp
groupby/min_scan_tests.cpp
groupby/nth_element_tests.cpp
groupby/nunique_tests.cpp
diff --git a/cpp/tests/groupby/merge_lists_tests.cpp b/cpp/tests/groupby/merge_lists_tests.cpp
new file mode 100644
index 00000000000..7851565d86a
--- /dev/null
+++ b/cpp/tests/groupby/merge_lists_tests.cpp
@@ -0,0 +1,388 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+using namespace cudf::test::iterators;
+
+namespace {
+constexpr bool print_all{false}; // For debugging
+constexpr int32_t null{0}; // Mark for null elements
+
+using vcol_views = std::vector;
+
+auto merge_lists(vcol_views const& keys_cols, vcol_views const& values_cols)
+{
+ // Append all the keys and lists together.
+ auto const keys = cudf::concatenate(keys_cols);
+ auto const values = cudf::concatenate(values_cols);
+
+ std::vector requests;
+ requests.emplace_back(cudf::groupby::aggregation_request());
+ requests[0].values = *values;
+ requests[0].aggregations.emplace_back(cudf::make_merge_lists_aggregation());
+
+ auto gb_obj = cudf::groupby::groupby(cudf::table_view({*keys}));
+ auto result = gb_obj.aggregate(requests);
+ return std::make_pair(std::move(result.first->release()[0]),
+ std::move(result.second[0].results[0]));
+}
+
+} // namespace
+
+template
+struct GroupbyMergeListsTypedTest : public cudf::test::BaseFixture {
+};
+
+using FixedWidthTypesNotBool = cudf::test::Concat;
+TYPED_TEST_CASE(GroupbyMergeListsTypedTest, FixedWidthTypesNotBool);
+
+TYPED_TEST(GroupbyMergeListsTypedTest, InvalidInput)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys = keys_col{1, 2, 3};
+
+ // The input lists column must NOT be nullable.
+ auto const lists = lists_col{{lists_col{1}, lists_col{} /*NULL*/, lists_col{2}}, null_at(1)};
+ EXPECT_THROW(merge_lists({keys}, {lists}), cudf::logic_error);
+
+ // The input column must be a lists column.
+ auto const non_lists = keys_col{1, 2, 3, 4, 5};
+ EXPECT_THROW(merge_lists({keys}, {non_lists}), cudf::logic_error);
+}
+
+TYPED_TEST(GroupbyMergeListsTypedTest, EmptyInput)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ // Keys and lists columns are all empty.
+ auto const keys = keys_col{};
+ auto const lists0 = lists_col{{1, 2, 3}, {4, 5, 6}};
+ auto const lists = cudf::empty_like(lists0);
+
+ auto const [out_keys, out_lists] = merge_lists(vcol_views{keys}, vcol_views{*lists});
+ auto const expected_keys = keys_col{};
+ auto const expected_lists = cudf::empty_like(lists0);
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeListsTypedTest, InputWithoutNull)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = keys_col{1, 2};
+ auto const keys2 = keys_col{1, 3};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ auto const lists1 = lists_col{
+ {1, 2, 3}, // key = 1
+ {4, 5, 6} // key = 2
+ };
+ auto const lists2 = lists_col{
+ {10, 11}, // key = 1
+ {11, 12} // key = 3
+ };
+ auto const lists3 = lists_col{
+ {20, 21, 22}, // key = 2
+ {23, 24, 25}, // key = 3
+ {24, 25, 26} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ {1, 2, 3, 10, 11}, // key = 1
+ {4, 5, 6, 20, 21, 22}, // key = 2
+ {11, 12, 23, 24, 25}, // key = 3
+ {24, 25, 26} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeListsTypedTest, InputHasNulls)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = keys_col{1, 2};
+ auto const keys2 = keys_col{1, 3};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ // Note that the null elements here are not sorted, while the results from current collect_list
+ // are sorted.
+ auto const lists1 = lists_col{
+ lists_col{{1, null, 3}, null_at(1)}, // key = 1
+ lists_col{4, 5, 6} // key = 2
+ };
+ auto const lists2 = lists_col{
+ lists_col{10, 11}, // key = 1
+ lists_col{{null, null, null}, all_nulls()} // key = 3
+ };
+ auto const lists3 = lists_col{
+ lists_col{20, 21, 22}, // key = 2
+ lists_col{{null, 24, null}, nulls_at({0, 2})}, // key = 3
+ lists_col{{24, 25, 26}, no_nulls()} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ lists_col{{1, null, 3, 10, 11}, null_at(1)}, // key = 1
+ lists_col{4, 5, 6, 20, 21, 22}, // key = 2
+ lists_col{{null, null, null, null, 24, null}, nulls_at({0, 1, 2, 3, 5})}, // key = 3
+ lists_col{24, 25, 26} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeListsTypedTest, InputHasEmptyLists)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = keys_col{1, 2};
+ auto const keys2 = keys_col{1, 3};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ auto const lists1 = lists_col{
+ {1, 2, 3}, // key = 1
+ {} // key = 2
+ };
+ auto const lists2 = lists_col{
+ {}, // key = 1
+ {11, 12} // key = 3
+ };
+ auto const lists3 = lists_col{
+ {}, // key = 2
+ {}, // key = 3
+ {24, 25, 26} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ {1, 2, 3}, // key = 1
+ {}, // key = 2
+ {11, 12}, // key = 3
+ {24, 25, 26} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeListsTypedTest, InputHasNullsAndEmptyLists)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = keys_col{1, 2, 3};
+ auto const keys2 = keys_col{1, 3, 4};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ // Note that the null elements here are not sorted, while the results from current collect_list
+ // are sorted.
+ auto const lists1 = lists_col{
+ lists_col{{1, null, 3}, null_at(1)}, // key = 1
+ lists_col{}, // key = 2
+ lists_col{4, 5} // key = 3
+ };
+ auto const lists2 = lists_col{
+ lists_col{10, 11}, // key = 1
+ lists_col{{null, null, null}, all_nulls()}, // key = 3
+ lists_col{} // key = 4
+ };
+ auto const lists3 = lists_col{
+ lists_col{20, 21, 22}, // key = 2
+ lists_col{{null, 24, null}, nulls_at({0, 2})}, // key = 3
+ lists_col{{24, 25, 26}, no_nulls()} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ lists_col{{1, null, 3, 10, 11}, null_at(1)}, // key = 1
+ lists_col{20, 21, 22}, // key = 2
+ lists_col{{4, 5, null, null, null, null, 24, null}, nulls_at({2, 3, 4, 5, 7})}, // key = 3
+ lists_col{24, 25, 26} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeListsTypedTest, InputHasListsOfLists)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = keys_col{1, 2};
+ auto const keys2 = keys_col{1, 3};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ auto const lists1 = lists_col{
+ lists_col{lists_col{1, 2, 3}, lists_col{4}, lists_col{5, 6}}, // key = 1
+ lists_col{lists_col{}, lists_col{7}} // key = 2
+ };
+ auto const lists2 = lists_col{
+ lists_col{lists_col{}, lists_col{8, 9}}, // key = 1
+ lists_col{lists_col{11}, lists_col{12, 13}} // key = 3
+ };
+ auto const lists3 = lists_col{
+ lists_col{lists_col{14}, lists_col{15, 16, 17, 18}}, // key = 2
+ lists_col{lists_col{}}, // key = 3
+ lists_col{lists_col{17, 18, 19, 20, 21}, lists_col{18, 19, 20}} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ lists_col{
+ lists_col{1, 2, 3}, lists_col{4}, lists_col{5, 6}, lists_col{}, lists_col{8, 9}}, // key = 1
+ lists_col{lists_col{}, lists_col{7}, lists_col{14}, lists_col{15, 16, 17, 18}}, // key = 2
+ lists_col{lists_col{11}, lists_col{12, 13}, lists_col{}}, // key = 3
+ lists_col{lists_col{17, 18, 19, 20, 21}, lists_col{18, 19, 20}} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeListsTypedTest, SlicedColumnsInput)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1_original = keys_col{1, 2, 4, 5, 6, 7, 8, 9, 10};
+ auto const keys2_original = keys_col{0, 0, 1, 1, 1, 3, 4, 5, 6};
+ auto const keys3_original = keys_col{0, 1, 2, 3, 4, 5, 6, 7, 8};
+
+ auto const keys1 = cudf::slice(keys1_original, {0, 2})[0]; // { 1, 2 }
+ auto const keys2 = cudf::slice(keys2_original, {4, 6})[0]; // { 1, 3 }
+ auto const keys3 = cudf::slice(keys3_original, {2, 5})[0]; // { 2, 3, 4 }
+
+ auto const lists1_original = lists_col{
+ {10, 11, 12},
+ {12, 13, 14},
+ {1, 2, 3}, // key = 1
+ {4, 5, 6} // key = 2
+ };
+ auto const lists2_original = lists_col{{1, 2},
+ {10, 11}, // key = 1
+ {11, 12}, // key = 3
+ {13},
+ {14},
+ {15, 16}};
+ auto const lists3_original = lists_col{{20, 21, 22}, // key = 2
+ {23, 24, 25}, // key = 3
+ {24, 25, 26}, // key = 4
+ {1, 2, 3, 4, 5}};
+
+ auto const lists1 = cudf::slice(lists1_original, {2, 4})[0];
+ auto const lists2 = cudf::slice(lists2_original, {1, 3})[0];
+ auto const lists3 = cudf::slice(lists3_original, {0, 3})[0];
+
+ auto const [out_keys, out_lists] =
+ merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ {1, 2, 3, 10, 11}, // key = 1
+ {4, 5, 6, 20, 21, 22}, // key = 2
+ {11, 12, 23, 24, 25}, // key = 3
+ {24, 25, 26} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+struct GroupbyMergeListsTest : public cudf::test::BaseFixture {
+};
+
+TEST_F(GroupbyMergeListsTest, StringsColumnInput)
+{
+ using strings_col = cudf::test::strings_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = strings_col{"dog", "unknown"};
+ auto const keys2 = strings_col{"banana", "unknown", "dog"};
+ auto const keys3 = strings_col{"apple", "dog", "water melon"};
+
+ auto const lists1 = lists_col{
+ lists_col{"Poodle", "Golden Retriever", "Corgi"}, // key = "dog"
+ lists_col{{"Whale", "" /*NULL*/, "Polar Bear"}, null_at(1)} // key = "unknown"
+ };
+ auto const lists2 = lists_col{
+ lists_col{"Green", "Yellow"}, // key = "banana"
+ lists_col{}, // key = "unknown"
+ lists_col{{"" /*NULL*/, "" /*NULL*/}, all_nulls()} // key = "dog"
+ };
+ auto const lists3 = lists_col{
+ lists_col{"Fuji", "Red Delicious"}, // key = "apple"
+ lists_col{{"" /*NULL*/, "German Shepherd", "" /*NULL*/}, nulls_at({0, 2})}, // key = "dog"
+ lists_col{{"Seeedless", "Mini"}, no_nulls()} // key = "water melon"
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_lists(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = strings_col{"apple", "banana", "dog", "unknown", "water melon"};
+ auto const expected_lists = lists_col{
+ lists_col{"Fuji", "Red Delicious"}, // key = "apple"
+ lists_col{"Green", "Yellow"}, // key = "banana"
+ lists_col{{
+ "Poodle",
+ "Golden Retriever",
+ "Corgi",
+ "" /*NULL*/,
+ "" /*NULL*/,
+ "" /*NULL*/,
+ "German Shepherd",
+ "" /*NULL*/
+ },
+ nulls_at({3, 4, 5, 7})}, // key = "dog"
+ lists_col{{"Whale", "" /*NULL*/, "Polar Bear"}, null_at(1)}, // key = "unknown"
+ lists_col{{"Seeedless", "Mini"}, no_nulls()} // key = "water melon"
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
diff --git a/cpp/tests/groupby/merge_sets_tests.cpp b/cpp/tests/groupby/merge_sets_tests.cpp
new file mode 100644
index 00000000000..1365245c8af
--- /dev/null
+++ b/cpp/tests/groupby/merge_sets_tests.cpp
@@ -0,0 +1,345 @@
+/*
+ * Copyright (c) 2021, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+
+using namespace cudf::test::iterators;
+
+namespace {
+constexpr bool print_all{false}; // For debugging
+constexpr int32_t null{0}; // Mark for null elements
+
+using vcol_views = std::vector;
+
+auto merge_sets(vcol_views const& keys_cols, vcol_views const& values_cols)
+{
+ // Append all the keys and lists together.
+ auto const keys = cudf::concatenate(keys_cols);
+ auto const values = cudf::concatenate(values_cols);
+
+ std::vector requests;
+ requests.emplace_back(cudf::groupby::aggregation_request());
+ requests[0].values = *values;
+ requests[0].aggregations.emplace_back(cudf::make_merge_sets_aggregation());
+
+ auto gb_obj = cudf::groupby::groupby(cudf::table_view({*keys}));
+ auto result = gb_obj.aggregate(requests);
+ return std::make_pair(std::move(result.first->release()[0]),
+ std::move(result.second[0].results[0]));
+}
+
+} // namespace
+
+template
+struct GroupbyMergeSetsTypedTest : public cudf::test::BaseFixture {
+};
+
+using FixedWidthTypesNotBool = cudf::test::Concat;
+TYPED_TEST_CASE(GroupbyMergeSetsTypedTest, FixedWidthTypesNotBool);
+
+TYPED_TEST(GroupbyMergeSetsTypedTest, InvalidInput)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys = keys_col{1, 2, 3};
+
+ // The input lists column must NOT be nullable.
+ auto const lists = lists_col{{lists_col{1}, lists_col{} /*NULL*/, lists_col{2}}, null_at(1)};
+ EXPECT_THROW(merge_sets({keys}, {lists}), cudf::logic_error);
+
+ // The input column must be a lists column.
+ auto const non_lists = keys_col{1, 2, 3, 4, 5};
+ EXPECT_THROW(merge_sets({keys}, {non_lists}), cudf::logic_error);
+}
+
+TYPED_TEST(GroupbyMergeSetsTypedTest, EmptyInput)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ // Keys and lists columns are all empty.
+ auto const keys = keys_col{};
+ auto const lists0 = lists_col{{1, 2, 3}, {4, 5, 6}};
+ auto const lists = cudf::empty_like(lists0);
+
+ auto const [out_keys, out_lists] = merge_sets(vcol_views{keys}, vcol_views{*lists});
+ auto const expected_keys = keys_col{};
+ auto const expected_lists = cudf::empty_like(lists0);
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(*expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeSetsTypedTest, InputWithoutNull)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = keys_col{1, 2};
+ auto const keys2 = keys_col{1, 3};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ auto const lists1 = lists_col{
+ {1, 2, 3, 4, 5, 6}, // key = 1
+ {10, 11, 12, 13, 14, 15} // key = 2
+ };
+ auto const lists2 = lists_col{
+ {4, 5, 6, 7, 8, 9}, // key = 1
+ {20, 21, 22, 23, 24, 25} // key = 3
+ };
+ auto const lists3 = lists_col{
+ {11, 12}, // key = 2
+ {23, 24, 25, 26, 27, 28}, // key = 3
+ {30, 31, 32} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ {1, 2, 3, 4, 5, 6, 7, 8, 9}, // key = 1
+ {10, 11, 12, 13, 14, 15}, // key = 2
+ {20, 21, 22, 23, 24, 25, 26, 27, 28}, // key = 3
+ {30, 31, 32} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeSetsTypedTest, InputHasNulls)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ // Note that the null elements here are not sorted, while the results from current collect_list
+ // and collect_set are sorted.
+ auto const keys1 = keys_col{1, 2};
+ auto const keys2 = keys_col{1, 3};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ auto const lists1 = lists_col{
+ lists_col{{1, null, null, null, 5, 6}, nulls_at({1, 2, 3})}, // key = 1
+ lists_col{10, 11, 12, 13, 14, 15} // key = 2
+ };
+ auto const lists2 = lists_col{
+ lists_col{{null, null, 6, 7, 8, 9}, nulls_at({0, 1})}, // key = 1
+ lists_col{{null, 21, 22, 23, 24, 25}, null_at(0)} // key = 3
+ };
+ auto const lists3 = lists_col{
+ lists_col{11, 12}, // key = 2
+ lists_col{23, 24, 25, 26, 27, 28}, // key = 3
+ lists_col{{30, null, 32}, null_at(1)} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ lists_col{{1, 5, 6, 7, 8, 9, null}, null_at(6)}, // key = 1
+ lists_col{10, 11, 12, 13, 14, 15}, // key = 2
+ lists_col{{21, 22, 23, 24, 25, 26, 27, 28, null}, null_at(8)}, // key = 3
+ lists_col{{30, 32, null}, null_at(2)} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeSetsTypedTest, InputHasEmptyLists)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = keys_col{1, 2};
+ auto const keys2 = keys_col{1, 3};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ auto const lists1 = lists_col{
+ {1, 2, 3}, // key = 1
+ {} // key = 2
+ };
+ auto const lists2 = lists_col{
+ {0, 1, 2, 3, 4, 5}, // key = 1
+ {11, 12, 12, 12, 12, 12} // key = 3
+ };
+ auto const lists3 = lists_col{
+ {}, // key = 2
+ {}, // key = 3
+ {24, 25, 26} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ {0, 1, 2, 3, 4, 5}, // key = 1
+ {}, // key = 2
+ {11, 12}, // key = 3
+ {24, 25, 26} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeSetsTypedTest, InputHasNullsAndEmptyLists)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ // Note that the null elements here are not sorted, while the results from current collect_list
+ // and collect_set are sorted.
+ auto const keys1 = keys_col{1, 2, 3};
+ auto const keys2 = keys_col{1, 3, 4};
+ auto const keys3 = keys_col{2, 3, 4};
+
+ auto const lists1 = lists_col{
+ lists_col{{null, 1, 2, 3}, null_at(0)}, // key = 1
+ lists_col{}, // key = 2
+ lists_col{} // key = 3
+ };
+ auto const lists2 = lists_col{
+ lists_col{0, 1, 2, 3, 4, 5}, // key = 1
+ lists_col{{null, 11, null, 12, 12, 12, 12, 12}, nulls_at({0, 2})}, // key = 3
+ lists_col{20} // key = 4
+ };
+ auto const lists3 = lists_col{
+ lists_col{}, // key = 2
+ lists_col{}, // key = 3
+ lists_col{{24, 25, null, null, null, 26}, nulls_at({2, 3, 4})} // key = 4
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ lists_col{{0, 1, 2, 3, 4, 5, null}, null_at(6)}, // key = 1
+ lists_col{}, // key = 2
+ lists_col{{11, 12, null}, null_at(2)}, // key = 3
+ lists_col{{20, 24, 25, 26, null}, null_at(4)} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+TYPED_TEST(GroupbyMergeSetsTypedTest, SlicedColumnsInput)
+{
+ using keys_col = cudf::test::fixed_width_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1_original = keys_col{1, 2, 4, 5, 6, 7, 8, 9, 10};
+ auto const keys2_original = keys_col{0, 0, 1, 1, 1, 3, 4, 5, 6};
+ auto const keys3_original = keys_col{0, 1, 2, 3, 4, 5, 6, 7, 8};
+
+ auto const keys1 = cudf::slice(keys1_original, {0, 2})[0]; // { 1, 2 }
+ auto const keys2 = cudf::slice(keys2_original, {4, 6})[0]; // { 1, 3 }
+ auto const keys3 = cudf::slice(keys3_original, {2, 5})[0]; // { 2, 3, 4 }
+
+ auto const lists1_original = lists_col{
+ {10, 11, 12, 10, 11, 12, 10, 11, 12},
+ {12, 13, 12, 13, 12, 13, 12, 13, 14},
+ {1, 2, 3, 1, 2, 3, 1, 2, 3}, // key = 1
+ {4, 5, 6, 4, 5, 6, 4, 5, 6} // key = 2
+ };
+ auto const lists2_original = lists_col{{1, 1, 1, 1, 1, 1, 1, 2},
+ {10, 11, 11, 11, 11, 11, 12}, // key = 1
+ {11, 12, 13, 12, 13, 12, 13, 12, 13, 14, 15}, // key = 3
+ {13, 14, 15},
+ {14, 15, 16},
+ {15, 16}};
+ auto const lists3_original = lists_col{{20, 21, 20, 21, 20, 21, 20, 21, 22}, // key = 2
+ {23, 24, 25, 23, 24, 25}, // key = 3
+ {24, 25, 26}, // key = 4
+ {1, 2, 3, 4, 5}};
+
+ auto const lists1 = cudf::slice(lists1_original, {2, 4})[0];
+ auto const lists2 = cudf::slice(lists2_original, {1, 3})[0];
+ auto const lists3 = cudf::slice(lists3_original, {0, 3})[0];
+
+ auto const [out_keys, out_lists] =
+ merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = keys_col{1, 2, 3, 4};
+ auto const expected_lists = lists_col{
+ {1, 2, 3, 10, 11, 12}, // key = 1
+ {4, 5, 6, 20, 21, 22}, // key = 2
+ {11, 12, 13, 14, 15, 23, 24, 25}, // key = 3
+ {24, 25, 26} // key = 4
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}
+
+struct GroupbyMergeSetsTest : public cudf::test::BaseFixture {
+};
+
+TEST_F(GroupbyMergeSetsTest, StringsColumnInput)
+{
+ using strings_col = cudf::test::strings_column_wrapper;
+ using lists_col = cudf::test::lists_column_wrapper;
+
+ auto const keys1 = strings_col{"apple", "dog", "unknown"};
+ auto const keys2 = strings_col{"banana", "unknown", "dog"};
+ auto const keys3 = strings_col{"apple", "dog", "water melon"};
+
+ auto const lists1 = lists_col{
+ lists_col{"Fuji", "Honey Bee"}, // key = "apple"
+ lists_col{"Poodle", "Golden Retriever", "Corgi"}, // key = "dog"
+ lists_col{{"Whale", "" /*NULL*/, "Polar Bear"}, null_at(1)} // key = "unknown"
+ };
+ auto const lists2 = lists_col{
+ lists_col{"Green", "Yellow"}, // key = "banana"
+ lists_col{}, // key = "unknown"
+ lists_col{{"" /*NULL*/, "" /*NULL*/, "" /*NULL*/}, all_nulls()} // key = "dog"
+ };
+ auto const lists3 = lists_col{
+ lists_col{"Fuji", "Red Delicious"}, // key = "apple"
+ lists_col{{"" /*NULL*/, "Corgi", "German Shepherd", "" /*NULL*/, "Golden Retriever"},
+ nulls_at({0, 3})}, // key = "dog"
+ lists_col{{"Seeedless", "Mini"}, no_nulls()} // key = "water melon"
+ };
+
+ auto const [out_keys, out_lists] =
+ merge_sets(vcol_views{keys1, keys2, keys3}, vcol_views{lists1, lists2, lists3});
+ auto const expected_keys = strings_col{"apple", "banana", "dog", "unknown", "water melon"};
+ auto const expected_lists = lists_col{
+ lists_col{"Fuji", "Honey Bee", "Red Delicious"}, // key = "apple"
+ lists_col{"Green", "Yellow"}, // key = "banana"
+ lists_col{{
+ "Corgi", "German Shepherd", "Golden Retriever", "Poodle", "" /*NULL*/
+ },
+ null_at(4)}, // key = "dog"
+ lists_col{{"Polar Bear", "Whale", "" /*NULL*/}, null_at(2)}, // key = "unknown"
+ lists_col{{"Mini", "Seeedless"}, no_nulls()} // key = "water melon"
+ };
+
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_keys, *out_keys, print_all);
+ CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lists, *out_lists, print_all);
+}