Skip to content

Commit

Permalink
Expose stream parameter to public rolling APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
srinivasyadav18 committed Jun 7, 2024
1 parent dc829b8 commit 0044f31
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 69 deletions.
20 changes: 20 additions & 0 deletions cpp/include/cudf/rolling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace cudf {
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] agg The rolling window aggregation type (SUM, MAX, MIN, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -67,6 +68,7 @@ std::unique_ptr<column> rolling_window(
size_type following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -77,6 +79,7 @@ std::unique_ptr<column> rolling_window(
* size_type following_window,
* size_type min_periods,
* rolling_aggregation const& agg,
* rmm::cuda_stream_view stream,
* rmm::device_async_resource_ref mr)
*
* @param default_outputs A column of per-row default values to be returned instead
Expand All @@ -90,6 +93,7 @@ std::unique_ptr<column> rolling_window(
size_type following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -227,6 +231,7 @@ struct window_bounds {
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] aggr The rolling window aggregation type (SUM, MAX, MIN, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -238,6 +243,7 @@ std::unique_ptr<column> grouped_rolling_window(
size_type following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -249,6 +255,7 @@ std::unique_ptr<column> grouped_rolling_window(
* size_type following_window,
* size_type min_periods,
* rolling_aggregation const& aggr,
* rmm::cuda_stream_view stream,
* rmm::device_async_resource_ref mr)
*/
std::unique_ptr<column> grouped_rolling_window(
Expand All @@ -258,6 +265,7 @@ std::unique_ptr<column> grouped_rolling_window(
window_bounds following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -269,6 +277,7 @@ std::unique_ptr<column> grouped_rolling_window(
* size_type following_window,
* size_type min_periods,
* rolling_aggregation const& aggr,
* rmm::cuda_stream_view stream,,
* rmm::device_async_resource_ref mr)
*
* @param default_outputs A column of per-row default values to be returned instead
Expand All @@ -283,6 +292,7 @@ std::unique_ptr<column> grouped_rolling_window(
size_type following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -295,6 +305,7 @@ std::unique_ptr<column> grouped_rolling_window(
* size_type following_window,
* size_type min_periods,
* rolling_aggregation const& aggr,
* rmm::cuda_stream_view stream,
* rmm::device_async_resource_ref mr)
*/
std::unique_ptr<column> grouped_rolling_window(
Expand All @@ -305,6 +316,7 @@ std::unique_ptr<column> grouped_rolling_window(
window_bounds following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -387,6 +399,7 @@ std::unique_ptr<column> grouped_rolling_window(
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] aggr The rolling window aggregation type (SUM, MAX, MIN, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -400,6 +413,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(
size_type following_window_in_days,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -415,6 +429,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(
* size_type following_window_in_days,
* size_type min_periods,
* rolling_aggregation const& aggr,
* rmm::cuda_stream_view stream,
* rmm::device_async_resource_ref mr)
*
* The `preceding_window_in_days` and `following_window_in_days` are specified as a `window_bounds`
Expand All @@ -429,6 +444,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(
window_bounds following_window_in_days,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -536,6 +552,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] aggr The rolling window aggregation type (SUM, MAX, MIN, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -549,6 +566,7 @@ std::unique_ptr<column> grouped_range_rolling_window(
range_window_bounds const& following,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -582,6 +600,7 @@ std::unique_ptr<column> grouped_range_rolling_window(
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] agg The rolling window aggregation type (sum, max, min, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -592,6 +611,7 @@ std::unique_ptr<column> rolling_window(
column_view const& following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
Expand Down
129 changes: 70 additions & 59 deletions cpp/src/rolling/grouped_rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,59 +40,6 @@
#include <thrust/partition.h>

namespace cudf {
std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
column_view const& input,
size_type preceding_window,
size_type following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::device_async_resource_ref mr)
{
return grouped_rolling_window(group_keys,
input,
window_bounds::get(preceding_window),
window_bounds::get(following_window),
min_periods,
aggr,
mr);
}

std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
column_view const& input,
window_bounds preceding_window,
window_bounds following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::device_async_resource_ref mr)
{
return grouped_rolling_window(group_keys,
input,
empty_like(input)->view(),
preceding_window,
following_window,
min_periods,
aggr,
mr);
}

std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
column_view const& input,
column_view const& default_outputs,
size_type preceding_window,
size_type following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::device_async_resource_ref mr)
{
return grouped_rolling_window(group_keys,
input,
default_outputs,
window_bounds::get(preceding_window),
window_bounds::get(following_window),
min_periods,
aggr,
mr);
}

namespace detail {

Expand Down Expand Up @@ -237,8 +184,8 @@ std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,

if (group_keys.num_columns() == 0) {
// No Groupby columns specified. Treat as one big group.
return rolling_window(
input, default_outputs, preceding_window, following_window, min_periods, aggr, mr);
return detail::rolling_window(
input, default_outputs, preceding_window, following_window, min_periods, aggr, stream, mr);
}

using sort_groupby_helper = cudf::groupby::detail::sort::sort_groupby_helper;
Expand Down Expand Up @@ -306,6 +253,7 @@ std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
window_bounds following_window_bounds,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return detail::grouped_rolling_window(group_keys,
Expand All @@ -315,7 +263,67 @@ std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
following_window_bounds,
min_periods,
aggr,
cudf::get_default_stream(),
stream,
mr);
}

std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
column_view const& input,
size_type preceding_window,
size_type following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return grouped_rolling_window(group_keys,
input,
window_bounds::get(preceding_window),
window_bounds::get(following_window),
min_periods,
aggr,
stream,
mr);
}

std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
column_view const& input,
window_bounds preceding_window,
window_bounds following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return detail::grouped_rolling_window(group_keys,
input,
empty_like(input)->view(),
preceding_window,
following_window,
min_periods,
aggr,
stream,
mr);
}

std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
column_view const& input,
column_view const& default_outputs,
size_type preceding_window,
size_type following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return detail::grouped_rolling_window(group_keys,
input,
default_outputs,
window_bounds::get(preceding_window),
window_bounds::get(following_window),
min_periods,
aggr,
stream,
mr);
}

Expand Down Expand Up @@ -1199,6 +1207,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(table_view const& grou
size_type following_window_in_days,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
Expand All @@ -1213,7 +1222,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(table_view const& grou
following,
min_periods,
aggr,
cudf::get_default_stream(),
stream,
mr);
}

Expand All @@ -1237,6 +1246,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(table_view const& grou
window_bounds following_window_in_days,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
Expand All @@ -1253,7 +1263,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(table_view const& grou
following,
min_periods,
aggr,
cudf::get_default_stream(),
stream,
mr);
}

Expand All @@ -1277,6 +1287,7 @@ std::unique_ptr<column> grouped_range_rolling_window(table_view const& group_key
range_window_bounds const& following,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
Expand All @@ -1288,7 +1299,7 @@ std::unique_ptr<column> grouped_range_rolling_window(table_view const& group_key
following,
min_periods,
aggr,
cudf::get_default_stream(),
stream,
mr);
}

Expand Down
Loading

0 comments on commit 0044f31

Please sign in to comment.