From 911972654f8522139a494f974a06b7b9006d5bd0 Mon Sep 17 00:00:00 2001 From: MithunR Date: Thu, 18 Feb 2021 08:07:31 -0800 Subject: [PATCH] Support null_policy::EXCLUDE for COLLECT rolling aggregation (#7264) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #7258. #7189 implements `COLLECT` aggregations to be done from window functions. The semantics of how null input rows are handled are consistent with CUDF semantics. E.g. ```c++ auto input_col = fixed_width_column_wrapper{70, ∅, 72, 73, 74}; auto output_col = cudf::rolling_window(input_col, 2, 1, 1, collect_aggr); // == [ [70,∅], [70,∅,72], [∅,72,73], [72,73,74], [73,74] ] ``` Note that the null element (`∅`) is replicated in the first 3 rows of the output. SparkSQL (and Hive, and other big data SQL systems) have different semantics, in that all null elements are purged. The output for the same operation should yield the following: ```c++ auto sparkish_output_col = cudf::rolling_window(input_col, 2, 1, 1, collect_aggr); // == [ [70], [70,72], [72,73], [72,73,74], [73,74] ] ``` CUDF should allow the `COLLECT` aggregation to be constructed with an optional `null_policy` argument (with default `INCLUDE`). The `COLLECT` window function should check the policy, and filter out null list-elements _a posteriori_. Authors: - MithunR (@mythrocks) Approvers: - Ram (Ramakrishna Prabhu) (@rgsl888prabhu) - AJ Schmidt (@ajschmidt8) - Vukasin Milovanovic (@vuule) - Jake Hemstad (@jrhemstad) URL: https://github.com/rapidsai/cudf/pull/7264 --- conda/recipes/libcudf/meta.yaml | 1 + cpp/include/cudf/aggregation.hpp | 14 +- .../cudf/detail/aggregation/aggregation.hpp | 24 + cpp/include/cudf_test/iterator_utilities.hpp | 106 ++++ cpp/src/aggregation/aggregation.cpp | 4 +- cpp/src/groupby/sort/groupby.cu | 5 + cpp/src/rolling/rolling_detail.cuh | 86 +++ cpp/tests/collect_list/collect_list_test.cpp | 558 +++++++++++++++++- cpp/tests/groupby/group_collect_test.cpp | 19 + 9 files changed, 796 insertions(+), 21 deletions(-) create mode 100644 cpp/include/cudf_test/iterator_utilities.hpp diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index a1953a2d358..44a402262cc 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -204,6 +204,7 @@ test: - test -f $PREFIX/include/cudf_test/cudf_gtest.hpp - test -f $PREFIX/include/cudf_test/cxxopts.hpp - test -f $PREFIX/include/cudf_test/file_utilities.hpp + - test -f $PREFIX/include/cudf_test/iterator_utilities.hpp - test -f $PREFIX/include/cudf_test/table_utilities.hpp - test -f $PREFIX/include/cudf_test/timestamp_utilities.cuh - test -f $PREFIX/include/cudf_test/type_list_utilities.hpp diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index 4a7bc129aaf..a81b6ebc8a1 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -204,8 +204,18 @@ std::unique_ptr make_nth_element_aggregation( /// Factory to create a ROW_NUMBER aggregation std::unique_ptr make_row_number_aggregation(); -/// Factory to create a COLLECT_NUMBER aggregation -std::unique_ptr make_collect_aggregation(); +/** + * @brief Factory to create a COLLECT aggregation + * + * `COLLECT` returns a list column of all included elements in the group/series. + * + * If `null_handling` is set to `EXCLUDE`, null elements are dropped from each + * of the list rows. + * + * @param null_handling Indicates whether to include/exclude nulls in list elements. + */ +std::unique_ptr make_collect_aggregation( + null_policy null_handling = null_policy::INCLUDE); /// Factory to create a LAG aggregation std::unique_ptr make_lag_aggregation(size_type offset); diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 6b4a537d21b..1cafad25c9c 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -154,6 +154,9 @@ struct quantile_aggregation final : derived_aggregation { } }; +/** + * @brief Derived aggregation class for specifying LEAD/LAG window aggregations + */ struct lead_lag_aggregation final : derived_aggregation { lead_lag_aggregation(Kind kind, size_type offset) : derived_aggregation{offset < 0 ? (kind == LAG ? LEAD : LAG) : kind}, @@ -316,6 +319,27 @@ struct udf_aggregation final : derived_aggregation { } }; +/** + * @brief Derived aggregation class for specifying COLLECT aggregation + */ +struct collect_list_aggregation final : derived_aggregation { + explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE) + : derived_aggregation{COLLECT}, _null_handling{null_handling} + { + } + null_policy _null_handling; ///< include or exclude nulls + + protected: + friend class derived_aggregation; + + bool operator==(nunique_aggregation const& other) const + { + return _null_handling == other._null_handling; + } + + size_t hash_impl() const { return std::hash{}(static_cast(_null_handling)); } +}; + /** * @brief Sentinel value used for `ARGMAX` aggregation. * diff --git a/cpp/include/cudf_test/iterator_utilities.hpp b/cpp/include/cudf_test/iterator_utilities.hpp new file mode 100644 index 00000000000..40c275a13d3 --- /dev/null +++ b/cpp/include/cudf_test/iterator_utilities.hpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +namespace cudf { +namespace test { + +/** + * @brief Bool iterator for marking (possibly multiple) null elements in a column_wrapper. + * + * The returned iterator yields `false` (to mark `null`) at all the specified indices, + * and yields `true` (to mark valid rows) for all other indices. E.g. + * + * @code + * auto indices = std::vector{8,9}; + * auto iter = iterator_with_null_at(indices.cbegin(), indices.end()); + * iter[6] == true; // i.e. Valid row at index 6. + * iter[7] == true; // i.e. Valid row at index 7. + * iter[8] == false; // i.e. Invalid row at index 8. + * iter[9] == false; // i.e. Invalid row at index 9. + * @endcode + * + * @tparam Iter Iterator type + * @param index_start Iterator to start of indices for which the validity iterator + * must return `false` (i.e. null) + * @param index_end Iterator to end of indices for the validity iterator + * @return auto Validity iterator + */ +template +static auto iterator_with_null_at(Iter index_start, Iter index_end) +{ + using index_type = typename std::iterator_traits::value_type; + + return cudf::detail::make_counting_transform_iterator( + 0, [indices = std::vector{index_start, index_end}](auto i) { + return std::find(indices.cbegin(), indices.cend(), i) == indices.cend(); + }); +} + +/** + * @brief Bool iterator for marking (possibly multiple) null elements in a column_wrapper. + * + * The returned iterator yields `false` (to mark `null`) at all the specified indices, + * and yields `true` (to mark valid rows) for all other indices. E.g. + * + * @code + * using host_span = cudf::detail::host_span; + * auto iter = iterator_with_null_at(host_span{std::vector{8,9}}); + * iter[6] == true; // i.e. Valid row at index 6. + * iter[7] == true; // i.e. Valid row at index 7. + * iter[8] == false; // i.e. Invalid row at index 8. + * iter[9] == false; // i.e. Invalid row at index 9. + * @endcode + * + * @param indices The indices for which the validity iterator must return `false` (i.e. null) + * @return auto Validity iterator + */ +static auto iterator_with_null_at(cudf::detail::host_span const& indices) +{ + return iterator_with_null_at(indices.begin(), indices.end()); +} + +/** + * @brief Bool iterator for marking a single null element in a column_wrapper + * + * The returned iterator yields `false` (to mark `null`) at the specified index, + * and yields `true` (to mark valid rows) for all other indices. E.g. + * + * @code + * auto iter = iterator_with_null_at(8); + * iter[7] == true; // i.e. Valid row at index 7. + * iter[8] == false; // i.e. Invalid row at index 8. + * @endcode + * + * @param index The index for which the validity iterator must return `false` (i.e. null) + * @return auto Validity iterator + */ +static auto iterator_with_null_at(cudf::size_type const& index) +{ + return iterator_with_null_at(std::vector{index}); +} + +} // namespace test +} // namespace cudf diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index 6c1ad58e81b..04dc8776d20 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -126,9 +126,9 @@ std::unique_ptr make_row_number_aggregation() return std::make_unique(aggregation::ROW_NUMBER); } /// Factory to create a COLLECT aggregation -std::unique_ptr make_collect_aggregation() +std::unique_ptr make_collect_aggregation(null_policy null_handling) { - return std::make_unique(aggregation::COLLECT); + return std::make_unique(null_handling); } /// Factory to create a LAG aggregation std::unique_ptr make_lag_aggregation(size_type offset) diff --git a/cpp/src/groupby/sort/groupby.cu b/cpp/src/groupby/sort/groupby.cu index a88e45c4c7f..5c54dd3cb4c 100644 --- a/cpp/src/groupby/sort/groupby.cu +++ b/cpp/src/groupby/sort/groupby.cu @@ -403,6 +403,11 @@ void store_result_functor::operator()(aggregation cons template <> void store_result_functor::operator()(aggregation const& agg) { + auto null_handling = + static_cast(agg)._null_handling; + CUDF_EXPECTS(null_handling == null_policy::INCLUDE, + "null exclusion is not supported on groupby COLLECT aggregation."); + if (cache.has_result(col_idx, agg)) return; auto result = detail::group_collect( diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index 2ede50b468a..dcc48aafb39 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -1063,6 +1063,81 @@ struct rolling_window_launcher { return gather_map; } + /** + * @brief Count null entries in result of COLLECT. + */ + size_type count_child_nulls(column_view const& input, + std::unique_ptr const& gather_map, + rmm::cuda_stream_view stream) + { + auto input_device_view = column_device_view::create(input, stream); + + auto input_row_is_null = [d_input = *input_device_view] __device__(auto i) { + return d_input.is_null_nocheck(i); + }; + + return thrust::count_if(rmm::exec_policy(stream), + gather_map->view().template begin(), + gather_map->view().template end(), + input_row_is_null); + } + + /** + * @brief Purge entries for null inputs from gather_map, and adjust offsets. + */ + std::pair, std::unique_ptr> purge_null_entries( + column_view const& input, + column_view const& gather_map, + column_view const& offsets, + size_type num_child_nulls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto input_device_view = column_device_view::create(input, stream); + + auto input_row_not_null = [d_input = *input_device_view] __device__(auto i) { + return d_input.is_valid_nocheck(i); + }; + + // Purge entries in gather_map that correspond to null input. + auto new_gather_map = make_fixed_width_column(data_type{type_to_id()}, + gather_map.size() - num_child_nulls, + mask_state::UNALLOCATED, + stream, + mr); + thrust::copy_if(rmm::exec_policy(stream), + gather_map.template begin(), + gather_map.template end(), + new_gather_map->mutable_view().template begin(), + input_row_not_null); + + // Recalculate offsets after null entries are purged. + auto new_sizes = make_fixed_width_column( + data_type{type_to_id()}, input.size(), mask_state::UNALLOCATED, stream, mr); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.size()), + new_sizes->mutable_view().template begin(), + [d_gather_map = gather_map.template begin(), + d_old_offsets = offsets.template begin(), + input_row_not_null] __device__(auto i) { + return thrust::count_if(thrust::seq, + d_gather_map + d_old_offsets[i], + d_gather_map + d_old_offsets[i + 1], + input_row_not_null); + }); + + auto new_offsets = + strings::detail::make_offsets_child_column(new_sizes->view().template begin(), + new_sizes->view().template end(), + stream, + mr); + + return std::make_pair, std::unique_ptr>( + std::move(new_gather_map), std::move(new_offsets)); + } + template std::enable_if_t<(op == aggregation::COLLECT), std::unique_ptr> operator()( column_view const& input, @@ -1106,6 +1181,17 @@ struct rolling_window_launcher { auto gather_map = create_collect_gather_map( offsets->view(), per_row_mapping->view(), preceding_begin, stream, mr); + // If gather_map collects null elements, and null_policy == EXCLUDE, + // those elements must be filtered out, and offsets recomputed. + auto null_handling = static_cast(agg.get())->_null_handling; + if (null_handling == null_policy::EXCLUDE && input.has_nulls()) { + auto num_child_nulls = count_child_nulls(input, gather_map, stream); + if (num_child_nulls != 0) { + std::tie(gather_map, offsets) = + purge_null_entries(input, *gather_map, *offsets, num_child_nulls, stream, mr); + } + } + // gather(), to construct child column. auto gather_output = cudf::gather(table_view{std::vector{input}}, gather_map->view()); diff --git a/cpp/tests/collect_list/collect_list_test.cpp b/cpp/tests/collect_list/collect_list_test.cpp index dada82c3dc1..98a7b2bacc2 100644 --- a/cpp/tests/collect_list/collect_list_test.cpp +++ b/cpp/tests/collect_list/collect_list_test.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -78,6 +79,11 @@ TYPED_TEST(TypedCollectListTest, BasicRollingWindow) auto const result_fixed_window = rolling_window(input_column, 2, 1, 1, make_collect_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, 2, 1, 1, make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputLists) @@ -110,6 +116,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputLists) .release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); + + auto const result_with_nulls_excluded = rolling_window( + input_column, prev_column, foll_column, 0, make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputListsAtEnds) @@ -131,6 +142,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputListsAtEnds) lists_column_wrapper{{}, {0, 1, 2}, {1, 2, 3}, {2, 3, 4}, {3, 4, 5}, {}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = rolling_window( + input_column, prev_column, foll_column, 0, make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) @@ -160,6 +176,15 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); + preceding = 2; following = 2; min_periods = 4; @@ -173,6 +198,16 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) })}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view()); + + auto result_2_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), + result_2_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) @@ -188,7 +223,6 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) auto const input_column = fixed_width_column_wrapper{{0, 1, 2, 3, 4, 5}, {1, 0, 1, 1, 0, 1}}; - // auto const num_elements = static_cast(input_column).size(); { // One result row at each end should be null. @@ -219,6 +253,36 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); } + { + // One result row at each end should be null. + // Exclude nulls: No nulls elements for any output list rows. + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + auto expected_result_child_values = std::vector{0, 2, 2, 3, 2, 3, 3, 5}; + auto expected_result_child = fixed_width_column_wrapper( + expected_result_child_values.begin(), expected_result_child_values.end()); + auto expected_offsets = fixed_width_column_wrapper{0, 0, 2, 4, 6, 8, 8}.release(); + auto expected_num_rows = expected_offsets->size() - 1; + auto null_mask_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, [expected_num_rows](auto i) { return i != 0 && i != (expected_num_rows - 1); }); + + auto expected_result = make_lists_column( + expected_num_rows, + std::move(expected_offsets), + expected_result_child.release(), + 2, + cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } + { // First result row, and the last two result rows should be null. auto preceding = 2; @@ -248,6 +312,37 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); } + + { + // First result row, and the last two result rows should be null. + // Exclude nulls: No nulls elements for any output list rows. + auto preceding = 2; + auto following = 2; + auto min_periods = 4; + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + auto expected_result_child_values = std::vector{0, 2, 3, 2, 3, 2, 3, 5}; + auto expected_result_child = fixed_width_column_wrapper( + expected_result_child_values.begin(), expected_result_child_values.end()); + + auto expected_offsets = fixed_width_column_wrapper{0, 0, 3, 5, 8, 8, 8}.release(); + auto expected_num_rows = expected_offsets->size() - 1; + auto null_mask_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, [expected_num_rows](auto i) { return i > 0 && i < 4; }); + + auto expected_result = make_lists_column( + expected_num_rows, + std::move(expected_offsets), + expected_result_child.release(), + 3, + cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } } TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) @@ -275,6 +370,15 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); + preceding = 2; following = 2; min_periods = 4; @@ -288,6 +392,16 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) })}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view()); + + auto result_2_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), + result_2_with_nulls_excluded->view()); } TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) @@ -329,6 +443,16 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), + result_with_nulls_excluded->view()); } { @@ -357,6 +481,16 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), + result_with_nulls_excluded->view()); } } @@ -393,6 +527,16 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindow) {22, 23}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindowWithNulls) @@ -409,26 +553,53 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindowWithNulls) auto const preceding = 2; auto const following = 1; auto const min_periods = 1; - auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - make_collect_aggregation()); - auto expected_child = fixed_width_column_wrapper{ - {10, 11, 10, 11, 12, 11, 12, 13, 12, 13, 14, 13, 14, 20, 21, 20, 21, 22, 21, 22, 23, 22, 23}, - {1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1}}; + { + // Nulls included. + auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto expected_child = fixed_width_column_wrapper{ + {10, 11, 10, 11, 12, 11, 12, 13, 12, 13, 14, 13, 14, 20, 21, 20, 21, 22, 21, 22, 23, 22, 23}, + {1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1}}; + + auto expected_offsets = fixed_width_column_wrapper{0, 2, 5, 8, 11, 13, 15, 18, 21, 23}; + + auto expected_result = make_lists_column(static_cast(group_column).size(), + expected_offsets.release(), + expected_child.release(), + 0, + {}); - auto expected_offsets = fixed_width_column_wrapper{0, 2, 5, 8, 11, 13, 15, 18, 21, 23}; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } - auto expected_result = make_lists_column(static_cast(group_column).size(), - expected_offsets.release(), - expected_child.release(), - 0, - {}); + { + // Nulls excluded. + auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + auto expected_child = fixed_width_column_wrapper{ + 10, 10, 12, 12, 13, 12, 13, 14, 13, 14, 20, 20, 22, 22, 23, 22, 23}; + + auto expected_offsets = fixed_width_column_wrapper{0, 1, 3, 5, 8, 10, 11, 13, 15, 17}; + + auto expected_result = make_lists_column(static_cast(group_column).size(), + expected_offsets.release(), + expected_child.release(), + 0, + {}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } } TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindow) @@ -468,6 +639,86 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindow) {21, 22, 23}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNulls) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = fixed_width_column_wrapper{ + {10, 11, 12, 13, 14, 20, 21, 22, 23}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto null_at_0 = iterator_with_null_at(0); + auto null_at_1 = iterator_with_null_at(1); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{10, 11, 12, 13}, null_at_1}, + {{10, 11, 12, 13}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{20}, null_at_1}, + {{21, 22}, null_at_0}, + {{21, 22, 23}, null_at_0}, + {{21, 22, 23}, null_at_0}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {10, 12, 13}, + {10, 12, 13}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {20}, + {22}, + {22, 23}, + {22, 23}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); } TEST_F(CollectListTest, BasicGroupedTimeRangeRollingWindowOnStrings) @@ -505,6 +756,85 @@ TEST_F(CollectListTest, BasicGroupedTimeRangeRollingWindowOnStrings) {"21", "22", "23"}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNulls) +{ + using namespace cudf; + using namespace cudf::test; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = strings_column_wrapper{ + {"10", "11", "12", "13", "14", "20", "21", "22", "23"}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto null_at_0 = iterator_with_null_at(0); + auto null_at_1 = iterator_with_null_at(1); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{"10", "11", "12", "13"}, null_at_1}, + {{"10", "11", "12", "13"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {"20"}, + {{"21", "22"}, null_at_0}, + {{"21", "22", "23"}, null_at_0}, + {{"21", "22", "23"}, + null_at_0}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {"10", "12", "13"}, + {"10", "12", "13"}, + {"10", "12", "13", "14"}, + {"10", "12", "13", "14"}, + {"10", "12", "13", "14"}, + {"20"}, + {"22"}, + {"22", "23"}, + {"22", "23"}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindowOnStructs) @@ -557,6 +887,18 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindowOnStructs) 9, std::move(expected_offsets_column), std::move(expected_structs_column), 0, {}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + struct_column->view(), + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithMinPeriods) @@ -601,6 +943,92 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithMinPeriods) })}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNullsAndMinPeriods) +{ + // Test that min_periods is honoured. + // i.e. output row is null when min_periods exceeds number of observations. + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = fixed_width_column_wrapper{ + {10, 11, 12, 13, 14, 20, 21, 22, 23}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 4; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto null_at_1 = iterator_with_null_at(1); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{{10, 11, 12, 13}, null_at_1}, + {{10, 11, 12, 13}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {}, + {}, + {}, + {}}, + cudf::detail::make_counting_transform_iterator(0, [](auto i) { + return i < 5; + })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {{10, 12, 13}, + {10, 12, 13}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {}, + {}, + {}, + {}}, + cudf::detail::make_counting_transform_iterator( + 0, [](auto i) { return i < 5; })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); } TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithMinPeriods) @@ -643,6 +1071,90 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithMinPeriods) })}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNullsAndMinPeriods) +{ + // Test that min_periods is honoured. + // i.e. output row is null when min_periods exceeds number of observations. + using namespace cudf; + using namespace cudf::test; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = strings_column_wrapper{ + {"10", "11", "12", "13", "14", "20", "21", "22", "23"}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 4; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto null_at_1 = iterator_with_null_at(1); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{{"10", "11", "12", "13"}, null_at_1}, + {{"10", "11", "12", "13"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {}, + {}, + {}, + {}}, + cudf::detail::make_counting_transform_iterator(0, [](auto i) { + return i < 5; + })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {{"10", "12", "13"}, + {"10", "12", "13"}, + {"10", "12", "13", "14"}, + {"10", "12", "13", "14"}, + {"10", "12", "13", "14"}, + {}, + {}, + {}, + {}}, + cudf::detail::make_counting_transform_iterator( + 0, [](auto i) { return i < 5; })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPeriods) @@ -703,6 +1215,18 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe std::move(expected_null_mask)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + struct_column->view(), + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } CUDF_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/groupby/group_collect_test.cpp b/cpp/tests/groupby/group_collect_test.cpp index 1e387fa3455..9edd0a6932a 100644 --- a/cpp/tests/groupby/group_collect_test.cpp +++ b/cpp/tests/groupby/group_collect_test.cpp @@ -108,5 +108,24 @@ TYPED_TEST(groupby_collect_test, dictionary) test_single_agg(keys, vals, expect_keys, expect_vals->view(), cudf::make_collect_aggregation()); } +TYPED_TEST(groupby_collect_test, CollectFailsWithNullExclusion) +{ + using K = int32_t; + using V = TypeParam; + + fixed_width_column_wrapper keys{1, 1, 2, 2, 3, 3}; + groupby::groupby gby{table_view{{keys}}}; + + fixed_width_column_wrapper values{{1, 2, 3, 4, 5, 6}, + {true, false, true, false, true, false}}; + + std::vector agg_requests(1); + agg_requests[0].values = values; + agg_requests[0].aggregations.push_back(cudf::make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_EXPECT_THROW_MESSAGE(gby.aggregate(agg_requests), + "null exclusion is not supported on groupby COLLECT aggregation."); +} + } // namespace test } // namespace cudf