diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index a1953a2d358..44a402262cc 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -204,6 +204,7 @@ test: - test -f $PREFIX/include/cudf_test/cudf_gtest.hpp - test -f $PREFIX/include/cudf_test/cxxopts.hpp - test -f $PREFIX/include/cudf_test/file_utilities.hpp + - test -f $PREFIX/include/cudf_test/iterator_utilities.hpp - test -f $PREFIX/include/cudf_test/table_utilities.hpp - test -f $PREFIX/include/cudf_test/timestamp_utilities.cuh - test -f $PREFIX/include/cudf_test/type_list_utilities.hpp diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index 4a7bc129aaf..a81b6ebc8a1 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -204,8 +204,18 @@ std::unique_ptr make_nth_element_aggregation( /// Factory to create a ROW_NUMBER aggregation std::unique_ptr make_row_number_aggregation(); -/// Factory to create a COLLECT_NUMBER aggregation -std::unique_ptr make_collect_aggregation(); +/** + * @brief Factory to create a COLLECT aggregation + * + * `COLLECT` returns a list column of all included elements in the group/series. + * + * If `null_handling` is set to `EXCLUDE`, null elements are dropped from each + * of the list rows. + * + * @param null_handling Indicates whether to include/exclude nulls in list elements. + */ +std::unique_ptr make_collect_aggregation( + null_policy null_handling = null_policy::INCLUDE); /// Factory to create a LAG aggregation std::unique_ptr make_lag_aggregation(size_type offset); diff --git a/cpp/include/cudf/detail/aggregation/aggregation.hpp b/cpp/include/cudf/detail/aggregation/aggregation.hpp index 6b4a537d21b..1cafad25c9c 100644 --- a/cpp/include/cudf/detail/aggregation/aggregation.hpp +++ b/cpp/include/cudf/detail/aggregation/aggregation.hpp @@ -154,6 +154,9 @@ struct quantile_aggregation final : derived_aggregation { } }; +/** + * @brief Derived aggregation class for specifying LEAD/LAG window aggregations + */ struct lead_lag_aggregation final : derived_aggregation { lead_lag_aggregation(Kind kind, size_type offset) : derived_aggregation{offset < 0 ? (kind == LAG ? LEAD : LAG) : kind}, @@ -316,6 +319,27 @@ struct udf_aggregation final : derived_aggregation { } }; +/** + * @brief Derived aggregation class for specifying COLLECT aggregation + */ +struct collect_list_aggregation final : derived_aggregation { + explicit collect_list_aggregation(null_policy null_handling = null_policy::INCLUDE) + : derived_aggregation{COLLECT}, _null_handling{null_handling} + { + } + null_policy _null_handling; ///< include or exclude nulls + + protected: + friend class derived_aggregation; + + bool operator==(nunique_aggregation const& other) const + { + return _null_handling == other._null_handling; + } + + size_t hash_impl() const { return std::hash{}(static_cast(_null_handling)); } +}; + /** * @brief Sentinel value used for `ARGMAX` aggregation. * diff --git a/cpp/include/cudf_test/iterator_utilities.hpp b/cpp/include/cudf_test/iterator_utilities.hpp new file mode 100644 index 00000000000..40c275a13d3 --- /dev/null +++ b/cpp/include/cudf_test/iterator_utilities.hpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +namespace cudf { +namespace test { + +/** + * @brief Bool iterator for marking (possibly multiple) null elements in a column_wrapper. + * + * The returned iterator yields `false` (to mark `null`) at all the specified indices, + * and yields `true` (to mark valid rows) for all other indices. E.g. + * + * @code + * auto indices = std::vector{8,9}; + * auto iter = iterator_with_null_at(indices.cbegin(), indices.end()); + * iter[6] == true; // i.e. Valid row at index 6. + * iter[7] == true; // i.e. Valid row at index 7. + * iter[8] == false; // i.e. Invalid row at index 8. + * iter[9] == false; // i.e. Invalid row at index 9. + * @endcode + * + * @tparam Iter Iterator type + * @param index_start Iterator to start of indices for which the validity iterator + * must return `false` (i.e. null) + * @param index_end Iterator to end of indices for the validity iterator + * @return auto Validity iterator + */ +template +static auto iterator_with_null_at(Iter index_start, Iter index_end) +{ + using index_type = typename std::iterator_traits::value_type; + + return cudf::detail::make_counting_transform_iterator( + 0, [indices = std::vector{index_start, index_end}](auto i) { + return std::find(indices.cbegin(), indices.cend(), i) == indices.cend(); + }); +} + +/** + * @brief Bool iterator for marking (possibly multiple) null elements in a column_wrapper. + * + * The returned iterator yields `false` (to mark `null`) at all the specified indices, + * and yields `true` (to mark valid rows) for all other indices. E.g. + * + * @code + * using host_span = cudf::detail::host_span; + * auto iter = iterator_with_null_at(host_span{std::vector{8,9}}); + * iter[6] == true; // i.e. Valid row at index 6. + * iter[7] == true; // i.e. Valid row at index 7. + * iter[8] == false; // i.e. Invalid row at index 8. + * iter[9] == false; // i.e. Invalid row at index 9. + * @endcode + * + * @param indices The indices for which the validity iterator must return `false` (i.e. null) + * @return auto Validity iterator + */ +static auto iterator_with_null_at(cudf::detail::host_span const& indices) +{ + return iterator_with_null_at(indices.begin(), indices.end()); +} + +/** + * @brief Bool iterator for marking a single null element in a column_wrapper + * + * The returned iterator yields `false` (to mark `null`) at the specified index, + * and yields `true` (to mark valid rows) for all other indices. E.g. + * + * @code + * auto iter = iterator_with_null_at(8); + * iter[7] == true; // i.e. Valid row at index 7. + * iter[8] == false; // i.e. Invalid row at index 8. + * @endcode + * + * @param index The index for which the validity iterator must return `false` (i.e. null) + * @return auto Validity iterator + */ +static auto iterator_with_null_at(cudf::size_type const& index) +{ + return iterator_with_null_at(std::vector{index}); +} + +} // namespace test +} // namespace cudf diff --git a/cpp/src/aggregation/aggregation.cpp b/cpp/src/aggregation/aggregation.cpp index 6c1ad58e81b..04dc8776d20 100644 --- a/cpp/src/aggregation/aggregation.cpp +++ b/cpp/src/aggregation/aggregation.cpp @@ -126,9 +126,9 @@ std::unique_ptr make_row_number_aggregation() return std::make_unique(aggregation::ROW_NUMBER); } /// Factory to create a COLLECT aggregation -std::unique_ptr make_collect_aggregation() +std::unique_ptr make_collect_aggregation(null_policy null_handling) { - return std::make_unique(aggregation::COLLECT); + return std::make_unique(null_handling); } /// Factory to create a LAG aggregation std::unique_ptr make_lag_aggregation(size_type offset) diff --git a/cpp/src/groupby/sort/groupby.cu b/cpp/src/groupby/sort/groupby.cu index a88e45c4c7f..5c54dd3cb4c 100644 --- a/cpp/src/groupby/sort/groupby.cu +++ b/cpp/src/groupby/sort/groupby.cu @@ -403,6 +403,11 @@ void store_result_functor::operator()(aggregation cons template <> void store_result_functor::operator()(aggregation const& agg) { + auto null_handling = + static_cast(agg)._null_handling; + CUDF_EXPECTS(null_handling == null_policy::INCLUDE, + "null exclusion is not supported on groupby COLLECT aggregation."); + if (cache.has_result(col_idx, agg)) return; auto result = detail::group_collect( diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index 2ede50b468a..dcc48aafb39 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -1063,6 +1063,81 @@ struct rolling_window_launcher { return gather_map; } + /** + * @brief Count null entries in result of COLLECT. + */ + size_type count_child_nulls(column_view const& input, + std::unique_ptr const& gather_map, + rmm::cuda_stream_view stream) + { + auto input_device_view = column_device_view::create(input, stream); + + auto input_row_is_null = [d_input = *input_device_view] __device__(auto i) { + return d_input.is_null_nocheck(i); + }; + + return thrust::count_if(rmm::exec_policy(stream), + gather_map->view().template begin(), + gather_map->view().template end(), + input_row_is_null); + } + + /** + * @brief Purge entries for null inputs from gather_map, and adjust offsets. + */ + std::pair, std::unique_ptr> purge_null_entries( + column_view const& input, + column_view const& gather_map, + column_view const& offsets, + size_type num_child_nulls, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + auto input_device_view = column_device_view::create(input, stream); + + auto input_row_not_null = [d_input = *input_device_view] __device__(auto i) { + return d_input.is_valid_nocheck(i); + }; + + // Purge entries in gather_map that correspond to null input. + auto new_gather_map = make_fixed_width_column(data_type{type_to_id()}, + gather_map.size() - num_child_nulls, + mask_state::UNALLOCATED, + stream, + mr); + thrust::copy_if(rmm::exec_policy(stream), + gather_map.template begin(), + gather_map.template end(), + new_gather_map->mutable_view().template begin(), + input_row_not_null); + + // Recalculate offsets after null entries are purged. + auto new_sizes = make_fixed_width_column( + data_type{type_to_id()}, input.size(), mask_state::UNALLOCATED, stream, mr); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.size()), + new_sizes->mutable_view().template begin(), + [d_gather_map = gather_map.template begin(), + d_old_offsets = offsets.template begin(), + input_row_not_null] __device__(auto i) { + return thrust::count_if(thrust::seq, + d_gather_map + d_old_offsets[i], + d_gather_map + d_old_offsets[i + 1], + input_row_not_null); + }); + + auto new_offsets = + strings::detail::make_offsets_child_column(new_sizes->view().template begin(), + new_sizes->view().template end(), + stream, + mr); + + return std::make_pair, std::unique_ptr>( + std::move(new_gather_map), std::move(new_offsets)); + } + template std::enable_if_t<(op == aggregation::COLLECT), std::unique_ptr> operator()( column_view const& input, @@ -1106,6 +1181,17 @@ struct rolling_window_launcher { auto gather_map = create_collect_gather_map( offsets->view(), per_row_mapping->view(), preceding_begin, stream, mr); + // If gather_map collects null elements, and null_policy == EXCLUDE, + // those elements must be filtered out, and offsets recomputed. + auto null_handling = static_cast(agg.get())->_null_handling; + if (null_handling == null_policy::EXCLUDE && input.has_nulls()) { + auto num_child_nulls = count_child_nulls(input, gather_map, stream); + if (num_child_nulls != 0) { + std::tie(gather_map, offsets) = + purge_null_entries(input, *gather_map, *offsets, num_child_nulls, stream, mr); + } + } + // gather(), to construct child column. auto gather_output = cudf::gather(table_view{std::vector{input}}, gather_map->view()); diff --git a/cpp/tests/collect_list/collect_list_test.cpp b/cpp/tests/collect_list/collect_list_test.cpp index dada82c3dc1..98a7b2bacc2 100644 --- a/cpp/tests/collect_list/collect_list_test.cpp +++ b/cpp/tests/collect_list/collect_list_test.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -78,6 +79,11 @@ TYPED_TEST(TypedCollectListTest, BasicRollingWindow) auto const result_fixed_window = rolling_window(input_column, 2, 1, 1, make_collect_aggregation()); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, 2, 1, 1, make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputLists) @@ -110,6 +116,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputLists) .release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view()); + + auto const result_with_nulls_excluded = rolling_window( + input_column, prev_column, foll_column, 0, make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputListsAtEnds) @@ -131,6 +142,11 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithEmptyOutputListsAtEnds) lists_column_wrapper{{}, {0, 1, 2}, {1, 2, 3}, {2, 3, 4}, {3, 4, 5}, {}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = rolling_window( + input_column, prev_column, foll_column, 0, make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) @@ -160,6 +176,15 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); + preceding = 2; following = 2; min_periods = 4; @@ -173,6 +198,16 @@ TYPED_TEST(TypedCollectListTest, RollingWindowHonoursMinPeriods) })}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view()); + + auto result_2_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), + result_2_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) @@ -188,7 +223,6 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) auto const input_column = fixed_width_column_wrapper{{0, 1, 2, 3, 4, 5}, {1, 0, 1, 1, 0, 1}}; - // auto const num_elements = static_cast(input_column).size(); { // One result row at each end should be null. @@ -219,6 +253,36 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); } + { + // One result row at each end should be null. + // Exclude nulls: No nulls elements for any output list rows. + auto preceding = 2; + auto following = 1; + auto min_periods = 3; + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + auto expected_result_child_values = std::vector{0, 2, 2, 3, 2, 3, 3, 5}; + auto expected_result_child = fixed_width_column_wrapper( + expected_result_child_values.begin(), expected_result_child_values.end()); + auto expected_offsets = fixed_width_column_wrapper{0, 0, 2, 4, 6, 8, 8}.release(); + auto expected_num_rows = expected_offsets->size() - 1; + auto null_mask_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, [expected_num_rows](auto i) { return i != 0 && i != (expected_num_rows - 1); }); + + auto expected_result = make_lists_column( + expected_num_rows, + std::move(expected_offsets), + expected_result_child.release(), + 2, + cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } + { // First result row, and the last two result rows should be null. auto preceding = 2; @@ -248,6 +312,37 @@ TYPED_TEST(TypedCollectListTest, RollingWindowWithNullInputsHonoursMinPeriods) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); } + + { + // First result row, and the last two result rows should be null. + // Exclude nulls: No nulls elements for any output list rows. + auto preceding = 2; + auto following = 2; + auto min_periods = 4; + auto const result = rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + auto expected_result_child_values = std::vector{0, 2, 3, 2, 3, 2, 3, 5}; + auto expected_result_child = fixed_width_column_wrapper( + expected_result_child_values.begin(), expected_result_child_values.end()); + + auto expected_offsets = fixed_width_column_wrapper{0, 0, 3, 5, 8, 8, 8}.release(); + auto expected_num_rows = expected_offsets->size() - 1; + auto null_mask_iter = cudf::detail::make_counting_transform_iterator( + size_type{0}, [expected_num_rows](auto i) { return i > 0 && i < 4; }); + + auto expected_result = make_lists_column( + expected_num_rows, + std::move(expected_offsets), + expected_result_child.release(), + 3, + cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } } TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) @@ -275,6 +370,15 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); + preceding = 2; following = 2; min_periods = 4; @@ -288,6 +392,16 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsOnStrings) })}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), result_2->view()); + + auto result_2_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_2->view(), + result_2_with_nulls_excluded->view()); } TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) @@ -329,6 +443,16 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), + result_with_nulls_excluded->view()); } { @@ -357,6 +481,16 @@ TEST_F(CollectListTest, RollingWindowHonoursMinPeriodsWithDecimal) cudf::test::detail::make_null_mask(null_mask_iter, null_mask_iter + expected_num_rows)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + rolling_window(input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), + result_with_nulls_excluded->view()); } } @@ -393,6 +527,16 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindow) {22, 23}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindowWithNulls) @@ -409,26 +553,53 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedRollingWindowWithNulls) auto const preceding = 2; auto const following = 1; auto const min_periods = 1; - auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, - input_column, - preceding, - following, - min_periods, - make_collect_aggregation()); - auto expected_child = fixed_width_column_wrapper{ - {10, 11, 10, 11, 12, 11, 12, 13, 12, 13, 14, 13, 14, 20, 21, 20, 21, 22, 21, 22, 23, 22, 23}, - {1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1}}; + { + // Nulls included. + auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto expected_child = fixed_width_column_wrapper{ + {10, 11, 10, 11, 12, 11, 12, 13, 12, 13, 14, 13, 14, 20, 21, 20, 21, 22, 21, 22, 23, 22, 23}, + {1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1}}; + + auto expected_offsets = fixed_width_column_wrapper{0, 2, 5, 8, 11, 13, 15, 18, 21, 23}; + + auto expected_result = make_lists_column(static_cast(group_column).size(), + expected_offsets.release(), + expected_child.release(), + 0, + {}); - auto expected_offsets = fixed_width_column_wrapper{0, 2, 5, 8, 11, 13, 15, 18, 21, 23}; + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } - auto expected_result = make_lists_column(static_cast(group_column).size(), - expected_offsets.release(), - expected_child.release(), - 0, - {}); + { + // Nulls excluded. + auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); - CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + auto expected_child = fixed_width_column_wrapper{ + 10, 10, 12, 12, 13, 12, 13, 14, 13, 14, 20, 20, 22, 22, 23, 22, 23}; + + auto expected_offsets = fixed_width_column_wrapper{0, 1, 3, 5, 8, 10, 11, 13, 15, 17}; + + auto expected_result = make_lists_column(static_cast(group_column).size(), + expected_offsets.release(), + expected_child.release(), + 0, + {}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } } TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindow) @@ -468,6 +639,86 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindow) {21, 22, 23}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNulls) +{ + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = fixed_width_column_wrapper{ + {10, 11, 12, 13, 14, 20, 21, 22, 23}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto null_at_0 = iterator_with_null_at(0); + auto null_at_1 = iterator_with_null_at(1); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{10, 11, 12, 13}, null_at_1}, + {{10, 11, 12, 13}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{20}, null_at_1}, + {{21, 22}, null_at_0}, + {{21, 22, 23}, null_at_0}, + {{21, 22, 23}, null_at_0}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {10, 12, 13}, + {10, 12, 13}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {20}, + {22}, + {22, 23}, + {22, 23}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); } TEST_F(CollectListTest, BasicGroupedTimeRangeRollingWindowOnStrings) @@ -505,6 +756,85 @@ TEST_F(CollectListTest, BasicGroupedTimeRangeRollingWindowOnStrings) {"21", "22", "23"}}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNulls) +{ + using namespace cudf; + using namespace cudf::test; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = strings_column_wrapper{ + {"10", "11", "12", "13", "14", "20", "21", "22", "23"}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto null_at_0 = iterator_with_null_at(0); + auto null_at_1 = iterator_with_null_at(1); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{"10", "11", "12", "13"}, null_at_1}, + {{"10", "11", "12", "13"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {"20"}, + {{"21", "22"}, null_at_0}, + {{"21", "22", "23"}, null_at_0}, + {{"21", "22", "23"}, + null_at_0}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {"10", "12", "13"}, + {"10", "12", "13"}, + {"10", "12", "13", "14"}, + {"10", "12", "13", "14"}, + {"10", "12", "13", "14"}, + {"20"}, + {"22"}, + {"22", "23"}, + {"22", "23"}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindowOnStructs) @@ -557,6 +887,18 @@ TYPED_TEST(TypedCollectListTest, BasicGroupedTimeRangeRollingWindowOnStructs) 9, std::move(expected_offsets_column), std::move(expected_structs_column), 0, {}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + struct_column->view(), + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithMinPeriods) @@ -601,6 +943,92 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithMinPeriods) })}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowWithNullsAndMinPeriods) +{ + // Test that min_periods is honoured. + // i.e. output row is null when min_periods exceeds number of observations. + using namespace cudf; + using namespace cudf::test; + + using T = TypeParam; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = fixed_width_column_wrapper{ + {10, 11, 12, 13, 14, 20, 21, 22, 23}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 4; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto null_at_1 = iterator_with_null_at(1); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{{10, 11, 12, 13}, null_at_1}, + {{10, 11, 12, 13}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {{10, 11, 12, 13, 14}, null_at_1}, + {}, + {}, + {}, + {}}, + cudf::detail::make_counting_transform_iterator(0, [](auto i) { + return i < 5; + })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {{10, 12, 13}, + {10, 12, 13}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {10, 12, 13, 14}, + {}, + {}, + {}, + {}}, + cudf::detail::make_counting_transform_iterator( + 0, [](auto i) { return i < 5; })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); } TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithMinPeriods) @@ -643,6 +1071,90 @@ TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithMinPeriods) })}.release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); +} + +TEST_F(CollectListTest, GroupedTimeRangeRollingWindowOnStringsWithNullsAndMinPeriods) +{ + // Test that min_periods is honoured. + // i.e. output row is null when min_periods exceeds number of observations. + using namespace cudf; + using namespace cudf::test; + + auto const time_column = fixed_width_column_wrapper{ + 1, 1, 2, 2, 3, 1, 4, 5, 6}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = strings_column_wrapper{ + {"10", "11", "12", "13", "14", "20", "21", "22", "23"}, {1, 0, 1, 1, 1, 1, 0, 1, 1}}; + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 4; + auto const result = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation()); + + auto null_at_1 = iterator_with_null_at(1); + + // In the results, `11` and `21` should be nulls. + auto const expected_result = lists_column_wrapper{ + {{{"10", "11", "12", "13"}, null_at_1}, + {{"10", "11", "12", "13"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {{"10", "11", "12", "13", "14"}, null_at_1}, + {}, + {}, + {}, + {}}, + cudf::detail::make_counting_transform_iterator(0, [](auto i) { + return i < 5; + })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + input_column, + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + // After null exclusion, `11`, `21`, and `null` should not appear. + auto const expected_result_with_nulls_excluded = lists_column_wrapper{ + {{"10", "12", "13"}, + {"10", "12", "13"}, + {"10", "12", "13", "14"}, + {"10", "12", "13", "14"}, + {"10", "12", "13", "14"}, + {}, + {}, + {}, + {}}, + cudf::detail::make_counting_transform_iterator( + 0, [](auto i) { return i < 5; })}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_with_nulls_excluded->view(), + result_with_nulls_excluded->view()); } TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPeriods) @@ -703,6 +1215,18 @@ TYPED_TEST(TypedCollectListTest, GroupedTimeRangeRollingWindowOnStructsWithMinPe std::move(expected_null_mask)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + auto const result_with_nulls_excluded = + grouped_time_range_rolling_window(table_view{std::vector{group_column}}, + time_column, + cudf::order::ASCENDING, + struct_column->view(), + preceding, + following, + min_periods, + make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } CUDF_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/groupby/group_collect_test.cpp b/cpp/tests/groupby/group_collect_test.cpp index 1e387fa3455..9edd0a6932a 100644 --- a/cpp/tests/groupby/group_collect_test.cpp +++ b/cpp/tests/groupby/group_collect_test.cpp @@ -108,5 +108,24 @@ TYPED_TEST(groupby_collect_test, dictionary) test_single_agg(keys, vals, expect_keys, expect_vals->view(), cudf::make_collect_aggregation()); } +TYPED_TEST(groupby_collect_test, CollectFailsWithNullExclusion) +{ + using K = int32_t; + using V = TypeParam; + + fixed_width_column_wrapper keys{1, 1, 2, 2, 3, 3}; + groupby::groupby gby{table_view{{keys}}}; + + fixed_width_column_wrapper values{{1, 2, 3, 4, 5, 6}, + {true, false, true, false, true, false}}; + + std::vector agg_requests(1); + agg_requests[0].values = values; + agg_requests[0].aggregations.push_back(cudf::make_collect_aggregation(null_policy::EXCLUDE)); + + CUDF_EXPECT_THROW_MESSAGE(gby.aggregate(agg_requests), + "null exclusion is not supported on groupby COLLECT aggregation."); +} + } // namespace test } // namespace cudf