Skip to content

Commit

Permalink
Refactor of rolling_window implementation. (#8158)
Browse files Browse the repository at this point in the history
This is an attempt to significantly reduce the complexity of the logic of the SFINAE and various functors/functions inside of rolling_detail.cuh.  There are 2 major components:

- It introduces the idea of device "rolling operators".  These operators are essentially just the implementations of what were formerly the `process_rolling_window()` functtions.  However, they provide they key mechanism for removing the complex SFINAE out of the core logic.  They do this by providing their own logic that can throw for invalid aggregation/type pairs at construction time, internally.

- It refactors the type and aggregation-dispatched functors to use the collector/finalize paradigm used by groupby.  Specifically, the rolling operation is broken down into three parts.   1.) Preprocess incoming aggregation/type pairs, potentially transforming them into different operations.   2.) Perform the rolling window operation on the transformed inputs.    3.) Postprocess the output from the rolling rolling window operation to obtain the final result.

Combined, these two changes dramatically reduce the amount of dispatch and gpu rolling implementation code one has to read through.

The implementation of the collect list rolling operation has been moved into `rolling_collect_list.cuh`

There are a couple of other things worth mentioning:

- Each device rolling operator implements an `is_supported()` constexpr function which are stripped down, type-specific versions of the old `is_rolling_supported()` global function.  It might be possible to eliminate this with further fundamental type traits.  Looking for opinions here.

- `is_rolling_supported()` has been removed from the code, however the various tests relied on it pretty heavily.  So for now I just transplanted it into the test code in a common place.  It's definitely not an ideal solution, but maybe ok for now. 

- It might be worth moving the device rolling operators into their own module to further shrink `rolling_detail.cuh`.  Also looking for opinions here.

Authors:
  - https://github.com/nvdbaranec

Approvers:
  - Mike Wilson (https://github.com/hyperbolic2346)
  - MithunR (https://github.com/mythrocks)
  - Jake Hemstad (https://github.com/jrhemstad)

URL: #8158
  • Loading branch information
nvdbaranec authored May 24, 2021
1 parent 5c0a75b commit 691dd11
Show file tree
Hide file tree
Showing 8 changed files with 1,238 additions and 1,136 deletions.
18 changes: 18 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ struct corresponding_operator<aggregation::MAX> {
using type = DeviceMax;
};
template <>
struct corresponding_operator<aggregation::ARGMIN> {
using type = DeviceMin;
};
template <>
struct corresponding_operator<aggregation::ARGMAX> {
using type = DeviceMax;
};
template <>
struct corresponding_operator<aggregation::ANY> {
using type = DeviceMax;
};
Expand Down Expand Up @@ -81,6 +89,10 @@ struct corresponding_operator<aggregation::VARIANCE> {
using type = DeviceSum;
};
template <>
struct corresponding_operator<aggregation::MEAN> {
using type = DeviceSum;
};
template <>
struct corresponding_operator<aggregation::COUNT_VALID> {
using type = DeviceCount;
};
Expand All @@ -92,6 +104,12 @@ struct corresponding_operator<aggregation::COUNT_ALL> {
template <aggregation::Kind k>
using corresponding_operator_t = typename corresponding_operator<k>::type;

template <aggregation::Kind k>
constexpr bool has_corresponding_operator()
{
return !std::is_same<typename corresponding_operator<k>::type, void>::value;
}

template <typename Source,
aggregation::Kind k,
bool target_has_nulls,
Expand Down
104 changes: 37 additions & 67 deletions cpp/src/rolling/lead_lag_nested_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,55 +27,6 @@

namespace cudf::detail {
namespace {
/**
* @brief Functor to calculate the gather map used for calculating LEAD/LAG.
*
* @tparam op Aggregation Kind (LEAD vs LAG)
* @tparam PrecedingIterator Iterator to retrieve preceding window bounds
* @tparam FollowingIterator Iterator to retrieve following window bounds
*/
template <aggregation::Kind op, typename PrecedingIterator, typename FollowingIterator>
class lead_lag_gather_map_builder {
public:
lead_lag_gather_map_builder(size_type input_size,
size_type row_offset,
PrecedingIterator preceding,
FollowingIterator following)
: _input_size{input_size},
_null_index{input_size}, // Out of input range. Gather returns null.
_row_offset{row_offset},
_preceding{preceding},
_following{following}
{
}

template <aggregation::Kind o = op, CUDF_ENABLE_IF(o == aggregation::LEAD)>
size_type __device__ operator()(size_type i)
{
// Note: grouped_*rolling_window() trims preceding/following to
// the beginning/end of the group. `rolling_window()` does not.
// Must trim _following[i] so as not to go past the column end.
auto following = min(_following[i], _input_size - i - 1);
return (_row_offset > following) ? _null_index : (i + _row_offset);
}

template <aggregation::Kind o = op, CUDF_ENABLE_IF(o == aggregation::LAG)>
size_type __device__ operator()(size_type i)
{
// Note: grouped_*rolling_window() trims preceding/following to
// the beginning/end of the group. `rolling_window()` does not.
// Must trim _preceding[i] so as not to go past the column start.
auto preceding = min(_preceding[i], i + 1);
return (_row_offset > (preceding - 1)) ? _null_index : (i - _row_offset);
}

private:
size_type const _input_size; // Number of rows in input to LEAD/LAG.
size_type const _null_index; // Index value to use to output NULL for LEAD/LAG calculation.
size_type const _row_offset; // LEAD/LAG offset. E.g. For LEAD(2), _row_offset == 2.
PrecedingIterator _preceding; // Iterator to retrieve preceding window offset.
FollowingIterator _following; // Iterator to retrieve following window offset.
};

/**
* @brief Predicate to find indices at which LEAD/LAG evaluated to null.
Expand Down Expand Up @@ -110,33 +61,31 @@ is_null_index_predicate_impl<GatherMapIter> is_null_index_predicate(size_type in
/**
* @brief Helper function to calculate LEAD/LAG for nested-type input columns.
*
* @tparam op The sort of aggregation being done (LEAD vs LAG)
* @tparam InputType The datatype of the input column being aggregated
* @tparam PrecedingIterator Iterator-type that returns the preceding bounds
* @tparam FollowingIterator Iterator-type that returns the following bounds
* @param[in] op Aggregation kind.
* @param[in] input Nested-type input column for LEAD/LAG calculation
* @param[in] default_outputs Default values to use as outputs, if LEAD/LAG
* offset crosses column/group boundaries
* @param[in] preceding Iterator to retrieve preceding window bounds
* @param[in] following Iterator to retrieve following window bounds
* @param[in] offset Lead/Lag offset, indicating which row after/before
* the current row is to be returned
* @param[in] row_offset Lead/Lag offset, indicating which row after/before
* the current row is to be returned
* @param[in] stream CUDA stream for device memory operations/allocations
* @param[in] mr device_memory_resource for device memory allocations
*/
template <aggregation::Kind op,
typename InputType,
typename PrecedingIter,
typename FollowingIter,
CUDF_ENABLE_IF(!cudf::is_fixed_width<InputType>())>
std::unique_ptr<column> compute_lead_lag_for_nested(column_view const& input,
template <typename PrecedingIter, typename FollowingIter>
std::unique_ptr<column> compute_lead_lag_for_nested(aggregation::Kind op,
column_view const& input,
column_view const& default_outputs,
PrecedingIter preceding,
FollowingIter following,
size_type offset,
size_type row_offset,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_EXPECTS(op == aggregation::LEAD || op == aggregation::LAG,
"Unexpected aggregation type in compute_lead_lag_for_nested");
CUDF_EXPECTS(default_outputs.type().id() == input.type().id(),
"Defaults column type must match input column."); // Because LEAD/LAG.

Expand All @@ -145,7 +94,7 @@ std::unique_ptr<column> compute_lead_lag_for_nested(column_view const& input,

// For LEAD(0)/LAG(0), no computation need be performed.
// Return copy of input.
if (offset == 0) { return std::make_unique<column>(input, stream, mr); }
if (row_offset == 0) { return std::make_unique<column>(input, stream, mr); }

// Algorithm:
//
Expand Down Expand Up @@ -174,12 +123,33 @@ std::unique_ptr<column> compute_lead_lag_for_nested(column_view const& input,
make_numeric_column(size_data_type, input.size(), mask_state::UNALLOCATED, stream);
auto gather_map = gather_map_column->mutable_view();

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator(size_type{0}),
thrust::make_counting_iterator(size_type{input.size()}),
gather_map.begin<size_type>(),
lead_lag_gather_map_builder<op, PrecedingIter, FollowingIter>{
input.size(), offset, preceding, following});
auto const input_size = input.size();
auto const null_index = input.size();
if (op == aggregation::LEAD) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator(size_type{0}),
thrust::make_counting_iterator(size_type{input.size()}),
gather_map.begin<size_type>(),
[following, input_size, null_index, row_offset] __device__(size_type i) {
// Note: grouped_*rolling_window() trims preceding/following to
// the beginning/end of the group. `rolling_window()` does not.
// Must trim _following[i] so as not to go past the column end.
auto _following = min(following[i], input_size - i - 1);
return (row_offset > _following) ? null_index : (i + row_offset);
});
} else {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator(size_type{0}),
thrust::make_counting_iterator(size_type{input.size()}),
gather_map.begin<size_type>(),
[preceding, input_size, null_index, row_offset] __device__(size_type i) {
// Note: grouped_*rolling_window() trims preceding/following to
// the beginning/end of the group. `rolling_window()` does not.
// Must trim _preceding[i] so as not to go past the column start.
auto _preceding = min(preceding[i], i + 1);
return (row_offset > (_preceding - 1)) ? null_index : (i - row_offset);
});
}

auto output_with_nulls =
cudf::detail::gather(table_view{std::vector<column_view>{input}},
Expand Down
Loading

0 comments on commit 691dd11

Please sign in to comment.