Skip to content

Commit

Permalink
Support nth_element for window functions (rapidsai#11158)
Browse files Browse the repository at this point in the history
Fixes rapidsai#9643.

This is to address spark-rapids/issues/4005 and
spark-rapids/issues/5061.

This change adds support for `NTH_ELEMENT` window aggregations.
This should allow for the implementation of `FIRST()`, `LAST()`,
and `NTH_VALUE()` window functions in Spark SQL.

`NTH_ELEMENT` in window function returns the Nth element from the
specified window for each row in a column. `N` is deemed to be
zero based, so `NTH_ELEMENT(0)` translates to the first element
in a window. Similarly, `NTH_ELEMENT(-1)` translates to the last.

If for any window of size `W`, if the specified `N` falls outside
the range `[ -W, W-1 ]`, a null element is returned for that row.

Authors:
  - MithunR (https://github.com/mythrocks)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - David Wendt (https://github.com/davidwendt)
  - Nghia Truong (https://github.com/ttnghia)

URL: rapidsai#11158
  • Loading branch information
mythrocks authored Jul 7, 2022
1 parent 58f46a6 commit cdd4e03
Show file tree
Hide file tree
Showing 12 changed files with 1,032 additions and 112 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ add_library(
src/rolling/range_window_bounds.cpp
src/rolling/rolling.cu
src/rolling/rolling_collect_list.cu
src/rolling/rolling_detail_fixed_window.cu
src/rolling/rolling_detail_variable_window.cu
src/round/round.cu
src/scalar/scalar.cpp
src/scalar/scalar_factories.cpp
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,9 @@ class nunique_aggregation final : public groupby_aggregation, public reduce_aggr
/**
* @brief Derived class for specifying a nth element aggregation
*/
class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation {
class nth_element_aggregation final : public groupby_aggregation,
public reduce_aggregation,
public rolling_aggregation {
public:
nth_element_aggregation(size_type n, null_policy null_handling)
: aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,8 @@ template std::unique_ptr<groupby_aggregation> make_nth_element_aggregation<group
size_type n, null_policy null_handling);
template std::unique_ptr<reduce_aggregation> make_nth_element_aggregation<reduce_aggregation>(
size_type n, null_policy null_handling);
template std::unique_ptr<rolling_aggregation> make_nth_element_aggregation<rolling_aggregation>(
size_type n, null_policy null_handling);

/// Factory to create a ROW_NUMBER aggregation
template <typename Base>
Expand Down
167 changes: 167 additions & 0 deletions cpp/src/rolling/nth_element.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/aggregation.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/gather.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/utilities/bit.hpp>

#include <limits>
#include <rmm/exec_policy.hpp>

namespace cudf::detail::rolling {

/**
* @brief Functor to construct gather-map indices for NTH_ELEMENT rolling aggregation.
*
* By definition, the `N`th element is deemed null (i.e. the gather index is set to "nullify")
* for the following cases:
* 1. The window has fewer elements than `min_periods`.
* 2. N falls outside the window, i.e. N ∉ [-window_size, window_size).
* 3. `null_handling == EXCLUDE`, and the window has fewer than `N` non-null elements.
*
* If none of the above holds true, the result is non-null. How the value is determined
* depends on `null_handling`:
* 1. `null_handling == INCLUDE`: The required value is the `N`th value from the window's start.
* i.e. the gather index is window_start + N (adjusted for negative N).
* 2. `null_handling == EXCLUDE`: The required value is the `N`th non-null value from the
* window's start. i.e. Return index of the `N`th non-null value.
*/
template <null_policy null_handling, typename PrecedingIter, typename FollowingIter>
struct gather_index_calculator {
size_type n;
bitmask_type const* input_nullmask;
bool exclude_nulls;
PrecedingIter preceding;
FollowingIter following;
size_type min_periods;
rmm::cuda_stream_view stream;

static size_type constexpr NULL_INDEX =
std::numeric_limits<size_type>::min(); // For nullifying with gather.

gather_index_calculator(size_type n,
column_view input,
PrecedingIter preceding,
FollowingIter following,
size_type min_periods,
rmm::cuda_stream_view stream)
: n{n},
input_nullmask{input.null_mask()},
exclude_nulls{null_handling == null_policy::EXCLUDE and input.has_nulls()},
preceding{preceding},
following{following},
min_periods{min_periods},
stream{stream}
{
}

/// For `null_policy::EXCLUDE`, find gather index for `N`th non-null value.
template <typename Iter>
size_type __device__ index_of_nth_non_null(Iter begin, size_type window_size) const
{
auto reqd_valid_count = n >= 0 ? n : (-n - 1);
auto const pred_nth_valid = [&reqd_valid_count, input_nullmask = input_nullmask](size_type j) {
return cudf::bit_is_set(input_nullmask, j) && reqd_valid_count-- == 0;
};
auto const end = begin + window_size;
auto const found = thrust::find_if(thrust::seq, begin, end, pred_nth_valid);
return found == end ? NULL_INDEX : *found;
}

size_type __device__ operator()(size_type i) const
{
// preceding[i] includes the current row.
auto const window_size = preceding[i] + following[i];
if (min_periods > window_size) { return NULL_INDEX; }

auto const wrapped_n = n >= 0 ? n : (window_size + n);
if (wrapped_n < 0 || wrapped_n > (window_size - 1)) {
return NULL_INDEX; // n lies outside the window.
}

// Out of short-circuit exits.
// If nulls don't need to be excluded, a fixed window offset calculation is sufficient.
auto const window_start = i - preceding[i] + 1;
if (not exclude_nulls) { return window_start + wrapped_n; }

// Must exclude nulls. Must examine each row in the window.
auto const window_end = window_start + window_size;
return n >= 0 ? index_of_nth_non_null(thrust::make_counting_iterator(window_start), window_size)
: index_of_nth_non_null(
thrust::make_reverse_iterator(thrust::make_counting_iterator(window_end)),
window_size);
}
};

/**
* @brief Helper function for NTH_ELEMENT window aggregation
*
* The `N`th element is deemed null for the following cases:
* 1. The window has fewer elements than `min_periods`.
* 2. N falls outside the window, i.e. N ∉ [-window_size, window_size).
* 3. `null_handling == EXCLUDE`, and the window has fewer than `N` non-null elements.
*
* If none of the above holds true, the result is non-null. How the value is determined
* depends on `null_handling`:
* 1. `null_handling == INCLUDE`: The required value is the `N`th value from the window's start.
* 2. `null_handling == EXCLUDE`: The required value is the `N`th *non-null* value from the
* window's start. If the window has fewer than `N` non-null values, the result is null.
*
* @tparam null_handling Whether to include/exclude null rows in the window
* @tparam PrecedingIter Type of iterator for preceding window
* @tparam FollowingIter Type of iterator for following window
* @param n The index of the element to be returned
* @param input The input column
* @param preceding Iterator specifying the preceding window bound
* @param following Iterator specifying the following window bound
* @param min_periods The minimum number of rows required in the window
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return A column the `n`th element of the specified window for each row
*/
template <null_policy null_handling, typename PrecedingIter, typename FollowingIter>
std::unique_ptr<column> nth_element(size_type n,
column_view const& input,
PrecedingIter preceding,
FollowingIter following,
size_type min_periods,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const gather_iter = cudf::detail::make_counting_transform_iterator(
0,
gather_index_calculator<null_handling, PrecedingIter, FollowingIter>{
n, input, preceding, following, min_periods, stream});

auto gather_map = rmm::device_uvector<offset_type>(input.size(), stream);
thrust::copy(
rmm::exec_policy(stream), gather_iter, gather_iter + input.size(), gather_map.begin());

auto gathered = cudf::detail::gather(table_view{{input}},
gather_map,
cudf::out_of_bounds_policy::NULLIFY,
negative_index_policy::NOT_ALLOWED,
stream,
mr)
->release();
return std::move(gathered.front());
}

} // namespace cudf::detail::rolling
93 changes: 0 additions & 93 deletions cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,99 +22,6 @@
#include <thrust/iterator/constant_iterator.h>

namespace cudf {
namespace detail {

// Applies a fixed-size rolling window function to the values in a column.
std::unique_ptr<column> rolling_window(column_view const& input,
column_view const& default_outputs,
size_type preceding_window,
size_type following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();

if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); }

CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative");

CUDF_EXPECTS((default_outputs.is_empty() || default_outputs.size() == input.size()),
"Defaults column must be either empty or have as many rows as the input column.");

if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) {
return cudf::detail::rolling_window_udf(input,
preceding_window,
"cudf::size_type",
following_window,
"cudf::size_type",
min_periods,
agg,
stream,
mr);
} else {
auto preceding_window_begin = thrust::make_constant_iterator(preceding_window);
auto following_window_begin = thrust::make_constant_iterator(following_window);

return cudf::detail::rolling_window(input,
default_outputs,
preceding_window_begin,
following_window_begin,
min_periods,
agg,
stream,
mr);
}
}

// Applies a variable-size rolling window function to the values in a column.
std::unique_ptr<column> rolling_window(column_view const& input,
column_view const& preceding_window,
column_view const& following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();

if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) {
return cudf::detail::empty_output_for_rolling_aggregation(input, agg);
}

CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 &&
following_window.type().id() == type_id::INT32,
"preceding_window/following_window must have type_id::INT32 type");

CUDF_EXPECTS(preceding_window.size() == input.size() && following_window.size() == input.size(),
"preceding_window/following_window size must match input size");

if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) {
return cudf::detail::rolling_window_udf(input,
preceding_window.begin<size_type>(),
"cudf::size_type*",
following_window.begin<size_type>(),
"cudf::size_type*",
min_periods,
agg,
stream,
mr);
} else {
auto defaults_col =
cudf::is_dictionary(input.type()) ? dictionary_column_view(input).indices() : input;
return cudf::detail::rolling_window(input,
empty_like(defaults_col)->view(),
preceding_window.begin<size_type>(),
following_window.begin<size_type>(),
min_periods,
agg,
stream,
mr);
}
}

} // namespace detail

// Applies a fixed-size rolling window function to the values in a column, with default output
// specified
Expand Down
23 changes: 11 additions & 12 deletions cpp/src/rolling/rolling_collect_list.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,17 @@ std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
auto new_sizes = make_fixed_width_column(
data_type{type_to_id<size_type>()}, input.size(), mask_state::UNALLOCATED, stream);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
new_sizes->mutable_view().template begin<size_type>(),
[d_gather_map = gather_map.template begin<size_type>(),
d_old_offsets = offsets.template begin<size_type>(),
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});
thrust::tabulate(rmm::exec_policy(stream),
new_sizes->mutable_view().template begin<size_type>(),
new_sizes->mutable_view().template end<size_type>(),
[d_gather_map = gather_map.template begin<offset_type>(),
d_old_offsets = offsets.template begin<offset_type>(),
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});

auto new_offsets =
strings::detail::make_offsets_child_column(new_sizes->view().template begin<size_type>(),
Expand Down
21 changes: 20 additions & 1 deletion cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "lead_lag_nested_detail.cuh"
#include "nth_element.cuh"
#include "rolling/rolling_collect_list.cuh"
#include "rolling/rolling_detail.hpp"
#include "rolling/rolling_jit_detail.hpp"
Expand Down Expand Up @@ -604,7 +605,7 @@ struct DeviceRollingLag {
};

/**
* @brief Maps an `InputType and `aggregation::Kind` value to it's corresponding
* @brief Maps an `InputType and `aggregation::Kind` value to its corresponding
* rolling window operator.
*
* @tparam InputType The input type to map to its corresponding operator
Expand Down Expand Up @@ -818,6 +819,13 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre
aggs.push_back(agg.clone());
return aggs;
}

// NTH_ELEMENT aggregations are computed in finalize(). Skip preprocessing.
std::vector<std::unique_ptr<aggregation>> visit(
data_type, cudf::detail::nth_element_aggregation const&) override
{
return {};
}
};

/**
Expand Down Expand Up @@ -961,6 +969,17 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
}
}

// Nth_ELEMENT aggregation.
void visit(cudf::detail::nth_element_aggregation const& agg) override
{
result =
agg._null_handling == null_policy::EXCLUDE
? rolling::nth_element<null_policy::EXCLUDE>(
agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr)
: rolling::nth_element<null_policy::INCLUDE>(
agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr);
}

private:
column_view input;
column_view default_outputs;
Expand Down
Loading

0 comments on commit cdd4e03

Please sign in to comment.