Skip to content

Commit

Permalink
Support collect_set on rolling window (#7881)
Browse files Browse the repository at this point in the history
This pull request is to support collect_set on rolling window, which is required in #7809.

Authors:
  - Alfred Xu (https://github.com/sperlingxx)

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - Nghia Truong (https://github.com/ttnghia)

URL: #7881
  • Loading branch information
sperlingxx authored May 26, 2021
1 parent 773fc7a commit be05a00
Show file tree
Hide file tree
Showing 6 changed files with 794 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ class collect_list_aggregation final : public rolling_aggregation {
/**
* @brief Derived aggregation class for specifying COLLECT_SET aggregation
*/
class collect_set_aggregation final : public aggregation {
class collect_set_aggregation final : public rolling_aggregation {
public:
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ std::unique_ptr<Base> make_collect_set_aggregation(null_policy null_handling,
}
template std::unique_ptr<aggregation> make_collect_set_aggregation<aggregation>(
null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);
template std::unique_ptr<rolling_aggregation> make_collect_set_aggregation<rolling_aggregation>(
null_policy null_handling, null_equality nulls_equal, nan_equality nans_equal);

/// Factory to create a LAG aggregation
template <typename Base = aggregation>
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/rolling/rolling_collect_list.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ std::unique_ptr<column> rolling_collect_list(column_view const& input,
PrecedingIter preceding_begin_raw,
FollowingIter following_begin_raw,
size_type min_periods,
rolling_aggregation const& agg,
null_policy null_handling,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
Expand Down Expand Up @@ -321,7 +321,6 @@ std::unique_ptr<column> rolling_collect_list(column_view const& input,

// If gather_map collects null elements, and null_policy == EXCLUDE,
// those elements must be filtered out, and offsets recomputed.
auto null_handling = dynamic_cast<collect_list_aggregation const&>(agg)._null_handling;
if (null_handling == null_policy::EXCLUDE && input.has_nulls()) {
auto num_child_nulls = count_child_nulls(input, gather_map, stream);
if (num_child_nulls != 0) {
Expand Down
30 changes: 29 additions & 1 deletion cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <cudf/detail/valid_if.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/dictionary/dictionary_factories.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/rolling.hpp>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/types.hpp>
Expand Down Expand Up @@ -581,6 +582,14 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre
return {};
}

// COLLECT_SET aggregations do not peform a rolling operation at all. They get processed
// entirely in the finalize() step.
std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, cudf::detail::collect_set_aggregation const& agg) override
{
return {};
}

// LEAD and LAG have custom behaviors for non fixed-width types.
std::vector<std::unique_ptr<aggregation>> visit(
data_type col_type, cudf::detail::lead_lag_aggregation const& agg) override
Expand Down Expand Up @@ -678,11 +687,30 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
preceding_window_begin,
following_window_begin,
min_periods,
agg,
agg._null_handling,
stream,
mr);
}

// perform the actual COLLECT_SET operation entirely.
void visit(cudf::detail::collect_set_aggregation const& agg) override
{
auto const collected_list = rolling_collect_list(input,
default_outputs,
preceding_window_begin,
following_window_begin,
min_periods,
agg._null_handling,
stream,
rmm::mr::get_current_device_resource());

result = lists::detail::drop_list_duplicates(lists_column_view(collected_list->view()),
null_equality::EQUAL,
nan_equality::UNEQUAL,
stream,
mr);
}

std::unique_ptr<column> get_result()
{
CUDF_EXPECTS(result != nullptr,
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ ConfigureTest(ROLLING_TEST
rolling/lead_lag_test.cpp
rolling/range_window_bounds_test.cpp
rolling/range_rolling_window_test.cpp
rolling/collect_list_test.cpp)
rolling/collect_ops_test.cpp
)

###################################################################################################
# - filling test ----------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit be05a00

Please sign in to comment.