Skip to content

Commit

Permalink
update range_window_bounds API to accept stream parameter
Browse files Browse the repository at this point in the history
range_window_bounds API functions uses scalars internally, which should also the same stream parameter.

Signed-off-by: srinivasyadav18 <[email protected]>
  • Loading branch information
srinivasyadav18 committed Jun 7, 2024
1 parent db9d306 commit 4e5e4a8
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 34 deletions.
16 changes: 12 additions & 4 deletions cpp/include/cudf/rolling/range_window_bounds.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,22 @@ struct range_window_bounds {
* @brief Factory method to construct a bounded window boundary.
*
* @param boundary Finite window boundary
* @param stream CUDA stream used for device memory operations and kernel launches
* @return A bounded window boundary object
*/
static range_window_bounds get(scalar const& boundary);
static range_window_bounds get(scalar const& boundary,
rmm::cuda_stream_view stream = cudf::get_default_stream());

/**
* @brief Factory method to construct a window boundary
* limited to the value of the current row
*
* @param type The datatype of the window boundary
* @param stream CUDA stream used for device memory operations and kernel launches
* @return A "current row" window boundary object
*/
static range_window_bounds current_row(data_type type);
static range_window_bounds current_row(data_type type,
rmm::cuda_stream_view stream = cudf::get_default_stream());

/**
* @brief Whether or not the window is bounded to the current row
Expand All @@ -81,9 +85,11 @@ struct range_window_bounds {
* @brief Factory method to construct an unbounded window boundary.
*
* @param type The datatype of the window boundary
* @param stream CUDA stream used for device memory operations and kernel launches
* @return An unbounded window boundary object
*/
static range_window_bounds unbounded(data_type type);
static range_window_bounds unbounded(data_type type,
rmm::cuda_stream_view stream = cudf::get_default_stream());

/**
* @brief Whether or not the window is unbounded
Expand All @@ -107,7 +113,9 @@ struct range_window_bounds {
extent_type _extent{extent_type::UNBOUNDED};
std::shared_ptr<scalar> _range_scalar{nullptr}; // To enable copy construction/assignment.

range_window_bounds(extent_type extent_, std::unique_ptr<scalar> range_scalar_);
range_window_bounds(extent_type extent_,
std::unique_ptr<scalar> range_scalar_,
rmm::cuda_stream_view = cudf::get_default_stream());
};

/** @} */ // end of group
Expand Down
29 changes: 17 additions & 12 deletions cpp/src/rolling/grouped_rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1055,14 +1055,15 @@ struct dispatch_grouped_range_rolling_window {
*/
struct to_duration_bounds {
template <typename OrderBy, std::enable_if_t<cudf::is_timestamp<OrderBy>(), void>* = nullptr>
range_window_bounds operator()(size_type num_days) const
range_window_bounds operator()(size_type num_days, rmm::cuda_stream_view stream) const
{
using DurationT = typename OrderBy::duration;
return range_window_bounds::get(duration_scalar<DurationT>{duration_D{num_days}, true});
return range_window_bounds::get(duration_scalar<DurationT>{duration_D{num_days}, true, stream},
stream);
}

template <typename OrderBy, std::enable_if_t<!cudf::is_timestamp<OrderBy>(), void>* = nullptr>
range_window_bounds operator()(size_type) const
range_window_bounds operator()(size_type, rmm::cuda_stream_view) const
{
CUDF_FAIL("Expected timestamp orderby column.");
}
Expand Down Expand Up @@ -1093,9 +1094,11 @@ data_type get_duration_type_for(cudf::data_type timestamp_type)
* @param timestamp_type Data-type of the orderby column to which the `num_days` is to be adapted.
* @return range_window_bounds A `range_window_bounds` to be used with the new API.
*/
range_window_bounds to_range_bounds(cudf::size_type num_days, cudf::data_type timestamp_type)
range_window_bounds to_range_bounds(cudf::size_type num_days,
cudf::data_type timestamp_type,
rmm::cuda_stream_view stream)
{
return cudf::type_dispatcher(timestamp_type, to_duration_bounds{}, num_days);
return cudf::type_dispatcher(timestamp_type, to_duration_bounds{}, num_days, stream);
}

/**
Expand All @@ -1109,11 +1112,13 @@ range_window_bounds to_range_bounds(cudf::size_type num_days, cudf::data_type ti
* @return range_window_bounds A `range_window_bounds` to be used with the new API.
*/
range_window_bounds to_range_bounds(cudf::window_bounds const& days_bounds,
cudf::data_type timestamp_type)
cudf::data_type timestamp_type,
rmm::cuda_stream_view stream)
{
return days_bounds.is_unbounded()
? range_window_bounds::unbounded(get_duration_type_for(timestamp_type))
: cudf::type_dispatcher(timestamp_type, to_duration_bounds{}, days_bounds.value());
? range_window_bounds::unbounded(get_duration_type_for(timestamp_type), stream)
: cudf::type_dispatcher(
timestamp_type, to_duration_bounds{}, days_bounds.value(), stream);
}

} // namespace
Expand Down Expand Up @@ -1211,8 +1216,8 @@ std::unique_ptr<column> grouped_time_range_rolling_window(table_view const& grou
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
auto preceding = to_range_bounds(preceding_window_in_days, timestamp_column.type());
auto following = to_range_bounds(following_window_in_days, timestamp_column.type());
auto preceding = to_range_bounds(preceding_window_in_days, timestamp_column.type(), stream);
auto following = to_range_bounds(following_window_in_days, timestamp_column.type(), stream);

return detail::grouped_range_rolling_window(group_keys,
timestamp_column,
Expand Down Expand Up @@ -1251,9 +1256,9 @@ std::unique_ptr<column> grouped_time_range_rolling_window(table_view const& grou
{
CUDF_FUNC_RANGE();
range_window_bounds preceding =
to_range_bounds(preceding_window_in_days, timestamp_column.type());
to_range_bounds(preceding_window_in_days, timestamp_column.type(), stream);
range_window_bounds following =
to_range_bounds(following_window_in_days, timestamp_column.type());
to_range_bounds(following_window_in_days, timestamp_column.type(), stream);

return detail::grouped_range_rolling_window(group_keys,
timestamp_column,
Expand Down
39 changes: 23 additions & 16 deletions cpp/src/rolling/range_window_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,59 +32,66 @@ namespace {
*/
struct range_scalar_constructor {
template <typename T, CUDF_ENABLE_IF(not detail::is_supported_range_type<T>())>
std::unique_ptr<scalar> operator()(scalar const& range_scalar_) const
std::unique_ptr<scalar> operator()(scalar const& range_scalar_,
rmm::cuda_stream_view stream) const
{
CUDF_FAIL(
"Unsupported range type. "
"Only durations, fixed-point, and non-boolean numeric range types are allowed.");
}

template <typename T, CUDF_ENABLE_IF(cudf::is_duration<T>())>
std::unique_ptr<scalar> operator()(scalar const& range_scalar_) const
std::unique_ptr<scalar> operator()(scalar const& range_scalar_,
rmm::cuda_stream_view stream) const
{
return std::make_unique<duration_scalar<T>>(
static_cast<duration_scalar<T> const&>(range_scalar_));
static_cast<duration_scalar<T> const&>(range_scalar_), stream);
}

template <typename T, CUDF_ENABLE_IF(cudf::is_numeric<T>() && not cudf::is_boolean<T>())>
std::unique_ptr<scalar> operator()(scalar const& range_scalar_) const
std::unique_ptr<scalar> operator()(scalar const& range_scalar_,
rmm::cuda_stream_view stream) const
{
return std::make_unique<numeric_scalar<T>>(
static_cast<numeric_scalar<T> const&>(range_scalar_));
return std::make_unique<numeric_scalar<T>>(static_cast<numeric_scalar<T> const&>(range_scalar_),
stream);
}

template <typename T, CUDF_ENABLE_IF(cudf::is_fixed_point<T>())>
std::unique_ptr<scalar> operator()(scalar const& range_scalar_) const
std::unique_ptr<scalar> operator()(scalar const& range_scalar_,
rmm::cuda_stream_view stream) const
{
return std::make_unique<fixed_point_scalar<T>>(
static_cast<fixed_point_scalar<T> const&>(range_scalar_));
static_cast<fixed_point_scalar<T> const&>(range_scalar_), stream);
}
};
} // namespace

range_window_bounds::range_window_bounds(extent_type extent_, std::unique_ptr<scalar> range_scalar_)
range_window_bounds::range_window_bounds(extent_type extent_,
std::unique_ptr<scalar> range_scalar_,
rmm::cuda_stream_view stream)
: _extent{extent_}, _range_scalar{std::move(range_scalar_)}
{
CUDF_EXPECTS(_range_scalar.get(), "Range window scalar cannot be null.");
CUDF_EXPECTS(_extent == extent_type::UNBOUNDED || _extent == extent_type::CURRENT_ROW ||
_range_scalar->is_valid(),
_range_scalar->is_valid(stream),
"Bounded Range window scalar must be valid.");
}

range_window_bounds range_window_bounds::unbounded(data_type type)
range_window_bounds range_window_bounds::unbounded(data_type type, rmm::cuda_stream_view stream)
{
return {extent_type::UNBOUNDED, make_default_constructed_scalar(type)};
return {extent_type::UNBOUNDED, make_default_constructed_scalar(type, stream), stream};
}

range_window_bounds range_window_bounds::current_row(data_type type)
range_window_bounds range_window_bounds::current_row(data_type type, rmm::cuda_stream_view stream)
{
return {extent_type::CURRENT_ROW, make_default_constructed_scalar(type)};
return {extent_type::CURRENT_ROW, make_default_constructed_scalar(type, stream), stream};
}

range_window_bounds range_window_bounds::get(scalar const& boundary)
range_window_bounds range_window_bounds::get(scalar const& boundary, rmm::cuda_stream_view stream)
{
return {extent_type::BOUNDED,
cudf::type_dispatcher(boundary.type(), range_scalar_constructor{}, boundary)};
cudf::type_dispatcher(boundary.type(), range_scalar_constructor{}, boundary, stream),
stream};
}

} // namespace cudf
6 changes: 4 additions & 2 deletions cpp/tests/streams/rolling_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,12 @@ TEST_F(GroupedRangeRollingTest, RangeWindowBounds)
{0, 0, 0, 0, 1, 1, 1, 1, 1, 1}};

cudf::range_window_bounds preceding = cudf::range_window_bounds::get(
cudf::numeric_scalar<int>{int{1}, true, cudf::test::get_default_stream()});
cudf::numeric_scalar<int>{int{1}, true, cudf::test::get_default_stream()},
cudf::test::get_default_stream());

cudf::range_window_bounds following = cudf::range_window_bounds::get(
cudf::numeric_scalar<int>{int{1}, true, cudf::test::get_default_stream()});
cudf::numeric_scalar<int>{int{1}, true, cudf::test::get_default_stream()},
cudf::test::get_default_stream());

auto const min_periods = cudf::size_type{1};

Expand Down

0 comments on commit 4e5e4a8

Please sign in to comment.