Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nth_element for window functions #11158

Merged
merged 9 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ add_library(
src/rolling/range_window_bounds.cpp
src/rolling/rolling.cu
src/rolling/rolling_collect_list.cu
src/rolling/rolling_detail_fixed_window.cu
src/rolling/rolling_detail_variable_window.cu
src/round/round.cu
src/scalar/scalar.cpp
src/scalar/scalar_factories.cpp
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,9 @@ class nunique_aggregation final : public groupby_aggregation, public reduce_aggr
/**
* @brief Derived class for specifying a nth element aggregation
*/
class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation {
class nth_element_aggregation final : public groupby_aggregation,
public reduce_aggregation,
public rolling_aggregation {
public:
nth_element_aggregation(size_type n, null_policy null_handling)
: aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,8 @@ template std::unique_ptr<groupby_aggregation> make_nth_element_aggregation<group
size_type n, null_policy null_handling);
template std::unique_ptr<reduce_aggregation> make_nth_element_aggregation<reduce_aggregation>(
size_type n, null_policy null_handling);
template std::unique_ptr<rolling_aggregation> make_nth_element_aggregation<rolling_aggregation>(
size_type n, null_policy null_handling);

/// Factory to create a ROW_NUMBER aggregation
template <typename Base>
Expand Down
170 changes: 170 additions & 0 deletions cpp/src/rolling/nth_element.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/aggregation.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/gather.cuh>
#include <cudf/detail/iterator.cuh>
#include <cudf/utilities/bit.hpp>

#include <limits>
#include <rmm/exec_policy.hpp>

namespace cudf::detail::rolling {

/**
* @brief Functor to construct gather-map indices for NTH_ELEMENT rolling aggregation.
*
* By definition, the `N`th element is deemed null (i.e. the gather index is set to "nullify")
* for the following cases:
* 1. The window has fewer elements than `min_periods`.
* 2. N falls outside the window, i.e. N ∉ [-window_size, window_size).
* 3. `null_handling == EXCLUDE`, and the window has fewer than `N` non-null elements.
*
* If none of the above holds true, the result is non-null. How the value is determined
* depends on `null_handling`:
* 1. `null_handling == INCLUDE`: The required value is the `N`th value from the window's start.
* i.e. the gather index is window_start + N (adjusted for negative N).
* 2. `null_handling == EXCLUDE`: The required value is the `N`th non-null value from the
* window's start. i.e. Return index of the `N`th non-null value.
*/
template <null_policy null_handling, typename PrecedingIter, typename FollowingIter>
struct gather_index_calculator {
size_type n;
bitmask_type const* input_nullmask;
bool exclude_nulls;
PrecedingIter preceding;
FollowingIter following;
size_type min_periods;
rmm::cuda_stream_view stream;
rmm::mr::device_memory_resource* mr;

static size_type constexpr NULL_INDEX =
std::numeric_limits<size_type>::min(); // For nullifying with gather.

gather_index_calculator(size_type n,
column_view input,
PrecedingIter preceding,
FollowingIter following,
size_type min_periods,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
: n{n},
input_nullmask{input.null_mask()},
exclude_nulls{null_handling == null_policy::EXCLUDE and input.has_nulls()},
preceding{preceding},
following{following},
min_periods{min_periods},
stream{stream},
mr{mr}
{
}

/// For `null_policy::EXCLUDE`, find gather index for `N`th non-null value.
template <typename Iter>
size_type __device__ index_of_nth_non_null(Iter begin, size_type window_size) const
{
auto reqd_valid_count = n >= 0 ? n : (-n - 1);
auto const pred_nth_valid = [&reqd_valid_count, input_nullmask = input_nullmask](size_type j) {
return cudf::bit_is_set(input_nullmask, j) && reqd_valid_count-- == 0;
};
auto const end = begin + window_size;
auto const found = thrust::find_if(thrust::seq, begin, end, pred_nth_valid);
return found == end ? NULL_INDEX : *found;
}

size_type __device__ operator()(size_type i) const
{
// preceding[i] includes the current row.
auto const window_size = preceding[i] + following[i];
if (min_periods > window_size) { return NULL_INDEX; }

auto const wrapped_n = n >= 0 ? n : (window_size + n);
if (wrapped_n < 0 || wrapped_n > (window_size - 1)) {
return NULL_INDEX; // n lies outside the window.
}

// Out of short-circuit exits.
// If nulls don't need to be excluded, a fixed window offset calculation is sufficient.
auto const window_start = i - preceding[i] + 1;
if (not exclude_nulls) { return window_start + wrapped_n; }

// Must exclude nulls. Must examine each row in the window.
auto const window_end = window_start + window_size;
return n >= 0 ? index_of_nth_non_null(thrust::make_counting_iterator(window_start), window_size)
: index_of_nth_non_null(
thrust::make_reverse_iterator(thrust::make_counting_iterator(window_end)),
window_size);
}
};

/**
* @brief Helper function for NTH_ELEMENT window aggregation
*
* The `N`th element is deemed null for the following cases:
* 1. The window has fewer elements than `min_periods`.
* 2. N falls outside the window, i.e. N ∉ [-window_size, window_size).
* 3. `null_handling == EXCLUDE`, and the window has fewer than `N` non-null elements.
*
* If none of the above holds true, the result is non-null. How the value is determined
* depends on `null_handling`:
* 1. `null_handling == INCLUDE`: The required value is the `N`th value from the window's start.
* 2. `null_handling == EXCLUDE`: The required value is the `N`th *non-null* value from the
* window's start. If the window has fewer than `N` non-null values, the result is null.
*
* @tparam null_handling Whether to include/exclude null rows in the window
* @tparam PrecedingIter Type of iterator for preceding window
* @tparam FollowingIter Type of iterator for following window
* @param n The index of the element to be returned
* @param input The input column
* @param preceding Iterator specifying the preceding window bound
* @param following Iterator specifying the following window bound
* @param min_periods The minimum number of rows required in the window
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned column's device memory
* @return A column the `n`th element of the specified window for each row
*/
template <null_policy null_handling, typename PrecedingIter, typename FollowingIter>
std::unique_ptr<column> nth_element(size_type n,
column_view const& input,
PrecedingIter preceding,
FollowingIter following,
size_type min_periods,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const gather_iter = cudf::detail::make_counting_transform_iterator(
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
0,
gather_index_calculator<null_handling, PrecedingIter, FollowingIter>{
n, input, preceding, following, min_periods, stream, mr});
mythrocks marked this conversation as resolved.
Show resolved Hide resolved

auto gather_map = rmm::device_uvector<offset_type>(input.size(), stream);
thrust::copy(
rmm::exec_policy(stream), gather_iter, gather_iter + input.size(), gather_map.begin());

auto gathered = cudf::detail::gather(table_view{{input}},
gather_map.begin(),
gather_map.end(),
cudf::out_of_bounds_policy::NULLIFY,
stream,
mr)
->release();
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
return std::move(gathered[0]);
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace cudf::detail::rolling
93 changes: 0 additions & 93 deletions cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,99 +22,6 @@
#include <thrust/iterator/constant_iterator.h>

namespace cudf {
namespace detail {

// Applies a fixed-size rolling window function to the values in a column.
std::unique_ptr<column> rolling_window(column_view const& input,
column_view const& default_outputs,
size_type preceding_window,
size_type following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();

if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); }

CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative");

CUDF_EXPECTS((default_outputs.is_empty() || default_outputs.size() == input.size()),
"Defaults column must be either empty or have as many rows as the input column.");

if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) {
return cudf::detail::rolling_window_udf(input,
preceding_window,
"cudf::size_type",
following_window,
"cudf::size_type",
min_periods,
agg,
stream,
mr);
} else {
auto preceding_window_begin = thrust::make_constant_iterator(preceding_window);
auto following_window_begin = thrust::make_constant_iterator(following_window);

return cudf::detail::rolling_window(input,
default_outputs,
preceding_window_begin,
following_window_begin,
min_periods,
agg,
stream,
mr);
}
}

// Applies a variable-size rolling window function to the values in a column.
std::unique_ptr<column> rolling_window(column_view const& input,
column_view const& preceding_window,
column_view const& following_window,
size_type min_periods,
rolling_aggregation const& agg,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();

if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) {
return cudf::detail::empty_output_for_rolling_aggregation(input, agg);
}

CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 &&
following_window.type().id() == type_id::INT32,
"preceding_window/following_window must have type_id::INT32 type");

CUDF_EXPECTS(preceding_window.size() == input.size() && following_window.size() == input.size(),
"preceding_window/following_window size must match input size");

if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) {
return cudf::detail::rolling_window_udf(input,
preceding_window.begin<size_type>(),
"cudf::size_type*",
following_window.begin<size_type>(),
"cudf::size_type*",
min_periods,
agg,
stream,
mr);
} else {
auto defaults_col =
cudf::is_dictionary(input.type()) ? dictionary_column_view(input).indices() : input;
return cudf::detail::rolling_window(input,
empty_like(defaults_col)->view(),
preceding_window.begin<size_type>(),
following_window.begin<size_type>(),
min_periods,
agg,
stream,
mr);
}
}

} // namespace detail

// Applies a fixed-size rolling window function to the values in a column, with default output
// specified
Expand Down
23 changes: 11 additions & 12 deletions cpp/src/rolling/rolling_collect_list.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,17 @@ std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
auto new_sizes = make_fixed_width_column(
data_type{type_to_id<size_type>()}, input.size(), mask_state::UNALLOCATED, stream);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
new_sizes->mutable_view().template begin<size_type>(),
[d_gather_map = gather_map.template begin<size_type>(),
d_old_offsets = offsets.template begin<size_type>(),
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});
thrust::tabulate(rmm::exec_policy(stream),
new_sizes->mutable_view().template begin<size_type>(),
new_sizes->mutable_view().template end<size_type>(),
[d_gather_map = gather_map.template begin<offset_type>(),
d_old_offsets = offsets.template begin<offset_type>(),
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});

auto new_offsets =
strings::detail::make_offsets_child_column(new_sizes->view().template begin<size_type>(),
Expand Down
21 changes: 20 additions & 1 deletion cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "lead_lag_nested_detail.cuh"
#include "nth_element.cuh"
#include "rolling/rolling_collect_list.cuh"
#include "rolling/rolling_detail.hpp"
#include "rolling/rolling_jit_detail.hpp"
Expand Down Expand Up @@ -604,7 +605,7 @@ struct DeviceRollingLag {
};

/**
* @brief Maps an `InputType and `aggregation::Kind` value to it's corresponding
* @brief Maps an `InputType and `aggregation::Kind` value to its corresponding
* rolling window operator.
*
* @tparam InputType The input type to map to its corresponding operator
Expand Down Expand Up @@ -818,6 +819,13 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre
aggs.push_back(agg.clone());
return aggs;
}

// NTH_ELEMENT aggregations are computed in finalize(). Skip preprocessing.
std::vector<std::unique_ptr<aggregation>> visit(
data_type, cudf::detail::nth_element_aggregation const&) override
{
return {};
}
};

/**
Expand Down Expand Up @@ -961,6 +969,17 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
}
}

// Nth_ELEMENT aggregation.
void visit(cudf::detail::nth_element_aggregation const& agg) override
{
result =
agg._null_handling == null_policy::EXCLUDE
? rolling::nth_element<null_policy::EXCLUDE>(
agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr)
: rolling::nth_element<null_policy::INCLUDE>(
agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr);
}

private:
column_view input;
column_view default_outputs;
Expand Down
Loading