Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose stream parameter to public rolling APIs #15865

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions cpp/include/cudf/rolling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ namespace cudf {
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] agg The rolling window aggregation type (SUM, MAX, MIN, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -67,6 +68,7 @@ std::unique_ptr<column> rolling_window(
size_type following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -77,6 +79,7 @@ std::unique_ptr<column> rolling_window(
* size_type following_window,
* size_type min_periods,
* rolling_aggregation const& agg,
* rmm::cuda_stream_view stream,
* rmm::device_async_resource_ref mr)
*
* @param default_outputs A column of per-row default values to be returned instead
Expand All @@ -90,6 +93,7 @@ std::unique_ptr<column> rolling_window(
size_type following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -227,6 +231,7 @@ struct window_bounds {
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] aggr The rolling window aggregation type (SUM, MAX, MIN, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -238,6 +243,7 @@ std::unique_ptr<column> grouped_rolling_window(
size_type following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -249,6 +255,7 @@ std::unique_ptr<column> grouped_rolling_window(
* size_type following_window,
* size_type min_periods,
* rolling_aggregation const& aggr,
* rmm::cuda_stream_view stream,
* rmm::device_async_resource_ref mr)
*/
std::unique_ptr<column> grouped_rolling_window(
Expand All @@ -258,6 +265,7 @@ std::unique_ptr<column> grouped_rolling_window(
window_bounds following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -269,6 +277,7 @@ std::unique_ptr<column> grouped_rolling_window(
* size_type following_window,
* size_type min_periods,
* rolling_aggregation const& aggr,
* rmm::cuda_stream_view stream,,
* rmm::device_async_resource_ref mr)
*
* @param default_outputs A column of per-row default values to be returned instead
Expand All @@ -283,6 +292,7 @@ std::unique_ptr<column> grouped_rolling_window(
size_type following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -295,6 +305,7 @@ std::unique_ptr<column> grouped_rolling_window(
* size_type following_window,
* size_type min_periods,
* rolling_aggregation const& aggr,
* rmm::cuda_stream_view stream,
* rmm::device_async_resource_ref mr)
*/
std::unique_ptr<column> grouped_rolling_window(
Expand All @@ -305,6 +316,7 @@ std::unique_ptr<column> grouped_rolling_window(
window_bounds following_window,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -387,6 +399,7 @@ std::unique_ptr<column> grouped_rolling_window(
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] aggr The rolling window aggregation type (SUM, MAX, MIN, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -400,6 +413,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(
size_type following_window_in_days,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -415,6 +429,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(
* size_type following_window_in_days,
* size_type min_periods,
* rolling_aggregation const& aggr,
* rmm::cuda_stream_view stream,
* rmm::device_async_resource_ref mr)
*
* The `preceding_window_in_days` and `following_window_in_days` are specified as a `window_bounds`
Expand All @@ -429,6 +444,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(
window_bounds following_window_in_days,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -536,6 +552,7 @@ std::unique_ptr<column> grouped_time_range_rolling_window(
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] aggr The rolling window aggregation type (SUM, MAX, MIN, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -549,6 +566,7 @@ std::unique_ptr<column> grouped_range_rolling_window(
range_window_bounds const& following,
size_type min_periods,
rolling_aggregation const& aggr,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/**
Expand Down Expand Up @@ -582,6 +600,7 @@ std::unique_ptr<column> grouped_range_rolling_window(
* @param[in] min_periods Minimum number of observations in window required to have a value,
* otherwise element `i` is null.
* @param[in] agg The rolling window aggregation type (sum, max, min, etc.)
* @param[in] stream CUDA stream used for device memory operations and kernel launches
* @param[in] mr Device memory resource used to allocate the returned column's device memory
*
* @returns A nullable output column containing the rolling window results
Expand All @@ -592,6 +611,7 @@ std::unique_ptr<column> rolling_window(
column_view const& following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
Expand Down
16 changes: 12 additions & 4 deletions cpp/include/cudf/rolling/range_window_bounds.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,22 @@ struct range_window_bounds {
* @brief Factory method to construct a bounded window boundary.
*
* @param boundary Finite window boundary
* @param stream CUDA stream used for device memory operations and kernel launches
* @return A bounded window boundary object
*/
static range_window_bounds get(scalar const& boundary);
static range_window_bounds get(scalar const& boundary,
rmm::cuda_stream_view stream = cudf::get_default_stream());

/**
* @brief Factory method to construct a window boundary
* limited to the value of the current row
*
* @param type The datatype of the window boundary
* @param stream CUDA stream used for device memory operations and kernel launches
* @return A "current row" window boundary object
*/
static range_window_bounds current_row(data_type type);
static range_window_bounds current_row(data_type type,
rmm::cuda_stream_view stream = cudf::get_default_stream());

/**
* @brief Whether or not the window is bounded to the current row
Expand All @@ -81,9 +85,11 @@ struct range_window_bounds {
* @brief Factory method to construct an unbounded window boundary.
*
* @param type The datatype of the window boundary
* @param stream CUDA stream used for device memory operations and kernel launches
* @return An unbounded window boundary object
*/
static range_window_bounds unbounded(data_type type);
static range_window_bounds unbounded(data_type type,
rmm::cuda_stream_view stream = cudf::get_default_stream());

/**
* @brief Whether or not the window is unbounded
Expand All @@ -107,7 +113,9 @@ struct range_window_bounds {
extent_type _extent{extent_type::UNBOUNDED};
std::shared_ptr<scalar> _range_scalar{nullptr}; // To enable copy construction/assignment.

range_window_bounds(extent_type extent_, std::unique_ptr<scalar> range_scalar_);
range_window_bounds(extent_type extent_,
std::unique_ptr<scalar> range_scalar_,
rmm::cuda_stream_view = cudf::get_default_stream());
};

/** @} */ // end of group
Expand Down
Loading
Loading