Skip to content

Commit

Permalink
Implement groupby MERGE_LISTS and MERGE_SETS aggregates (#8436)
Browse files Browse the repository at this point in the history
Groupby aggregations can be performed for distributed computing by the following approach:
 * Divide the dataset into batches
 * Run separate (distributed) aggregations over those batches on the distributed nodes
 * Merge the results of the step above into one final result by calling `groupby::aggregate` a final time on the master node

This PR supports merging operations for the lists resulted from distributed aggregate `collect_list` and `collect_set`.

Closes #7839.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Jake Hemstad (https://github.com/jrhemstad)
  - Mark Harris (https://github.com/harrism)

URL: #8436
  • Loading branch information
ttnghia authored Jun 22, 2021
1 parent bbf375b commit a9a95f3
Show file tree
Hide file tree
Showing 10 changed files with 1,086 additions and 24 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ add_library(cudf
src/groupby/sort/group_argmin.cu
src/groupby/sort/aggregate.cpp
src/groupby/sort/group_collect.cu
src/groupby/sort/group_merge_lists.cu
src/groupby/sort/group_count.cu
src/groupby/sort/group_max.cu
src/groupby/sort/group_min.cu
Expand Down
49 changes: 44 additions & 5 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class aggregation {
ROW_NUMBER, ///< get row-number of current index (relative to rolling window)
COLLECT_LIST, ///< collect values into a list
COLLECT_SET, ///< collect values into a list without duplicate entries
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
LEAD, ///< window function, accesses row at specified offset following current row
LAG, ///< window function, accesses row at specified offset preceding current row
PTX, ///< PTX UDF based reduction
Expand Down Expand Up @@ -250,7 +252,7 @@ std::unique_ptr<Base> make_collect_list_aggregation(
null_policy null_handling = null_policy::INCLUDE);

/**
* @brief Factory to create a COLLECT_SET aggregation
* @brief Factory to create a COLLECT_SET aggregation.
*
* `COLLECT_SET` returns a lists column of all included elements in the group/series. Within each
* list, the duplicated entries are dropped out such that each entry appears only once.
Expand All @@ -259,16 +261,53 @@ std::unique_ptr<Base> make_collect_list_aggregation(
* of the list rows.
*
* @param null_handling Indicates whether to include/exclude nulls during collection
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal.
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal.
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);

/**
* @brief Factory to create a MERGE_LISTS aggregation.
*
* Given a lists column, this aggregation merges all the lists corresponding to the same key value
* into one list. It is designed specificly to merge the partial results of multiple (distributed)
* groupby `COLLECT_LIST` aggregations into a final `COLLECT_LIST` result. As such, it requires the
* input lists column to be non-nullable (the child column containing list entries is not subjected
* to this requirement).
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_lists_aggregation();

/**
* @brief Factory to create a MERGE_SETS aggregation.
*
* Given a lists column, this aggregation firstly merges all the lists corresponding to the same key
* value into one list, then it drops all the duplicate entries in each lists, producing a lists
* column containing non-repeated entries.
*
* This aggregation is designed specificly to merge the partial results of multiple (distributed)
* groupby `COLLECT_LIST` or `COLLECT_SET` aggregations into a final `COLLECT_SET` result. As such,
* it requires the input lists column to be non-nullable (the child column containing list entries
* is not subjected to this requirement).
*
* In practice, the input (partial results) to this aggregation should be generated by (distributed)
* `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily removing duplicate entries
* for the partial results.
*
* @param nulls_equal Flag to specify whether nulls within each list should be considered equal
* during dropping duplicate list entries.
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal during dropping duplicate list entries.
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_sets_aggregation(null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);

/// Factory to create a LAG aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_lag_aggregation(size_type offset);
Expand Down
84 changes: 83 additions & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ class simple_aggregations_collector { // Declares the interface for the simple
data_type col_type, class collect_list_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class collect_set_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class merge_lists_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class merge_sets_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class lead_lag_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
Expand Down Expand Up @@ -105,6 +109,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class row_number_aggregation const& agg);
virtual void visit(class collect_list_aggregation const& agg);
virtual void visit(class collect_set_aggregation const& agg);
virtual void visit(class merge_lists_aggregation const& agg);
virtual void visit(class merge_sets_aggregation const& agg);
virtual void visit(class lead_lag_aggregation const& agg);
virtual void visit(class udf_aggregation const& agg);
};
Expand Down Expand Up @@ -627,6 +633,66 @@ class collect_set_aggregation final : public rolling_aggregation {
}
};

/**
* @brief Derived aggregation class for specifying MERGE_LISTs aggregation
*/
class merge_lists_aggregation final : public aggregation {
public:
explicit merge_lists_aggregation() : aggregation{MERGE_LISTS} {}

std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<merge_lists_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying MERGE_SETs aggregation
*/
class merge_sets_aggregation final : public aggregation {
public:
explicit merge_sets_aggregation(null_equality nulls_equal, nan_equality nans_equal)
: aggregation{MERGE_SETS}, _nulls_equal(nulls_equal), _nans_equal(nans_equal)
{
}

null_equality _nulls_equal; ///< whether to consider nulls as equal value
nan_equality _nans_equal; ///< whether to consider NaNs as equal value (applicable only to
///< floating point types)

bool is_equal(aggregation const& _other) const override
{
if (!this->aggregation::is_equal(_other)) { return false; }
auto const& other = dynamic_cast<merge_sets_aggregation const&>(_other);
return (_nulls_equal == other._nulls_equal && _nans_equal == other._nans_equal);
}

size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); }

std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<merge_sets_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, cudf::detail::simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }

protected:
size_t hash_impl() const
{
return std::hash<int>{}(static_cast<int>(_nulls_equal) ^ static_cast<int>(_nans_equal));
}
};

/**
* @brief Derived aggregation class for specifying LEAD/LAG window aggregations
*/
Expand Down Expand Up @@ -904,6 +970,18 @@ struct target_type_impl<Source, aggregation::COLLECT_SET> {
using type = cudf::list_view;
};

// Always use list for MERGE_LISTS
template <typename Source>
struct target_type_impl<Source, aggregation::MERGE_LISTS> {
using type = cudf::list_view;
};

// Always use list for MERGE_SETS
template <typename Source>
struct target_type_impl<Source, aggregation::MERGE_SETS> {
using type = cudf::list_view;
};

// Always use Source for LEAD
template <typename Source>
struct target_type_impl<Source, aggregation::LEAD> {
Expand Down Expand Up @@ -1005,6 +1083,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin
return f.template operator()<aggregation::COLLECT_LIST>(std::forward<Ts>(args)...);
case aggregation::COLLECT_SET:
return f.template operator()<aggregation::COLLECT_SET>(std::forward<Ts>(args)...);
case aggregation::MERGE_LISTS:
return f.template operator()<aggregation::MERGE_LISTS>(std::forward<Ts>(args)...);
case aggregation::MERGE_SETS:
return f.template operator()<aggregation::MERGE_SETS>(std::forward<Ts>(args)...);
case aggregation::LEAD:
return f.template operator()<aggregation::LEAD>(std::forward<Ts>(args)...);
case aggregation::LAG:
Expand Down Expand Up @@ -1107,4 +1189,4 @@ constexpr inline bool is_valid_aggregation()
bool is_valid_aggregation(data_type source, aggregation::Kind k);

} // namespace detail
} // namespace cudf
} // namespace cudf
40 changes: 40 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, merge_lists_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, merge_sets_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, lead_lag_aggregation const& agg)
{
Expand Down Expand Up @@ -270,6 +282,16 @@ void aggregation_finalizer::visit(collect_set_aggregation const& agg)
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(merge_lists_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(merge_sets_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(lead_lag_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
Expand Down Expand Up @@ -471,6 +493,24 @@ template std::unique_ptr<aggregation> make_collect_set_aggregation<aggregation>(
template std::unique_ptr<rolling_aggregation> make_collect_set_aggregation<rolling_aggregation>(
null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);

/// Factory to create a MERGE_LISTS aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_lists_aggregation()
{
return std::make_unique<detail::merge_lists_aggregation>();
}
template std::unique_ptr<aggregation> make_merge_lists_aggregation<aggregation>();

/// Factory to create a MERGE_SETS aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_sets_aggregation(null_equality nulls_equal,
nan_equality nans_equal)
{
return std::make_unique<detail::merge_sets_aggregation>(nulls_equal, nans_equal);
}
template std::unique_ptr<aggregation> make_merge_sets_aggregation<aggregation>(null_equality,
nan_equality);

/// Factory to create a LAG aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_lag_aggregation(size_type offset)
Expand Down
88 changes: 78 additions & 10 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,36 +366,32 @@ void aggregate_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation
template <>
void aggregate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregation const& agg)
{
auto null_handling =
dynamic_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
agg.do_hash();

if (cache.has_result(col_idx, agg)) return;
if (cache.has_result(col_idx, agg)) { return; }

auto const null_handling =
dynamic_cast<cudf::detail::collect_list_aggregation const&>(agg)._null_handling;
auto result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);

cache.add_result(col_idx, agg, std::move(result));
};

template <>
void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation const& agg)
{
auto const null_handling =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;

if (cache.has_result(col_idx, agg)) { return; }

auto const null_handling =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_handling;
auto const collect_result = detail::group_collect(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
null_handling,
stream,
mr);
rmm::mr::get_current_device_resource());
auto const nulls_equal =
dynamic_cast<cudf::detail::collect_set_aggregation const&>(agg)._nulls_equal;
auto const nans_equal =
Expand All @@ -406,6 +402,78 @@ void aggregate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr));
};

/**
* @brief Perform merging for the lists that correspond to the same key value.
*
* This aggregation is similar to `COLLECT_LIST` with the following differences:
* - It requires the input values to be a non-nullable lists column, and
* - The values (lists) corresponding to the same key will not result in a list of lists as output
* from `COLLECT_LIST`. Instead, those lists will result in a list generated by merging them
* together.
*
* In practice, this aggregation is used to merge the partial results of multiple (distributed)
* groupby `COLLECT_LIST` aggregations into a final `COLLECT_LIST` result. Those distributed
* aggregations were executed on different values columns partitioned from the original values
* column, then their results were (vertically) concatenated before given as the values column for
* this aggregation.
*/
template <>
void aggregate_result_functor::operator()<aggregation::MERGE_LISTS>(aggregation const& agg)
{
if (cache.has_result(col_idx, agg)) { return; }

cache.add_result(
col_idx,
agg,
detail::group_merge_lists(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
};

/**
* @brief Perform merging for the lists corresponding to the same key value, then dropping duplicate
* list entries.
*
* This aggregation is similar to `COLLECT_SET` with the following differences:
* - It requires the input values to be a non-nullable lists column, and
* - The values (lists) corresponding to the same key will result in a list generated by merging
* them together then dropping duplicate entries.
*
* In practice, this aggregation is used to merge the partial results of multiple (distributed)
* groupby `COLLECT_LIST` or `COLLECT_SET` aggregations into a final `COLLECT_SET` result. Those
* distributed aggregations were executed on different values columns partitioned from the original
* values column, then their results were (vertically) concatenated before given as the values
* column for this aggregation.
*
* Firstly, this aggregation performs `MERGE_LISTS` to concatenate the input lists (corresponding to
* the same key) into intermediate lists, then it calls `lists::drop_list_duplicates` on them to
* remove duplicate list entries. As such, the input (partial results) to this aggregation should be
* generated by (distributed) `COLLECT_LIST` aggregations, not `COLLECT_SET`, to avoid unnecessarily
* removing duplicate entries for the partial results.
*
* Since duplicate list entries will be removed, the parameters `null_equality` and `nan_equality`
* are needed for calling to `lists::drop_list_duplicates`.
*/
template <>
void aggregate_result_functor::operator()<aggregation::MERGE_SETS>(aggregation const& agg)
{
if (cache.has_result(col_idx, agg)) { return; }

auto const merged_result = detail::group_merge_lists(get_grouped_values(),
helper.group_offsets(stream),
helper.num_groups(stream),
stream,
rmm::mr::get_current_device_resource());
auto const merge_sets_agg = dynamic_cast<cudf::detail::merge_sets_aggregation const&>(agg);
cache.add_result(col_idx,
agg,
lists::detail::drop_list_duplicates(lists_column_view(merged_result->view()),
merge_sets_agg._nulls_equal,
merge_sets_agg._nans_equal,
stream,
mr));
};

} // namespace detail

// Sort-based groupby
Expand Down
Loading

0 comments on commit a9a95f3

Please sign in to comment.