Skip to content

Commit

Permalink
Support nth_element for window functions
Browse files Browse the repository at this point in the history
This is to address spark-rapids/issues/4005 and
spark-rapids/issues/5061.

This change adds support for `NTH_ELEMENT` window aggregations.
This should allow for the implementation of `FIRST()`, `LAST()`,
and `NTH_VALUE()` window functions in Spark SQL.

`NTH_ELEMENT` in window function returns the Nth element from the
specified window for each row in a column. `N` is deemed to be
zero based, so `NTH_ELEMENT(0)` translates to the first element
in a window. Similarly, `NTH_ELEMENT(-1)` translates to the last.

If for any window of size `W`, if the specified `N` falls outside
the range `[ -W, W-1 ]`, a null element is returned for that row.
  • Loading branch information
mythrocks committed Jun 27, 2022
1 parent b5a59cf commit 7d02895
Show file tree
Hide file tree
Showing 8 changed files with 781 additions and 18 deletions.
4 changes: 3 additions & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,9 @@ class nunique_aggregation final : public groupby_aggregation, public reduce_aggr
/**
* @brief Derived class for specifying a nth element aggregation
*/
class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation {
class nth_element_aggregation final : public groupby_aggregation,
public reduce_aggregation,
public rolling_aggregation {
public:
nth_element_aggregation(size_type n, null_policy null_handling)
: aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,8 @@ template std::unique_ptr<groupby_aggregation> make_nth_element_aggregation<group
size_type n, null_policy null_handling);
template std::unique_ptr<reduce_aggregation> make_nth_element_aggregation<reduce_aggregation>(
size_type n, null_policy null_handling);
template std::unique_ptr<rolling_aggregation> make_nth_element_aggregation<rolling_aggregation>(
size_type n, null_policy null_handling);

/// Factory to create a ROW_NUMBER aggregation
template <typename Base>
Expand Down
95 changes: 95 additions & 0 deletions cpp/src/rolling/nth_element.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/aggregation.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/gather.cuh>
#include <cudf/detail/iterator.cuh>
#include <cudf/utilities/bit.hpp>

#include <limits>
#include <rmm/exec_policy.hpp>

namespace cudf::detail::rolling {

auto constexpr NULL_INDEX = std::numeric_limits<size_type>::min(); // For nullifying with gather.

template <null_policy null_handling, typename PrecedingIter, typename FollowingIter>
std::unique_ptr<column> nth_element(size_type n,
column_view const& input,
PrecedingIter preceding,
FollowingIter following,
size_type min_periods,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const gather_iter = cudf::detail::make_counting_transform_iterator(
0,
[exclude_nulls = null_handling == null_policy::EXCLUDE and input.nullable(),
preceding,
following,
min_periods,
n,
input_nullmask = input.null_mask()] __device__(size_type i) {
// preceding[i] includes the current row.
auto const window_size = preceding[i] + following[i];
if (min_periods > window_size) { return NULL_INDEX; }

auto const wrapped_n = n >= 0 ? n : (window_size + n);
if (wrapped_n < 0 || wrapped_n > (window_size - 1)) {
return NULL_INDEX; // n lies outside the window.
}

auto const window_start = i - preceding[i] + 1;

if (not exclude_nulls) { return window_start + wrapped_n; }

// Must exclude nulls, and n is in range [-window_size, window_size-1].
// Depending on n >= 0, count forwards from window_start, or backwards from window_end.
auto const window_end = window_start + window_size;

auto reqd_valid_count = n >= 0 ? n : (-n - 1);
auto const nth_valid = [&reqd_valid_count, input_nullmask](size_type j) {
return cudf::bit_is_set(input_nullmask, j) && reqd_valid_count-- == 0;
};

if (n >= 0) { // Search forwards from window_start.
auto const begin = thrust::make_counting_iterator(window_start);
auto const end = begin + window_size;
auto const found = thrust::find_if(thrust::seq, begin, end, nth_valid);
return found == end ? NULL_INDEX : *found;
} else { // Search backwards from window-end.
auto const begin =
thrust::make_reverse_iterator(thrust::make_counting_iterator(window_end));
auto const end = begin + window_size;
auto const found = thrust::find_if(thrust::seq, begin, end, nth_valid);
return found == end ? NULL_INDEX : *found;
}
});

auto gathered = cudf::detail::gather(table_view{{input}},
gather_iter,
gather_iter + input.size(),
cudf::out_of_bounds_policy::NULLIFY,
stream,
mr)
->release();
return std::move(gathered[0]);
}

} // namespace cudf::detail::rolling
29 changes: 25 additions & 4 deletions cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
"Defaults column must be either empty or have as many rows as the input column.");

if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) {
// TODO: In future, might need to clamp preceding/following to column boundaries.
return cudf::detail::rolling_window_udf(input,
preceding_window,
"cudf::size_type",
Expand All @@ -54,8 +55,16 @@ std::unique_ptr<column> rolling_window(column_view const& input,
stream,
mr);
} else {
auto preceding_window_begin = thrust::make_constant_iterator(preceding_window);
auto following_window_begin = thrust::make_constant_iterator(following_window);
// Clamp preceding/following to column boundaries.
// E.g. If preceding_window == 2, then for a column of 5 elements, preceding_window will be:
// [1, 2, 2, 2, 1]
auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator(
0,
[preceding_window] __device__(size_type i) { return thrust::min(i + 1, preceding_window); });
auto const following_window_begin = cudf::detail::make_counting_transform_iterator(
0, [col_size = input.size(), following_window] __device__(size_type i) {
return thrust::min(col_size - i - 1, following_window);
});

return cudf::detail::rolling_window(input,
default_outputs,
Expand Down Expand Up @@ -91,6 +100,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
"preceding_window/following_window size must match input size");

if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) {
// TODO: In future, might need to clamp preceding/following to column boundaries.
return cudf::detail::rolling_window_udf(input,
preceding_window.begin<size_type>(),
"cudf::size_type*",
Expand All @@ -103,10 +113,21 @@ std::unique_ptr<column> rolling_window(column_view const& input,
} else {
auto defaults_col =
cudf::is_dictionary(input.type()) ? dictionary_column_view(input).indices() : input;
// Clamp preceding/following to column boundaries.
// E.g. If preceding_window == [2, 2, 2, 2, 2] for a column of 5 elements, the new
// preceding_window will be: [1, 2, 2, 2, 1]
auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator(
0, [preceding = preceding_window.begin<size_type>()] __device__(size_type i) {
return thrust::min(i + 1, preceding[i]);
});
auto const following_window_begin = cudf::detail::make_counting_transform_iterator(
0,
[col_size = input.size(), following = following_window.begin<size_type>()] __device__(
size_type i) { return thrust::min(col_size - i - 1, following[i]); });
return cudf::detail::rolling_window(input,
empty_like(defaults_col)->view(),
preceding_window.begin<size_type>(),
following_window.begin<size_type>(),
preceding_window_begin,
following_window_begin,
min_periods,
agg,
stream,
Expand Down
23 changes: 11 additions & 12 deletions cpp/src/rolling/rolling_collect_list.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,17 @@ std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
auto new_sizes = make_fixed_width_column(
data_type{type_to_id<size_type>()}, input.size(), mask_state::UNALLOCATED, stream);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
new_sizes->mutable_view().template begin<size_type>(),
[d_gather_map = gather_map.template begin<size_type>(),
d_old_offsets = offsets.template begin<size_type>(),
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});
thrust::tabulate(rmm::exec_policy(stream),
new_sizes->mutable_view().template begin<size_type>(),
new_sizes->mutable_view().template end<size_type>(),
[d_gather_map = gather_map.template begin<size_type>(),
d_old_offsets = offsets.template begin<size_type>(),
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});

auto new_offsets =
strings::detail::make_offsets_child_column(new_sizes->view().template begin<size_type>(),
Expand Down
21 changes: 20 additions & 1 deletion cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "lead_lag_nested_detail.cuh"
#include "nth_element.cuh"
#include "rolling/rolling_collect_list.cuh"
#include "rolling/rolling_detail.hpp"
#include "rolling/rolling_jit_detail.hpp"
Expand Down Expand Up @@ -604,7 +605,7 @@ struct DeviceRollingLag {
};

/**
* @brief Maps an `InputType and `aggregation::Kind` value to it's corresponding
* @brief Maps an `InputType and `aggregation::Kind` value to its corresponding
* rolling window operator.
*
* @tparam InputType The input type to map to its corresponding operator
Expand Down Expand Up @@ -818,6 +819,13 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre
aggs.push_back(agg.clone());
return aggs;
}

// NTH_ELEMENT aggregations are computed in finalize(). Skip preprocessing.
std::vector<std::unique_ptr<aggregation>> visit(
data_type, cudf::detail::nth_element_aggregation const&) override
{
return {};
}
};

/**
Expand Down Expand Up @@ -961,6 +969,17 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
}
}

// Nth_ELEMENT aggregation.
void visit(cudf::detail::nth_element_aggregation const& agg) override
{
result =
agg._null_handling == null_policy::EXCLUDE
? rolling::nth_element<null_policy::EXCLUDE>(
agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr)
: rolling::nth_element<null_policy::INCLUDE>(
agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr);
}

private:
column_view input;
column_view default_outputs;
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ ConfigureTest(
rolling/empty_input_test.cpp
rolling/grouped_rolling_test.cpp
rolling/lead_lag_test.cpp
rolling/nth_element_test.cpp
rolling/range_rolling_window_test.cpp
rolling/range_window_bounds_test.cpp
rolling/rolling_test.cpp
Expand Down
Loading

0 comments on commit 7d02895

Please sign in to comment.