diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 8ca49dd7d5f..75027c78a68 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -570,7 +570,9 @@ class nunique_aggregation final : public groupby_aggregation, public reduce_aggr /** * @brief Derived class for specifying a nth element aggregation */ -class nth_element_aggregation final : public groupby_aggregation, public reduce_aggregation { +class nth_element_aggregation final : public groupby_aggregation, + public reduce_aggregation, + public rolling_aggregation { public: nth_element_aggregation(size_type n, null_policy null_handling) : aggregation{NTH_ELEMENT}, _n{n}, _null_handling{null_handling} diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index 27732b25401..6dd014970c7 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -610,6 +610,8 @@ template std::unique_ptr make_nth_element_aggregation make_nth_element_aggregation( size_type n, null_policy null_handling); +template std::unique_ptr make_nth_element_aggregation( + size_type n, null_policy null_handling); /// Factory to create a ROW_NUMBER aggregation template diff --git a/cpp/src/rolling/nth_element.cuh b/cpp/src/rolling/nth_element.cuh new file mode 100644 index 00000000000..7f81c0ca657 --- /dev/null +++ b/cpp/src/rolling/nth_element.cuh @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace cudf::detail::rolling { + +auto constexpr NULL_INDEX = std::numeric_limits::min(); // For nullifying with gather. + +template +std::unique_ptr nth_element(size_type n, + column_view const& input, + PrecedingIter preceding, + FollowingIter following, + size_type min_periods, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + auto const gather_iter = cudf::detail::make_counting_transform_iterator( + 0, + [exclude_nulls = null_handling == null_policy::EXCLUDE and input.nullable(), + preceding, + following, + min_periods, + n, + input_nullmask = input.null_mask()] __device__(size_type i) { + // preceding[i] includes the current row. + auto const window_size = preceding[i] + following[i]; + if (min_periods > window_size) { return NULL_INDEX; } + + auto const wrapped_n = n >= 0 ? n : (window_size + n); + if (wrapped_n < 0 || wrapped_n > (window_size - 1)) { + return NULL_INDEX; // n lies outside the window. + } + + auto const window_start = i - preceding[i] + 1; + + if (not exclude_nulls) { return window_start + wrapped_n; } + + // Must exclude nulls, and n is in range [-window_size, window_size-1]. + // Depending on n >= 0, count forwards from window_start, or backwards from window_end. + auto const window_end = window_start + window_size; + + auto reqd_valid_count = n >= 0 ? n : (-n - 1); + auto const nth_valid = [&reqd_valid_count, input_nullmask](size_type j) { + return cudf::bit_is_set(input_nullmask, j) && reqd_valid_count-- == 0; + }; + + if (n >= 0) { // Search forwards from window_start. + auto const begin = thrust::make_counting_iterator(window_start); + auto const end = begin + window_size; + auto const found = thrust::find_if(thrust::seq, begin, end, nth_valid); + return found == end ? NULL_INDEX : *found; + } else { // Search backwards from window-end. + auto const begin = + thrust::make_reverse_iterator(thrust::make_counting_iterator(window_end)); + auto const end = begin + window_size; + auto const found = thrust::find_if(thrust::seq, begin, end, nth_valid); + return found == end ? NULL_INDEX : *found; + } + }); + + auto gathered = cudf::detail::gather(table_view{{input}}, + gather_iter, + gather_iter + input.size(), + cudf::out_of_bounds_policy::NULLIFY, + stream, + mr) + ->release(); + return std::move(gathered[0]); +} + +} // namespace cudf::detail::rolling \ No newline at end of file diff --git a/cpp/src/rolling/rolling.cu b/cpp/src/rolling/rolling.cu index d4a012610b9..850937cafe0 100644 --- a/cpp/src/rolling/rolling.cu +++ b/cpp/src/rolling/rolling.cu @@ -44,6 +44,7 @@ std::unique_ptr rolling_window(column_view const& input, "Defaults column must be either empty or have as many rows as the input column."); if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) { + // TODO: In future, might need to clamp preceding/following to column boundaries. return cudf::detail::rolling_window_udf(input, preceding_window, "cudf::size_type", @@ -54,8 +55,16 @@ std::unique_ptr rolling_window(column_view const& input, stream, mr); } else { - auto preceding_window_begin = thrust::make_constant_iterator(preceding_window); - auto following_window_begin = thrust::make_constant_iterator(following_window); + // Clamp preceding/following to column boundaries. + // E.g. If preceding_window == 2, then for a column of 5 elements, preceding_window will be: + // [1, 2, 2, 2, 1] + auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator( + 0, + [preceding_window] __device__(size_type i) { return thrust::min(i + 1, preceding_window); }); + auto const following_window_begin = cudf::detail::make_counting_transform_iterator( + 0, [col_size = input.size(), following_window] __device__(size_type i) { + return thrust::min(col_size - i - 1, following_window); + }); return cudf::detail::rolling_window(input, default_outputs, @@ -91,6 +100,7 @@ std::unique_ptr rolling_window(column_view const& input, "preceding_window/following_window size must match input size"); if (agg.kind == aggregation::CUDA || agg.kind == aggregation::PTX) { + // TODO: In future, might need to clamp preceding/following to column boundaries. return cudf::detail::rolling_window_udf(input, preceding_window.begin(), "cudf::size_type*", @@ -103,10 +113,21 @@ std::unique_ptr rolling_window(column_view const& input, } else { auto defaults_col = cudf::is_dictionary(input.type()) ? dictionary_column_view(input).indices() : input; + // Clamp preceding/following to column boundaries. + // E.g. If preceding_window == [2, 2, 2, 2, 2] for a column of 5 elements, the new + // preceding_window will be: [1, 2, 2, 2, 1] + auto const preceding_window_begin = cudf::detail::make_counting_transform_iterator( + 0, [preceding = preceding_window.begin()] __device__(size_type i) { + return thrust::min(i + 1, preceding[i]); + }); + auto const following_window_begin = cudf::detail::make_counting_transform_iterator( + 0, + [col_size = input.size(), following = following_window.begin()] __device__( + size_type i) { return thrust::min(col_size - i - 1, following[i]); }); return cudf::detail::rolling_window(input, empty_like(defaults_col)->view(), - preceding_window.begin(), - following_window.begin(), + preceding_window_begin, + following_window_begin, min_periods, agg, stream, diff --git a/cpp/src/rolling/rolling_collect_list.cu b/cpp/src/rolling/rolling_collect_list.cu index 5617995b348..7e9506ee1cd 100644 --- a/cpp/src/rolling/rolling_collect_list.cu +++ b/cpp/src/rolling/rolling_collect_list.cu @@ -136,18 +136,17 @@ std::pair, std::unique_ptr> purge_null_entries( auto new_sizes = make_fixed_width_column( data_type{type_to_id()}, input.size(), mask_state::UNALLOCATED, stream); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(input.size()), - new_sizes->mutable_view().template begin(), - [d_gather_map = gather_map.template begin(), - d_old_offsets = offsets.template begin(), - input_row_not_null] __device__(auto i) { - return thrust::count_if(thrust::seq, - d_gather_map + d_old_offsets[i], - d_gather_map + d_old_offsets[i + 1], - input_row_not_null); - }); + thrust::tabulate(rmm::exec_policy(stream), + new_sizes->mutable_view().template begin(), + new_sizes->mutable_view().template end(), + [d_gather_map = gather_map.template begin(), + d_old_offsets = offsets.template begin(), + input_row_not_null] __device__(auto i) { + return thrust::count_if(thrust::seq, + d_gather_map + d_old_offsets[i], + d_gather_map + d_old_offsets[i + 1], + input_row_not_null); + }); auto new_offsets = strings::detail::make_offsets_child_column(new_sizes->view().template begin(), diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index ca07d60f426..9c58512ab92 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -17,6 +17,7 @@ #pragma once #include "lead_lag_nested_detail.cuh" +#include "nth_element.cuh" #include "rolling/rolling_collect_list.cuh" #include "rolling/rolling_detail.hpp" #include "rolling/rolling_jit_detail.hpp" @@ -604,7 +605,7 @@ struct DeviceRollingLag { }; /** - * @brief Maps an `InputType and `aggregation::Kind` value to it's corresponding + * @brief Maps an `InputType and `aggregation::Kind` value to its corresponding * rolling window operator. * * @tparam InputType The input type to map to its corresponding operator @@ -818,6 +819,13 @@ class rolling_aggregation_preprocessor final : public cudf::detail::simple_aggre aggs.push_back(agg.clone()); return aggs; } + + // NTH_ELEMENT aggregations are computed in finalize(). Skip preprocessing. + std::vector> visit( + data_type, cudf::detail::nth_element_aggregation const&) override + { + return {}; + } }; /** @@ -961,6 +969,17 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation } } + // Nth_ELEMENT aggregation. + void visit(cudf::detail::nth_element_aggregation const& agg) override + { + result = + agg._null_handling == null_policy::EXCLUDE + ? rolling::nth_element( + agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr) + : rolling::nth_element( + agg._n, input, preceding_window_begin, following_window_begin, min_periods, stream, mr); + } + private: column_view input; column_view default_outputs; diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 816c5a1c59c..8d8fc3210bb 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -340,6 +340,7 @@ ConfigureTest( rolling/empty_input_test.cpp rolling/grouped_rolling_test.cpp rolling/lead_lag_test.cpp + rolling/nth_element_test.cpp rolling/range_rolling_window_test.cpp rolling/range_window_bounds_test.cpp rolling/rolling_test.cpp diff --git a/cpp/tests/rolling/nth_element_test.cpp b/cpp/tests/rolling/nth_element_test.cpp new file mode 100644 index 00000000000..93276abbbb2 --- /dev/null +++ b/cpp/tests/rolling/nth_element_test.cpp @@ -0,0 +1,624 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +#include +#include +#include + +namespace cudf::test::rolling { + +auto constexpr X = int32_t{0}; // Placeholder for null. + +template +using fwcw = fixed_width_column_wrapper; +using grouping_keys_column = fwcw; + +using namespace cudf::test::iterators; + +/// Rolling test executor with fluent interface. +class rolling_exec { + size_type _preceding{1}; + size_type _following{0}; + size_type _min_periods{1}; + column_view _grouping; + column_view _input; + null_policy _null_handling = null_policy::INCLUDE; + + public: + rolling_exec& preceding(size_type preceding) + { + _preceding = preceding; + return *this; + } + rolling_exec& following(size_type following) + { + _following = following; + return *this; + } + rolling_exec& min_periods(size_type min_periods) + { + _min_periods = min_periods; + return *this; + } + rolling_exec& grouping(column_view grouping) + { + _grouping = grouping; + return *this; + } + rolling_exec& input(column_view input) + { + _input = input; + return *this; + } + rolling_exec& null_handling(null_policy null_handling) + { + _null_handling = null_handling; + return *this; + } + + std::unique_ptr test_grouped_nth_element( + size_type n, std::optional null_handling = std::nullopt) const + { + return cudf::grouped_rolling_window(table_view{{_grouping}}, + _input, + _preceding, + _following, + _min_periods, + *make_nth_element_aggregation( + n, null_handling.value_or(_null_handling))); + } + + std::unique_ptr test_nth_element( + size_type n, std::optional null_handling = std::nullopt) const + { + return cudf::rolling_window(_input, + _preceding, + _following, + _min_periods, + *make_nth_element_aggregation( + n, null_handling.value_or(_null_handling))); + } +}; + +struct NthElementTest : public cudf::test::BaseFixture { +}; + +template +struct NthElementTypedTest : public NthElementTest { +}; + +using TypesForTest = cudf::test::Concat; + +TYPED_TEST_SUITE(NthElementTypedTest, TypesForTest); + +TYPED_TEST(NthElementTypedTest, RollingWindow) +{ + using T = TypeParam; + + auto const input_col = fwcw{{0, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 20}, null_at(5)}; + auto tester = rolling_exec{}.input(input_col); + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, fwcw{{0, 0, 0, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15}, null_at(7)}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, fwcw{{2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 20, 20, 20}, null_at(3)}); + auto const third_element = tester.test_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *third_element, fwcw{{2, 2, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 20}, null_at(5)}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + fwcw{{1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 16, 16}, null_at(4)}); + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at column margins. + tester.preceding(2).following(1).min_periods(3); + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + fwcw{{X, 0, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, X}, nulls_at({0, 6, 13})}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + fwcw{{X, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, 20, X}, nulls_at({0, 4, 13})}); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_element, + fwcw{{X, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, X}, nulls_at({0, 5, 13})}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + fwcw{{X, 1, 2, 3, 4, X, 10, 11, 12, 13, 14, 15, 16, X}, nulls_at({0, 5, 13})}); + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = fwcw{{X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TYPED_TEST(NthElementTypedTest, RollingWindowExcludeNulls) +{ + using T = TypeParam; + + auto const input_col = fwcw{{0, X, X, X, 4, X, 6, 7}, nulls_at({1, 2, 3, 5})}; + auto tester = rolling_exec{}.input(input_col); + + { + // Window of 5 elements, min-periods == 2. + tester.preceding(3).following(2).min_periods(1).null_handling(null_policy::EXCLUDE); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{0, 0, 0, 4, 4, 4, 4, 6}, no_nulls()}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{0, 0, 4, 4, 6, 7, 7, 7}, no_nulls()}); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, + fwcw{{X, X, 4, X, 6, 6, 6, 7}, nulls_at({0, 1, 3})}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{X, X, 0, X, 4, 6, 6, 6}, nulls_at({0, 1, 3})}); + } + { + // Window of 3 elements, min-periods == 1. + tester.preceding(2).following(1).min_periods(1).null_handling(null_policy::EXCLUDE); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{0, 0, X, 4, 4, 4, 6, 6}, null_at(2)}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{0, 0, X, 4, 4, 6, 7, 7}, null_at(2)}); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_element, fwcw{{X, X, X, X, X, 6, 7, 7}, nulls_at({0, 1, 2, 3, 4})}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, fwcw{{X, X, X, X, X, 4, 6, 6}, nulls_at({0, 1, 2, 3, 4})}); + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = fwcw{{X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TYPED_TEST(NthElementTypedTest, GroupedRollingWindow) +{ + using T = TypeParam; + + // clang-format off + auto const group_col = fwcw{0, 0, 0, 0, 0, 0, + 10, 10, 10, 10, 10, 10, 10, + 20}; + auto const input_col = fwcw {0, 1, 2, 3, 4, 5, // Group 0 + 10, 11, 12, 13, 14, 15, 16, // Group 10 + 20}; // Group 20 + // clang-format on + auto tester = rolling_exec{}.grouping(group_col).input(input_col); + + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{0, 0, 0, 1, 2, 3, // Group 0 + 10, 10, 10, 11, 12, 13, 14, // Group 10 + 20}, // Group 20 + no_nulls()}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{2, 3, 4, 5, 5, 5, // Group 0 + 12, 13, 14, 15, 16, 16, 16, // Group 10 + 20}, // Group 20 + no_nulls()}); + auto const third_element = tester.test_grouped_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*third_element, + fwcw{{2, 2, 2, 3, 4, 5, // Group 0 + 12, 12, 12, 13, 14, 15, 16, // Group 10 + X}, // Group 20 + null_at(13)}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{1, 2, 3, 4, 4, 4, // Group 0 + 11, 12, 13, 14, 15, 15, 15, // Group 10 + X}, // Group 20 + null_at(13)}); + // clang-format on + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at group margins. + tester.preceding(2).following(1).min_periods(3); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{X, 0, 1, 2, 3, X, // Group 0 + X, 10, 11, 12, 13, 14, X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{X, 2, 3, 4, 5, X, // Group 0 + X, 12, 13, 14, 15, 16, X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, + fwcw{{X, 1, 2, 3, 4, X, // Group 0 + X, 11, 12, 13, 14, 15, X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{X, 1, 2, 3, 4, X, // Group 0 + X, 11, 12, 13, 14, 15, X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + // clang-format on + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = fwcw{{X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_grouped_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TYPED_TEST(NthElementTypedTest, GroupedRollingWindowExcludeNulls) +{ + using T = TypeParam; + + // clang-format off + auto const group_col = fwcw{0, 0, 0, 0, 0, 0, + 10, 10, 10, 10, 10, 10, 10, + 20, + 30}; + auto const input_col = fwcw {{0, 1, X, 3, X, 5, // Group 0 + 10, X, X, 13, 14, 15, 16, // Group 10 + 20, // Group 20 + X}, // Group 30 + nulls_at({2, 4, 7, 8, 14})}; + // clang-format on + auto tester = rolling_exec{}.grouping(group_col).input(input_col); + + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1).null_handling(null_policy::EXCLUDE); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{0, 0, 0, 1, 3, 3, // Group 0 + 10, 10, 10, 13, 13, 13, 14, // Group 10 + 20, // Group 20 + X}, // Group 30 + null_at(14)}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{1, 3, 3, 5, 5, 5, // Group 0 + 10, 13, 14, 15, 16, 16, 16, // Group 10 + 20, // Group 20 + X}, // Group 30 + null_at(14)}); + auto const third_element = tester.test_grouped_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*third_element, + fwcw{{X, 3, 3, 5, X, X, // Group 0 + X, X, 14, 15, 15, 15, 16, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 4, 5, 6, 7, 13, 14})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{0, 1, 1, 3, 3, 3, // Group 0 + X, 10, 13, 14, 15, 15, 15, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({6, 13, 14})}); + // clang-format on + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at group margins. + tester.preceding(2).following(1).min_periods(3); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, + fwcw{{X, 0, 1, 3, 3, X, // Group 0 + X, 10, 13, 13, 13, 14, X, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 5, 6, 12, 13, 14})}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, + fwcw{{X, 1, 3, 3, 5, X, // Group 0 + X, 10, 13, 14, 15, 16, X, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 5, 6, 12, 13, 14})}); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, + fwcw{{X, 1, 3, X, 5, X, // Group 0 + X, X, X, 14, 14, 15, X, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 3, 5, 6, 7, 8, 12, 13, 14})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, + fwcw{{X, 0, 1, X, 3, X, // Group 0 + X, X, X, 13, 14, 15, X, // Group 10 + X, // Group 20 + X}, // Group 30 + nulls_at({0, 3, 5, 6, 7, 8, 12, 13, 14})}); + // clang-format on + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = + fwcw{{X, X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_grouped_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TYPED_TEST(NthElementTypedTest, EmptyInput) +{ + using T = TypeParam; + + auto const group_col = fwcw{}; + auto const input_col = fwcw{}; + auto tester = rolling_exec{}.grouping(group_col).input(input_col).preceding(3).following(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*tester.test_grouped_nth_element(0), fwcw{}); +} + +TEST_F(NthElementTest, RollingWindowOnStrings) +{ + using strings = strings_column_wrapper; + + auto constexpr X = ""; // Placeholder for null string. + + auto const input_col = strings{ + {"", "1", "22", "333", "4444", "", "10", "11", "12", "13", "14", "15", "16", "20"}, null_at(5)}; + auto tester = rolling_exec{}.input(input_col); + + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + strings{{"", "", "", "1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15"}, + null_at(7)}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + strings{{"22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", "20", "20", "20"}, + null_at(3)}); + auto const third_element = tester.test_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *third_element, + strings{{"22", "22", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", "20"}, + null_at(5)}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + strings{{"1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", "16", "16"}, + null_at(4)}); + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at column margins. + tester.preceding(2).following(1).min_periods(3); + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + strings{{X, "", "1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", X}, + nulls_at({0, 6, 13})}); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + strings{{X, "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", "20", X}, + nulls_at({0, 4, 13})}); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_element, + strings{{X, "1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", X}, + nulls_at({0, 5, 13})}); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + strings{{X, "1", "22", "333", "4444", X, "10", "11", "12", "13", "14", "15", "16", X}, + nulls_at({0, 5, 13})}); + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_values = strings{{X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_values); + auto const last_element = tester.test_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_values); + auto const second_element = tester.test_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_values); + auto const second_last_element = tester.test_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_values); + } +} + +TEST_F(NthElementTest, GroupedRollingWindowForStrings) +{ + using strings = strings_column_wrapper; + auto constexpr X = ""; // Placeholder for null strings. + + // clang-format off + auto const group_col = fwcw{0, 0, 0, 0, 0, 0, + 10, 10, 10, 10, 10, 10, 10, + 20}; + auto const input_col = strings{{"", "1", "22", "333", "4444", X, // Group 0 + "10", "11", "12", "13", "14", "15", "16", // Group 10 + "20"}, // Group 20 + null_at(5)}; + // clang-format on + auto tester = rolling_exec{}.grouping(group_col).input(input_col); + + { + // Window of 5 elements, min-periods == 1. + tester.preceding(3).following(2).min_periods(1); + + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + strings{{"", "", "", "1", "22", "333", // Group 0 + "10", "10", "10", "11", "12", "13", "14", // Group 10 + "20"}, // Group 20 + no_nulls()}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + strings{{"22", "333", "4444", X, X, X, // Group 0 + "12", "13", "14", "15", "16", "16", "16", // Group 10 + "20"}, // Group 20 + nulls_at({3, 4, 5})}); + auto const third_element = tester.test_grouped_nth_element(2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *third_element, + strings{{"22", "22", "22", "333", "4444", X, // Group 0 + "12", "12", "12", "13", "14", "15", "16", // Group 10 + X}, // Group 20 + nulls_at({5, 13})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + strings{{"1", "22", "333", "4444", "4444", "4444", // Group 0 + "11", "12", "13", "14", "15", "15", "15", // Group 10 + X}, // Group 20 + null_at(13)}); + // clang-format on + } + { + // Window of 3 elements, min-periods == 3. Expect null elements at group margins. + tester.preceding(2).following(1).min_periods(3); + auto const first_element = tester.test_grouped_nth_element(0); + // clang-format off + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *first_element, + strings{{X, "", "1", "22", "333", X, // Group 0 + X, "10", "11", "12", "13", "14", X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *last_element, + strings{{X, "22", "333", "4444", X, X, // Group 0 + X, "12", "13", "14", "15", "16", X, // Group 10 + X}, // Group 20 + nulls_at({0, 4, 5, 6, 12, 13})}); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_element, + strings{{X, "1", "22", "333", "4444", X, // Group 0 + X, "11", "12", "13", "14", "15", X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT( + *second_last_element, + strings{{X, "1", "22", "333", "4444", X, // Group 0 + X, "11", "12", "13", "14", "15", X, // Group 10 + X}, // Group 20 + nulls_at({0, 5, 6, 12, 13})}); + // clang-format on + } + { + // Too large values for `min_periods`. No window has enough periods. + tester.preceding(2).following(1).min_periods(4); + auto const all_null_strings = strings{{X, X, X, X, X, X, X, X, X, X, X, X, X, X}, all_nulls()}; + + auto const first_element = tester.test_grouped_nth_element(0); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*first_element, all_null_strings); + auto const last_element = tester.test_grouped_nth_element(-1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*last_element, all_null_strings); + auto const second_element = tester.test_grouped_nth_element(1); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_element, all_null_strings); + auto const second_last_element = tester.test_grouped_nth_element(-2); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*second_last_element, all_null_strings); + } +} + +} // namespace cudf::test::rolling