Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support collect_set on rolling window #7881

Merged
merged 14 commits into from
May 26, 2021
73 changes: 61 additions & 12 deletions cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <cudf/detail/valid_if.cuh>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/dictionary/dictionary_factories.hpp>
#include <cudf/lists/detail/drop_list_duplicates.hpp>
#include <cudf/rolling.hpp>
#include <cudf/strings/detail/utilities.cuh>
#include <cudf/types.hpp>
Expand Down Expand Up @@ -312,7 +313,8 @@ template <typename InputType,
std::enable_if_t<!std::is_same<InputType, cudf::string_view>::value and
!(op == aggregation::COUNT_VALID || op == aggregation::COUNT_ALL ||
op == aggregation::ROW_NUMBER || op == aggregation::LEAD ||
op == aggregation::LAG || op == aggregation::COLLECT_LIST)>* = nullptr>
op == aggregation::LAG || op == aggregation::COLLECT_LIST ||
op == aggregation::COLLECT_SET)>* = nullptr>
bool __device__ process_rolling_window(column_device_view input,
column_device_view ignored_default_outputs,
mutable_column_device_view output,
Expand Down Expand Up @@ -811,7 +813,7 @@ struct rolling_window_launcher {
typename PrecedingWindowIterator,
typename FollowingWindowIterator>
std::enable_if_t<!(op == aggregation::MEAN || op == aggregation::LEAD || op == aggregation::LAG ||
op == aggregation::COLLECT_LIST),
op == aggregation::COLLECT_LIST || op == aggregation::COLLECT_SET),
std::unique_ptr<column>>
operator()(column_view const& input,
column_view const& default_outputs,
Expand Down Expand Up @@ -1135,16 +1137,15 @@ struct rolling_window_launcher {
std::move(new_gather_map), std::move(new_offsets));
}

template <aggregation::Kind op, typename PrecedingIter, typename FollowingIter>
std::enable_if_t<(op == aggregation::COLLECT_LIST), std::unique_ptr<column>> operator()(
column_view const& input,
column_view const& default_outputs,
PrecedingIter preceding_begin_raw,
FollowingIter following_begin_raw,
size_type min_periods,
std::unique_ptr<aggregation> const& agg,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
template <typename PrecedingIter, typename FollowingIter>
std::unique_ptr<column> collect_list(column_view const& input,
column_view const& default_outputs,
PrecedingIter preceding_begin_raw,
FollowingIter following_begin_raw,
size_type min_periods,
std::unique_ptr<aggregation> const& agg,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(default_outputs.is_empty(),
"COLLECT_LIST window function does not support default values.");
Expand Down Expand Up @@ -1212,6 +1213,54 @@ struct rolling_window_launcher {
stream,
mr);
}

template <aggregation::Kind op, typename PrecedingIter, typename FollowingIter>
std::enable_if_t<(op == aggregation::COLLECT_LIST), std::unique_ptr<column>> operator()(
column_view const& input,
column_view const& default_outputs,
PrecedingIter preceding_begin_raw,
FollowingIter following_begin_raw,
size_type min_periods,
std::unique_ptr<aggregation> const& agg,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return collect_list(input,
default_outputs,
preceding_begin_raw,
following_begin_raw,
min_periods,
agg,
stream,
mr);
}

template <aggregation::Kind op, typename PrecedingIter, typename FollowingIter>
std::enable_if_t<(op == aggregation::COLLECT_SET), std::unique_ptr<column>> operator()(
column_view const& input,
column_view const& default_outputs,
PrecedingIter preceding_begin_raw,
FollowingIter following_begin_raw,
size_type min_periods,
std::unique_ptr<aggregation> const& agg,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto collect_result = collect_list(input,
default_outputs,
preceding_begin_raw,
following_begin_raw,
min_periods,
agg,
stream,
mr);

return lists::detail::drop_list_duplicates(lists_column_view(collect_result->view()),
null_equality::EQUAL,
nan_equality::UNEQUAL,
stream,
sperlingxx marked this conversation as resolved.
Show resolved Hide resolved
mr);
}
};

struct dispatch_rolling {
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ ConfigureTest(ROLLING_TEST
rolling/rolling_test.cpp
rolling/grouped_rolling_test.cpp
rolling/lead_lag_test.cpp
rolling/collect_list_test.cpp
rolling/collect_ops_test.cpp
)

###################################################################################################
Expand Down
Loading