Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nth_element for window functions #11158

Merged
merged 9 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,9 @@ class nunique_aggregation final : public groupby_aggregation, public reduce_aggr
/**
* @brief Derived class for specifying a nth element aggregation
*/
class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation {
class nth_element_aggregation final : public groupby_aggregation,
public reduce_aggregation,
public rolling_aggregation {
public:
nth_element_aggregation(size_type n, null_policy null_handling)
: aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,8 @@ template std::unique_ptr<groupby_aggregation> make_nth_element_aggregation<group
size_type n, null_policy null_handling);
template std::unique_ptr<reduce_aggregation> make_nth_element_aggregation<reduce_aggregation>(
size_type n, null_policy null_handling);
template std::unique_ptr<rolling_aggregation> make_nth_element_aggregation<rolling_aggregation>(
size_type n, null_policy null_handling);

/// Factory to create a ROW_NUMBER aggregation
template <typename Base>
Expand Down
95 changes: 95 additions & 0 deletions cpp/src/rolling/nth_element.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/aggregation.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/detail/gather.cuh>
#include <cudf/detail/iterator.cuh>
#include <cudf/utilities/bit.hpp>

#include <limits>
#include <rmm/exec_policy.hpp>

namespace cudf::detail::rolling {

auto constexpr NULL_INDEX = std::numeric_limits<size_type>::min(); // For nullifying with gather.
mythrocks marked this conversation as resolved.
Show resolved Hide resolved

template <null_policy null_handling, typename PrecedingIter, typename FollowingIter>
std::unique_ptr<column> nth_element(size_type n,
column_view const& input,
PrecedingIter preceding,
FollowingIter following,
size_type min_periods,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const gather_iter = cudf::detail::make_counting_transform_iterator(
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
0,
[exclude_nulls = null_handling == null_policy::EXCLUDE and input.nullable(),
preceding,
following,
min_periods,
n,
input_nullmask = input.null_mask()] __device__(size_type i) {
// preceding[i] includes the current row.
auto const window_size = preceding[i] + following[i];
if (min_periods > window_size) { return NULL_INDEX; }

auto const wrapped_n = n >= 0 ? n : (window_size + n);
if (wrapped_n < 0 || wrapped_n > (window_size - 1)) {
return NULL_INDEX; // n lies outside the window.
}

auto const window_start = i - preceding[i] + 1;

if (not exclude_nulls) { return window_start + wrapped_n; }

// Must exclude nulls, and n is in range [-window_size, window_size-1].
// Depending on n >= 0, count forwards from window_start, or backwards from window_end.
auto const window_end = window_start + window_size;

auto reqd_valid_count = n >= 0 ? n : (-n - 1);
auto const nth_valid = [&reqd_valid_count, input_nullmask](size_type j) {
return cudf::bit_is_set(input_nullmask, j) && reqd_valid_count-- == 0;
};

if (n >= 0) { // Search forwards from window_start.
auto const begin = thrust::make_counting_iterator(window_start);
auto const end = begin + window_size;
auto const found = thrust::find_if(thrust::seq, begin, end, nth_valid);
return found == end ? NULL_INDEX : *found;
} else { // Search backwards from window-end.
auto const begin =
thrust::make_reverse_iterator(thrust::make_counting_iterator(window_end));
auto const end = begin + window_size;
auto const found = thrust::find_if(thrust::seq, begin, end, nth_valid);
return found == end ? NULL_INDEX : *found;
}
});

auto gathered = cudf::detail::gather(table_view{{input}},
gather_iter,
gather_iter + input.size(),
cudf::out_of_bounds_policy::NULLIFY,
stream,
mr)
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
->release();
return std::move(gathered[0]);
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace cudf::detail::rolling
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 25 additions & 4 deletions cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
"Defaults column must be either empty or have as many rows as the input column.");

if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) {
// TODO: In future, might need to clamp preceding/following to column boundaries.
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
return cudf::detail::rolling_window_udf(input,
preceding_window,
"cudf::size_type",
Expand All @@ -54,8 +55,16 @@ std::unique_ptr<column> rolling_window(column_view const& input,
stream,
mr);
} else {
auto preceding_window_begin = thrust::make_constant_iterator(preceding_window);
auto following_window_begin = thrust::make_constant_iterator(following_window);
// Clamp preceding/following to column boundaries.
// E.g. If preceding_window == 2, then for a column of 5 elements, preceding_window will be:
// [1, 2, 2, 2, 1]
auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator(
0,
[preceding_window] __device__(size_type i) { return thrust::min(i + 1, preceding_window); });
auto const following_window_begin = cudf::detail::make_counting_transform_iterator(
0, [col_size = input.size(), following_window] __device__(size_type i) {
return thrust::min(col_size - i - 1, following_window);
});

return cudf::detail::rolling_window(input,
default_outputs,
Expand Down Expand Up @@ -91,6 +100,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
"preceding_window/following_window size must match input size");

if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) {
// TODO: In future, might need to clamp preceding/following to column boundaries.
return cudf::detail::rolling_window_udf(input,
preceding_window.begin<size_type>(),
"cudf::size_type*",
Expand All @@ -103,10 +113,21 @@ std::unique_ptr<column> rolling_window(column_view const& input,
} else {
auto defaults_col =
cudf::is_dictionary(input.type()) ? dictionary_column_view(input).indices() : input;
// Clamp preceding/following to column boundaries.
// E.g. If preceding_window == [2, 2, 2, 2, 2] for a column of 5 elements, the new
// preceding_window will be: [1, 2, 2, 2, 1]
auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator(
0, [preceding = preceding_window.begin<size_type>()] __device__(size_type i) {
return thrust::min(i + 1, preceding[i]);
});
auto const following_window_begin = cudf::detail::make_counting_transform_iterator(
0,
[col_size = input.size(), following = following_window.begin<size_type>()] __device__(
size_type i) { return thrust::min(col_size - i - 1, following[i]); });
return cudf::detail::rolling_window(input,
empty_like(defaults_col)->view(),
preceding_window.begin<size_type>(),
following_window.begin<size_type>(),
preceding_window_begin,
following_window_begin,
min_periods,
agg,
stream,
Expand Down
23 changes: 11 additions & 12 deletions cpp/src/rolling/rolling_collect_list.cu
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,17 @@ std::pair<std::unique_ptr<column>, std::unique_ptr<column>> purge_null_entries(
auto new_sizes = make_fixed_width_column(
data_type{type_to_id<size_type>()}, input.size(), mask_state::UNALLOCATED, stream);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(input.size()),
new_sizes->mutable_view().template begin<size_type>(),
[d_gather_map = gather_map.template begin<size_type>(),
d_old_offsets = offsets.template begin<size_type>(),
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});
thrust::tabulate(rmm::exec_policy(stream),
new_sizes->mutable_view().template begin<size_type>(),
new_sizes->mutable_view().template end<size_type>(),
[d_gather_map = gather_map.template begin<size_type>(),
d_old_offsets = offsets.template begin<size_type>(),
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
input_row_not_null] __device__(auto i) {
return thrust::count_if(thrust::seq,
d_gather_map + d_old_offsets[i],
d_gather_map + d_old_offsets[i + 1],
input_row_not_null);
});

auto new_offsets =
strings::detail::make_offsets_child_column(new_sizes->view().template begin<size_type>(),
Expand Down
21 changes: 20 additions & 1 deletion cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "lead_lag_nested_detail.cuh"
#include "nth_element.cuh"
#include "rolling/rolling_collect_list.cuh"
#include "rolling/rolling_detail.hpp"
#include "rolling/rolling_jit_detail.hpp"
Expand Down Expand Up @@ -604,7 +605,7 @@ struct DeviceRollingLag {
};

/**
* @brief Maps an `InputType and `aggregation::Kind` value to it's corresponding
* @brief Maps an `InputType and `aggregation::Kind` value to its corresponding
* rolling window operator.
*
* @tparam InputType The input type to map to its corresponding operator
Expand Down Expand Up @@ -818,6 +819,13 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre
aggs.push_back(agg.clone());
return aggs;
}

// NTH_ELEMENT aggregations are computed in finalize(). Skip preprocessing.
std::vector<std::unique_ptr<aggregation>> visit(
data_type, cudf::detail::nth_element_aggregation const&) override
{
return {};
}
};

/**
Expand Down Expand Up @@ -961,6 +969,17 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation
}
}

// Nth_ELEMENT aggregation.
void visit(cudf::detail::nth_element_aggregation const& agg) override
{
result =
agg._null_handling == null_policy::EXCLUDE
? rolling::nth_element<null_policy::EXCLUDE>(
agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr)
: rolling::nth_element<null_policy::INCLUDE>(
agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr);
}

private:
column_view input;
column_view default_outputs;
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ ConfigureTest(
rolling/empty_input_test.cpp
rolling/grouped_rolling_test.cpp
rolling/lead_lag_test.cpp
rolling/nth_element_test.cpp
rolling/range_rolling_window_test.cpp
rolling/range_window_bounds_test.cpp
rolling/rolling_test.cpp
Expand Down
Loading