From cdd4e03566ccfb08071c7ed07644d7f7c4cba2f5 Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 7 Jul 2022 14:41:59 -0700 Subject: [PATCH] Support `nth_element` for window functions (#11158) Fixes #9643. This is to address spark-rapids/issues/4005 and spark-rapids/issues/5061. This change adds support for `NTH_ELEMENT` window aggregations. This should allow for the implementation of `FIRST()`, `LAST()`, and `NTH_VALUE()` window functions in Spark SQL. `NTH_ELEMENT` in window function returns the Nth element from the specified window for each row in a column. `N` is deemed to be zero based, so `NTH_ELEMENT(0)` translates to the first element in a window. Similarly, `NTH_ELEMENT(-1)` translates to the last. If for any window of size `W`, if the specified `N` falls outside the range `[ -W, W-1 ]`, a null element is returned for that row. Authors: - MithunR (https://github.com/mythrocks) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) - David Wendt (https://github.com/davidwendt) - Nghia Truong (https://github.com/ttnghia) URL: https://github.com/rapidsai/cudf/pull/11158 --- cpp/CMakeLists.txt | 2 + .../cudf/detail/aggregation/aggregation.hpp | 4 +- cpp/src/aggregation/aggregation.cpp | 2 + cpp/src/rolling/nth_element.cuh | 167 +++++ cpp/src/rolling/rolling.cu | 93 --- cpp/src/rolling/rolling_collect_list.cu | 23 +- cpp/src/rolling/rolling_detail.cuh | 21 +- cpp/src/rolling/rolling_detail.hpp | 45 +- .../rolling/rolling_detail_fixed_window.cu | 78 +++ .../rolling/rolling_detail_variable_window.cu | 84 +++ cpp/tests/CMakeLists.txt | 1 + cpp/tests/rolling/nth_element_test.cpp | 624 ++++++++++++++++++ 12 files changed, 1032 insertions(+), 112 deletions(-) create mode 100644 cpp/src/rolling/nth_element.cuh create mode 100644 cpp/src/rolling/rolling_detail_fixed_window.cu create mode 100644 cpp/src/rolling/rolling_detail_variable_window.cu create mode 100644 cpp/tests/rolling/nth_element_test.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6548cb00c3b..4013f3894eb 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -431,6 +431,8 @@ add_library( src/rolling/range_window_bounds.cpp src/rolling/rolling.cu src/rolling/rolling_collect_list.cu + src/rolling/rolling_detail_fixed_window.cu + src/rolling/rolling_detail_variable_window.cu src/round/round.cu src/scalar/scalar.cpp src/scalar/scalar_factories.cpp diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 8ca49dd7d5f..75027c78a68 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -570,7 +570,9 @@ class nunique_aggregation final : public groupby_aggregation, public reduce_aggr /** * @brief Derived class for specifying a nth element aggregation */ -class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation { +class nth_element_aggregation final : public groupby_aggregation, + public reduce_aggregation, + public rolling_aggregation { public: nth_element_aggregation(size_type n, null_policy null_handling) : aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling} diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index 27732b25401..6dd014970c7 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -610,6 +610,8 @@ template std::unique_ptr make_nth_element_aggregation make_nth_element_aggregation( size_type n, null_policy null_handling); +template std::unique_ptr make_nth_element_aggregation( + size_type n, null_policy null_handling); /// Factory to create a ROW_NUMBER aggregation template diff --git a/cpp/src/rolling/nth_element.cuh b/cpp/src/rolling/nth_element.cuh new file mode 100644 index 00000000000..4a61b16ecac --- /dev/null +++ b/cpp/src/rolling/nth_element.cuh @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace cudf::detail::rolling { + +/** + * @brief Functor to construct gather-map indices for NTH_ELEMENT rolling aggregation. + * + * By definition, the `N`th element is deemed null (i.e. the gather index is set to "nullify") + * for the following cases: + * 1. The window has fewer elements than `min_periods`. + * 2. N falls outside the window, i.e. N ∉ [-window_size, window_size). + * 3. `null_handling == EXCLUDE`, and the window has fewer than `N` non-null elements. + * + * If none of the above holds true, the result is non-null. How the value is determined + * depends on `null_handling`: + * 1. `null_handling == INCLUDE`: The required value is the `N`th value from the window's start. + * i.e. the gather index is window_start + N (adjusted for negative N). + * 2. `null_handling == EXCLUDE`: The required value is the `N`th non-null value from the + * window's start. i.e. Return index of the `N`th non-null value. + */ +template +struct gather_index_calculator { + size_type n; + bitmask_type const* input_nullmask; + bool exclude_nulls; + PrecedingIter preceding; + FollowingIter following; + size_type min_periods; + rmm::cuda_stream_view stream; + + static size_type constexpr NULL_INDEX = + std::numeric_limits::min(); // For nullifying with gather. + + gather_index_calculator(size_type n, + column_view input, + PrecedingIter preceding, + FollowingIter following, + size_type min_periods, + rmm::cuda_stream_view stream) + : n{n}, + input_nullmask{input.null_mask()}, + exclude_nulls{null_handling == null_policy::EXCLUDE and input.has_nulls()}, + preceding{preceding}, + following{following}, + min_periods{min_periods}, + stream{stream} + { + } + + /// For `null_policy::EXCLUDE`, find gather index for `N`th non-null value. + template + size_type __device__ index_of_nth_non_null(Iter begin, size_type window_size) const + { + auto reqd_valid_count = n >= 0 ? n : (-n - 1); + auto const pred_nth_valid = [&reqd_valid_count, input_nullmask = input_nullmask](size_type j) { + return cudf::bit_is_set(input_nullmask, j) && reqd_valid_count-- == 0; + }; + auto const end = begin + window_size; + auto const found = thrust::find_if(thrust::seq, begin, end, pred_nth_valid); + return found == end ? NULL_INDEX : *found; + } + + size_type __device__ operator()(size_type i) const + { + // preceding[i] includes the current row. + auto const window_size = preceding[i] + following[i]; + if (min_periods > window_size) { return NULL_INDEX; } + + auto const wrapped_n = n >= 0 ? n : (window_size + n); + if (wrapped_n < 0 || wrapped_n > (window_size - 1)) { + return NULL_INDEX; // n lies outside the window. + } + + // Out of short-circuit exits. + // If nulls don't need to be excluded, a fixed window offset calculation is sufficient. + auto const window_start = i - preceding[i] + 1; + if (not exclude_nulls) { return window_start + wrapped_n; } + + // Must exclude nulls. Must examine each row in the window. + auto const window_end = window_start + window_size; + return n >= 0 ? index_of_nth_non_null(thrust::make_counting_iterator(window_start), window_size) + : index_of_nth_non_null( + thrust::make_reverse_iterator(thrust::make_counting_iterator(window_end)), + window_size); + } +}; + +/** + * @brief Helper function for NTH_ELEMENT window aggregation + * + * The `N`th element is deemed null for the following cases: + * 1. The window has fewer elements than `min_periods`. + * 2. N falls outside the window, i.e. N ∉ [-window_size, window_size). + * 3. `null_handling == EXCLUDE`, and the window has fewer than `N` non-null elements. + * + * If none of the above holds true, the result is non-null. How the value is determined + * depends on `null_handling`: + * 1. `null_handling == INCLUDE`: The required value is the `N`th value from the window's start. + * 2. `null_handling == EXCLUDE`: The required value is the `N`th *non-null* value from the + * window's start. If the window has fewer than `N` non-null values, the result is null. + * + * @tparam null_handling Whether to include/exclude null rows in the window + * @tparam PrecedingIter Type of iterator for preceding window + * @tparam FollowingIter Type of iterator for following window + * @param n The index of the element to be returned + * @param input The input column + * @param preceding Iterator specifying the preceding window bound + * @param following Iterator specifying the following window bound + * @param min_periods The minimum number of rows required in the window + * @param stream CUDA stream used for device memory operations and kernel launches + * @param mr Device memory resource used to allocate the returned column's device memory + * @return A column the `n`th element of the specified window for each row + */ +template +std::unique_ptr nth_element(size_type n, + column_view const& input, + PrecedingIter preceding, + FollowingIter following, + size_type min_periods, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto const gather_iter = cudf::detail::make_counting_transform_iterator( + 0, + gather_index_calculator{ + n, input, preceding, following, min_periods, stream}); + + auto gather_map = rmm::device_uvector(input.size(), stream); + thrust::copy( + rmm::exec_policy(stream), gather_iter, gather_iter + input.size(), gather_map.begin()); + + auto gathered = cudf::detail::gather(table_view{{input}}, + gather_map, + cudf::out_of_bounds_policy::NULLIFY, + negative_index_policy::NOT_ALLOWED, + stream, + mr) + ->release(); + return std::move(gathered.front()); +} + +} // namespace cudf::detail::rolling diff --git a/cpp/src/rolling/rolling.cu b/cpp/src/rolling/rolling.cu index d4a012610b9..918a7e7943d 100644 --- a/cpp/src/rolling/rolling.cu +++ b/cpp/src/rolling/rolling.cu @@ -22,99 +22,6 @@ #include namespace cudf { -namespace detail { - -// Applies a fixed-size rolling window function to the values in a column. -std::unique_ptr rolling_window(column_view const& input, - column_view const& default_outputs, - size_type preceding_window, - size_type following_window, - size_type min_periods, - rolling_aggregation const& agg, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - CUDF_FUNC_RANGE(); - - if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); } - - CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative"); - - CUDF_EXPECTS((default_outputs.is_empty() || default_outputs.size() == input.size()), - "Defaults column must be either empty or have as many rows as the input column."); - - if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) { - return cudf::detail::rolling_window_udf(input, - preceding_window, - "cudf::size_type", - following_window, - "cudf::size_type", - min_periods, - agg, - stream, - mr); - } else { - auto preceding_window_begin = thrust::make_constant_iterator(preceding_window); - auto following_window_begin = thrust::make_constant_iterator(following_window); - - return cudf::detail::rolling_window(input, - default_outputs, - preceding_window_begin, - following_window_begin, - min_periods, - agg, - stream, - mr); - } -} - -// Applies a variable-size rolling window function to the values in a column. -std::unique_ptr rolling_window(column_view const& input, - column_view const& preceding_window, - column_view const& following_window, - size_type min_periods, - rolling_aggregation const& agg, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) -{ - CUDF_FUNC_RANGE(); - - if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) { - return cudf::detail::empty_output_for_rolling_aggregation(input, agg); - } - - CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 && - following_window.type().id() == type_id::INT32, - "preceding_window/following_window must have type_id::INT32 type"); - - CUDF_EXPECTS(preceding_window.size() == input.size() && following_window.size() == input.size(), - "preceding_window/following_window size must match input size"); - - if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) { - return cudf::detail::rolling_window_udf(input, - preceding_window.begin(), - "cudf::size_type*", - following_window.begin(), - "cudf::size_type*", - min_periods, - agg, - stream, - mr); - } else { - auto defaults_col = - cudf::is_dictionary(input.type()) ? dictionary_column_view(input).indices() : input; - return cudf::detail::rolling_window(input, - empty_like(defaults_col)->view(), - preceding_window.begin(), - following_window.begin(), - min_periods, - agg, - stream, - mr); - } -} - -} // namespace detail // Applies a fixed-size rolling window function to the values in a column, with default output // specified diff --git a/cpp/src/rolling/rolling_collect_list.cu b/cpp/src/rolling/rolling_collect_list.cu index 5617995b348..d6f90b00f36 100644 --- a/cpp/src/rolling/rolling_collect_list.cu +++ b/cpp/src/rolling/rolling_collect_list.cu @@ -136,18 +136,17 @@ std::pair, std::unique_ptr> purge_null_entries( auto new_sizes = make_fixed_width_column( data_type{type_to_id()}, input.size(), mask_state::UNALLOCATED, stream); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(input.size()), - new_sizes->mutable_view().template begin(), - [d_gather_map = gather_map.template begin(), - d_old_offsets = offsets.template begin(), - input_row_not_null] __device__(auto i) { - return thrust::count_if(thrust::seq, - d_gather_map + d_old_offsets[i], - d_gather_map + d_old_offsets[i + 1], - input_row_not_null); - }); + thrust::tabulate(rmm::exec_policy(stream), + new_sizes->mutable_view().template begin(), + new_sizes->mutable_view().template end(), + [d_gather_map = gather_map.template begin(), + d_old_offsets = offsets.template begin(), + input_row_not_null] __device__(auto i) { + return thrust::count_if(thrust::seq, + d_gather_map + d_old_offsets[i], + d_gather_map + d_old_offsets[i + 1], + input_row_not_null); + }); auto new_offsets = strings::detail::make_offsets_child_column(new_sizes->view().template begin(), diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index ca07d60f426..9c58512ab92 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -17,6 +17,7 @@ #pragma once #include "lead_lag_nested_detail.cuh" +#include "nth_element.cuh" #include "rolling/rolling_collect_list.cuh" #include "rolling/rolling_detail.hpp" #include "rolling/rolling_jit_detail.hpp" @@ -604,7 +605,7 @@ struct DeviceRollingLag { }; /** - * @brief Maps an `InputType and `aggregation::Kind` value to it's corresponding + * @brief Maps an `InputType and `aggregation::Kind` value to its corresponding * rolling window operator. * * @tparam InputType The input type to map to its corresponding operator @@ -818,6 +819,13 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre aggs.push_back(agg.clone()); return aggs; } + + // NTH_ELEMENT aggregations are computed in finalize(). Skip preprocessing. + std::vector> visit( + data_type, cudf::detail::nth_element_aggregation const&) override + { + return {}; + } }; /** @@ -961,6 +969,17 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation } } + // Nth_ELEMENT aggregation. + void visit(cudf::detail::nth_element_aggregation const& agg) override + { + result = + agg._null_handling == null_policy::EXCLUDE + ? rolling::nth_element( + agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr) + : rolling::nth_element( + agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr); + } + private: column_view input; column_view default_outputs; diff --git a/cpp/src/rolling/rolling_detail.hpp b/cpp/src/rolling/rolling_detail.hpp index 80a9397922e..d2dfa2f9df5 100644 --- a/cpp/src/rolling/rolling_detail.hpp +++ b/cpp/src/rolling/rolling_detail.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,8 +14,7 @@ * limitations under the License. */ -#ifndef ROLLING_DETAIL_HPP -#define ROLLING_DETAIL_HPP +#pragma once #include #include @@ -57,8 +56,44 @@ struct rolling_store_output_functor<_T, true> { out = static_cast(val.time_since_epoch() / count); } }; + +/** + * @copydoc cudf::rolling_window(column_view const& input, + * column_view const& default_outputs, + * size_type preceding_window, + * size_type following_window, + * size_type min_periods, + * rolling_aggregation const& agg, + * rmm::mr::device_memory_resource* mr) + * + * @param stream CUDA stream to use for device memory operations + */ +std::unique_ptr rolling_window(column_view const& input, + column_view const& default_outputs, + size_type preceding_window, + size_type following_window, + size_type min_periods, + rolling_aggregation const& agg, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr); + +/** + * @copydoc cudf::rolling_window(column_view const& input, + * column_view const& preceding_window, + * column_view const& following_window, + * size_type min_periods, + * rolling_aggregation const& agg, + * rmm::mr::device_memory_resource* mr); + * + * @param stream CUDA stream to use for device memory operations + */ +std::unique_ptr rolling_window(column_view const& input, + column_view const& preceding_window, + column_view const& following_window, + size_type min_periods, + rolling_aggregation const& agg, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr); } // namespace detail } // namespace cudf - -#endif diff --git a/cpp/src/rolling/rolling_detail_fixed_window.cu b/cpp/src/rolling/rolling_detail_fixed_window.cu new file mode 100644 index 00000000000..21c406af8f1 --- /dev/null +++ b/cpp/src/rolling/rolling_detail_fixed_window.cu @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rolling_detail.cuh" + +#include +#include + +#include + +namespace cudf::detail { + +// Applies a fixed-size rolling window function to the values in a column. +std::unique_ptr rolling_window(column_view const& input, + column_view const& default_outputs, + size_type preceding_window, + size_type following_window, + size_type min_periods, + rolling_aggregation const& agg, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + + if (input.is_empty()) { return cudf::detail::empty_output_for_rolling_aggregation(input, agg); } + + CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative"); + + CUDF_EXPECTS((default_outputs.is_empty() || default_outputs.size() == input.size()), + "Defaults column must be either empty or have as many rows as the input column."); + + if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) { + // TODO: In future, might need to clamp preceding/following to column boundaries. + return cudf::detail::rolling_window_udf(input, + preceding_window, + "cudf::size_type", + following_window, + "cudf::size_type", + min_periods, + agg, + stream, + mr); + } else { + // Clamp preceding/following to column boundaries. + // E.g. If preceding_window == 2, then for a column of 5 elements, preceding_window will be: + // [1, 2, 2, 2, 1] + auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator( + 0, + [preceding_window] __device__(size_type i) { return thrust::min(i + 1, preceding_window); }); + auto const following_window_begin = cudf::detail::make_counting_transform_iterator( + 0, [col_size = input.size(), following_window] __device__(size_type i) { + return thrust::min(col_size - i - 1, following_window); + }); + + return cudf::detail::rolling_window(input, + default_outputs, + preceding_window_begin, + following_window_begin, + min_periods, + agg, + stream, + mr); + } +} +} // namespace cudf::detail diff --git a/cpp/src/rolling/rolling_detail_variable_window.cu b/cpp/src/rolling/rolling_detail_variable_window.cu new file mode 100644 index 00000000000..d41717a5dc7 --- /dev/null +++ b/cpp/src/rolling/rolling_detail_variable_window.cu @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rolling_detail.cuh" + +#include +#include + +#include + +namespace cudf::detail { + +// Applies a variable-size rolling window function to the values in a column. +std::unique_ptr rolling_window(column_view const& input, + column_view const& preceding_window, + column_view const& following_window, + size_type min_periods, + rolling_aggregation const& agg, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + + if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) { + return cudf::detail::empty_output_for_rolling_aggregation(input, agg); + } + + CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 && + following_window.type().id() == type_id::INT32, + "preceding_window/following_window must have type_id::INT32 type"); + + CUDF_EXPECTS(preceding_window.size() == input.size() && following_window.size() == input.size(), + "preceding_window/following_window size must match input size"); + + if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) { + // TODO: In future, might need to clamp preceding/following to column boundaries. + return cudf::detail::rolling_window_udf(input, + preceding_window.begin(), + "cudf::size_type*", + following_window.begin(), + "cudf::size_type*", + min_periods, + agg, + stream, + mr); + } else { + auto defaults_col = + cudf::is_dictionary(input.type()) ? dictionary_column_view(input).indices() : input; + // Clamp preceding/following to column boundaries. + // E.g. If preceding_window == [2, 2, 2, 2, 2] for a column of 5 elements, the new + // preceding_window will be: [1, 2, 2, 2, 1] + auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator( + 0, [preceding = preceding_window.begin()] __device__(size_type i) { + return thrust::min(i + 1, preceding[i]); + }); + auto const following_window_begin = cudf::detail::make_counting_transform_iterator( + 0, + [col_size = input.size(), following = following_window.begin()] __device__( + size_type i) { return thrust::min(col_size - i - 1, following[i]); }); + return cudf::detail::rolling_window(input, + empty_like(defaults_col)->view(), + preceding_window_begin, + following_window_begin, + min_periods, + agg, + stream, + mr); + } +} + +} // namespace cudf::detail diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 816c5a1c59c..8d8fc3210bb 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -340,6 +340,7 @@ ConfigureTest( rolling/empty_input_test.cpp rolling/grouped_rolling_test.cpp rolling/lead_lag_test.cpp + rolling/nth_element_test.cpp rolling/range_rolling_window_test.cpp rolling/range_window_bounds_test.cpp rolling/rolling_test.cpp diff --git a/cpp/tests/rolling/nth_element_test.cpp b/cpp/tests/rolling/nth_element_test.cpp new file mode 100644 index 00000000000..93276abbbb2 --- /dev/null +++ b/cpp/tests/rolling/nth_element_test.cpp @@ -0,0 +1,624 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +#include +#include +#include + +namespace cudf::test::rolling { + +auto constexpr X = int32_t{0}; // Placeholder for null. + +template +using fwcw = fixed_width_column_wrapper; +using grouping_keys_column = fwcw; + +using namespace cudf::test::iterators; + +/// Rolling test executor with fluent interface. +class rolling_exec { + size_type _preceding{1}; + size_type _following{0}; + size_type _min_periods{1}; + column_view _grouping; + column_view _input; + null_policy _null_handling = null_policy::INCLUDE; + + public: + rolling_exec& preceding(size_type preceding) + { + _preceding = preceding; + return *this; + } + rolling_exec& following(size_type following) + { + _following = following; + return *this; + } + rolling_exec& min_periods(size_type min_periods) + { + _min_periods = min_periods; + return *this; + } + rolling_exec& grouping(column_view grouping) + { + _grouping = grouping; + return *this; + } + rolling_exec& input(column_view input) + { + _input = input; + return *this; + } + rolling_exec& null_handling(null_policy null_handling) + { + _null_handling = null_handling; + return *this; + } + + std::unique_ptr test_grouped_nth_element( + size_type n, std::optional null_handling = std::nullopt) const + { + return cudf::grouped_rolling_window(table_view{{_grouping}}, + _input, + _preceding, + _following, + _min_periods, + *make_nth_element_aggregation( + n, null_handling.value_or(_null_handling))); + } + + std::unique_ptr test_nth_element( + size_type n, std::optional null_handling = std::nullopt) const + { + return cudf::rolling_window(_input, + _preceding, + _following, + _min_periods, + *make_nth_element_aggregation( + n, null_handling.value_or(_null_handling))); + } +}; + +struct NthElementTest : public cudf::test::BaseFixture { +}; + +template +struct NthElementTypedTest : public NthElementTest { +}; + +using TypesForTest = cudf::test::Concat; + +TYPED_TEST_SUITE(NthElementTypedTest, TypesForTest); + +TYPED_TEST(NthElementTypedTest, RollingWindow) +{ + using T = TypeParam; + + auto const input_col = fwcw{{0, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 20}, null_at(5)}; + auto tester = rolling_exec{}.input(input_col); + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, fwcw{{0, 0, 0, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15}, null_at(7)}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, fwcw{{2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 20, 20, 20}, null_at(3)}); + auto const third_element = tester.test_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *third_element, fwcw{{2, 2, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 20}, null_at(5)}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + fwcw{{1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 16, 16}, null_at(4)}); + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at column margins. + tester.preceding(2).following(1).min_periods(3); + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + fwcw{{X, 0, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, X}, nulls_at({0, 6, 13})}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + fwcw{{X, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 20, X}, nulls_at({0, 4, 13})}); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_element, + fwcw{{X, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, X}, nulls_at({0, 5, 13})}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + fwcw{{X, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, X}, nulls_at({0, 5, 13})}); + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = fwcw{{X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TYPED_TEST(NthElementTypedTest, RollingWindowExcludeNulls) +{ + using T = TypeParam; + + auto const input_col = fwcw{{0, X, X, X, 4, X, 6, 7}, nulls_at({1, 2, 3, 5})}; + auto tester = rolling_exec{}.input(input_col); + + { + // Window of 5 elements, min-periods == 2. + tester.preceding(3).following(2).min_periods(1).null_handling(null_policy::EXCLUDE); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{0, 0, 0, 4, 4, 4, 4, 6}, no_nulls()}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{0, 0, 4, 4, 6, 7, 7, 7}, no_nulls()}); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, + fwcw{{X, X, 4, X, 6, 6, 6, 7}, nulls_at({0, 1, 3})}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{X, X, 0, X, 4, 6, 6, 6}, nulls_at({0, 1, 3})}); + } + { + // Window of 3 elements, min-periods == 1. + tester.preceding(2).following(1).min_periods(1).null_handling(null_policy::EXCLUDE); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{0, 0, X, 4, 4, 4, 6, 6}, null_at(2)}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{0, 0, X, 4, 4, 6, 7, 7}, null_at(2)}); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_element, fwcw{{X, X, X, X, X, 6, 7, 7}, nulls_at({0, 1, 2, 3, 4})}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, fwcw{{X, X, X, X, X, 4, 6, 6}, nulls_at({0, 1, 2, 3, 4})}); + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = fwcw{{X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TYPED_TEST(NthElementTypedTest, GroupedRollingWindow) +{ + using T = TypeParam; + + // clang-format off + auto const group_col = fwcw{0, 0, 0, 0, 0, 0, + 10, 10, 10, 10, 10, 10, 10, + 20}; + auto const input_col = fwcw {0, 1, 2, 3, 4, 5, // Group 0 + 10, 11, 12, 13, 14, 15, 16, // Group 10 + 20}; // Group 20 + // clang-format on + auto tester = rolling_exec{}.grouping(group_col).input(input_col); + + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{0, 0, 0, 1, 2, 3, // Group 0 + 10, 10, 10, 11, 12, 13, 14, // Group 10 + 20}, // Group 20 + no_nulls()}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{2, 3, 4, 5, 5, 5, // Group 0 + 12, 13, 14, 15, 16, 16, 16, // Group 10 + 20}, // Group 20 + no_nulls()}); + auto const third_element = tester.test_grouped_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*third_element, + fwcw{{2, 2, 2, 3, 4, 5, // Group 0 + 12, 12, 12, 13, 14, 15, 16, // Group 10 + X}, // Group 20 + null_at(13)}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{1, 2, 3, 4, 4, 4, // Group 0 + 11, 12, 13, 14, 15, 15, 15, // Group 10 + X}, // Group 20 + null_at(13)}); + // clang-format on + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at group margins. + tester.preceding(2).following(1).min_periods(3); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{X, 0, 1, 2, 3, X, // Group 0 + X, 10, 11, 12, 13, 14, X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{X, 2, 3, 4, 5, X, // Group 0 + X, 12, 13, 14, 15, 16, X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, + fwcw{{X, 1, 2, 3, 4, X, // Group 0 + X, 11, 12, 13, 14, 15, X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{X, 1, 2, 3, 4, X, // Group 0 + X, 11, 12, 13, 14, 15, X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + // clang-format on + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = fwcw{{X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_grouped_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TYPED_TEST(NthElementTypedTest, GroupedRollingWindowExcludeNulls) +{ + using T = TypeParam; + + // clang-format off + auto const group_col = fwcw{0, 0, 0, 0, 0, 0, + 10, 10, 10, 10, 10, 10, 10, + 20, + 30}; + auto const input_col = fwcw {{0, 1, X, 3, X, 5, // Group 0 + 10, X, X, 13, 14, 15, 16, // Group 10 + 20, // Group 20 + X}, // Group 30 + nulls_at({2, 4, 7, 8, 14})}; + // clang-format on + auto tester = rolling_exec{}.grouping(group_col).input(input_col); + + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1).null_handling(null_policy::EXCLUDE); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{0, 0, 0, 1, 3, 3, // Group 0 + 10, 10, 10, 13, 13, 13, 14, // Group 10 + 20, // Group 20 + X}, // Group 30 + null_at(14)}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{1, 3, 3, 5, 5, 5, // Group 0 + 10, 13, 14, 15, 16, 16, 16, // Group 10 + 20, // Group 20 + X}, // Group 30 + null_at(14)}); + auto const third_element = tester.test_grouped_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*third_element, + fwcw{{X, 3, 3, 5, X, X, // Group 0 + X, X, 14, 15, 15, 15, 16, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 4, 5, 6, 7, 13, 14})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{0, 1, 1, 3, 3, 3, // Group 0 + X, 10, 13, 14, 15, 15, 15, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({6, 13, 14})}); + // clang-format on + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at group margins. + tester.preceding(2).following(1).min_periods(3); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{X, 0, 1, 3, 3, X, // Group 0 + X, 10, 13, 13, 13, 14, X, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 5, 6, 12, 13, 14})}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{X, 1, 3, 3, 5, X, // Group 0 + X, 10, 13, 14, 15, 16, X, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 5, 6, 12, 13, 14})}); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, + fwcw{{X, 1, 3, X, 5, X, // Group 0 + X, X, X, 14, 14, 15, X, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 3, 5, 6, 7, 8, 12, 13, 14})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{X, 0, 1, X, 3, X, // Group 0 + X, X, X, 13, 14, 15, X, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 3, 5, 6, 7, 8, 12, 13, 14})}); + // clang-format on + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = + fwcw{{X, X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_grouped_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TYPED_TEST(NthElementTypedTest, EmptyInput) +{ + using T = TypeParam; + + auto const group_col = fwcw{}; + auto const input_col = fwcw{}; + auto tester = rolling_exec{}.grouping(group_col).input(input_col).preceding(3).following(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*tester.test_grouped_nth_element(0), fwcw{}); +} + +TEST_F(NthElementTest, RollingWindowOnStrings) +{ + using strings = strings_column_wrapper; + + auto constexpr X = ""; // Placeholder for null string. + + auto const input_col = strings{ + {"", "1", "22", "333", "4444", "", "10", "11", "12", "13", "14", "15", "16", "20"}, null_at(5)}; + auto tester = rolling_exec{}.input(input_col); + + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + strings{{"", "", "", "1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15"}, + null_at(7)}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + strings{{"22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", "20", "20", "20"}, + null_at(3)}); + auto const third_element = tester.test_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *third_element, + strings{{"22", "22", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", "20"}, + null_at(5)}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + strings{{"1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", "16", "16"}, + null_at(4)}); + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at column margins. + tester.preceding(2).following(1).min_periods(3); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + strings{{X, "", "1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", X}, + nulls_at({0, 6, 13})}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + strings{{X, "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", "20", X}, + nulls_at({0, 4, 13})}); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_element, + strings{{X, "1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", X}, + nulls_at({0, 5, 13})}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + strings{{X, "1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", X}, + nulls_at({0, 5, 13})}); + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = strings{{X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TEST_F(NthElementTest, GroupedRollingWindowForStrings) +{ + using strings = strings_column_wrapper; + auto constexpr X = ""; // Placeholder for null strings. + + // clang-format off + auto const group_col = fwcw{0, 0, 0, 0, 0, 0, + 10, 10, 10, 10, 10, 10, 10, + 20}; + auto const input_col = strings{{"", "1", "22", "333", "4444", X, // Group 0 + "10", "11", "12", "13", "14", "15", "16", // Group 10 + "20"}, // Group 20 + null_at(5)}; + // clang-format on + auto tester = rolling_exec{}.grouping(group_col).input(input_col); + + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1); + + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + strings{{"", "", "", "1", "22", "333", // Group 0 + "10", "10", "10", "11", "12", "13", "14", // Group 10 + "20"}, // Group 20 + no_nulls()}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + strings{{"22", "333", "4444", X, X, X, // Group 0 + "12", "13", "14", "15", "16", "16", "16", // Group 10 + "20"}, // Group 20 + nulls_at({3, 4, 5})}); + auto const third_element = tester.test_grouped_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *third_element, + strings{{"22", "22", "22", "333", "4444", X, // Group 0 + "12", "12", "12", "13", "14", "15", "16", // Group 10 + X}, // Group 20 + nulls_at({5, 13})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + strings{{"1", "22", "333", "4444", "4444", "4444", // Group 0 + "11", "12", "13", "14", "15", "15", "15", // Group 10 + X}, // Group 20 + null_at(13)}); + // clang-format on + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at group margins. + tester.preceding(2).following(1).min_periods(3); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + strings{{X, "", "1", "22", "333", X, // Group 0 + X, "10", "11", "12", "13", "14", X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + strings{{X, "22", "333", "4444", X, X, // Group 0 + X, "12", "13", "14", "15", "16", X, // Group 10 + X}, // Group 20 + nulls_at({0, 4, 5, 6, 12, 13})}); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_element, + strings{{X, "1", "22", "333", "4444", X, // Group 0 + X, "11", "12", "13", "14", "15", X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + strings{{X, "1", "22", "333", "4444", X, // Group 0 + X, "11", "12", "13", "14", "15", X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + // clang-format on + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_strings = strings{{X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_grouped_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_strings); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_strings); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_strings); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_strings); + } +} + +} // namespace cudf::test::rolling