diff --git a/cpp/src/rolling/rolling_detail.cuh b/cpp/src/rolling/rolling_detail.cuh index 56b8bad0bac..525ed31ad82 100644 --- a/cpp/src/rolling/rolling_detail.cuh +++ b/cpp/src/rolling/rolling_detail.cuh @@ -746,11 +746,8 @@ class rolling_aggregation_postprocessor final : public cudf::detail::aggregation stream, rmm::mr::get_current_device_resource()); - result = lists::detail::drop_list_duplicates(lists_column_view(collected_list->view()), - null_equality::EQUAL, - nan_equality::UNEQUAL, - stream, - mr); + result = lists::detail::drop_list_duplicates( + lists_column_view(collected_list->view()), agg._nulls_equal, agg._nans_equal, stream, mr); } std::unique_ptr get_result() diff --git a/cpp/tests/groupby/collect_set_tests.cpp b/cpp/tests/groupby/collect_set_tests.cpp index d5a881a1993..8ce0380ad66 100644 --- a/cpp/tests/groupby/collect_set_tests.cpp +++ b/cpp/tests/groupby/collect_set_tests.cpp @@ -146,6 +146,47 @@ TEST_F(CollectSetTest, StringInput) test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); } +TEST_F(CollectSetTest, FloatsWithNaN) +{ + COL_K keys{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + cudf::test::fixed_width_column_wrapper vals{ + {1.0f, 1.0f, -2.3e-5f, -2.3e-5f, 2.3e5f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f}, + {true, true, true, true, true, true, true, true, true, true, false, false}}; + COL_K keys_expected{1}; + // null equal with nan unequal + cudf::test::lists_column_wrapper vals_expected{ + {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f}, + VALIDITY{true, true, true, true, true, true, true, false}}, + }; + test_single_agg(keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set()); + // null unequal with nan unequal + vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f}, + VALIDITY{true, true, true, true, true, true, true, false, false}}}; + test_single_agg( + keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_unequal()); + // null exclude with nan unequal + vals_expected = {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN}}; + test_single_agg( + keys, vals, keys_expected, vals_expected, CollectSetTest::collect_set_null_exclude()); + // null equal with nan equal + vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, NAN, 0.0f}, VALIDITY{true, true, true, true, false}}}; + test_single_agg(keys, + vals, + keys_expected, + vals_expected, + cudf::make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); + // null unequal with nan equal + vals_expected = { + {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, 0.0f, 0.0f}, VALIDITY{true, true, true, true, false, false}}}; + test_single_agg(keys, + vals, + keys_expected, + vals_expected, + cudf::make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::ALL_EQUAL)); +} + TYPED_TEST(CollectSetTypedTest, CollectWithNulls) { // Just use an arbitrary value to store null entries diff --git a/cpp/tests/rolling/collect_ops_test.cpp b/cpp/tests/rolling/collect_ops_test.cpp index f97e13b49f1..8f4cd34fd35 100644 --- a/cpp/tests/rolling/collect_ops_test.cpp +++ b/cpp/tests/rolling/collect_ops_test.cpp @@ -1661,16 +1661,16 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls) using T = TypeParam; - auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2, 2}; auto const input_column = fixed_width_column_wrapper{ - {10, 11, 12, 13, 13, 20, 21, 21, 23}, {1, 0, 0, 1, 1, 1, 0, 1, 1}}; + {10, 0, 0, 13, 13, 20, 21, 0, 0, 23}, {1, 0, 0, 1, 1, 1, 1, 0, 0, 1}}; auto const preceding = 2; auto const following = 1; auto const min_periods = 1; { - // Nulls included. + // Nulls included and nulls are equal. auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, input_column, @@ -1679,10 +1679,78 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls) min_periods, *make_collect_set_aggregation()); // Null values are sorted to the tails of lists (sets) - auto expected_child = fixed_width_column_wrapper{ - {10, 11, 10, 11, 13, 11, 13, 12, 13, 20, 21, 20, 21, 21, 21, 23, 21, 21, 23}, - {1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1}}; - auto expected_offsets = fixed_width_column_wrapper{0, 2, 4, 6, 8, 9, 11, 14, 17, 19}; + auto expected_child = fixed_width_column_wrapper{{ + 10, 0, // row 0 + 10, 0, // row 1 + 13, 0, // row 2 + 13, 0, // row 3 + 13, // row 4 + 20, 21, // row 5 + 20, 21, 0, // row 6 + 21, 0, // row 7 + 23, 0, // row 8 + 23, 0, // row 9 + }, + { + 1, 0, // row 0 + 1, 0, // row 1 + 1, 0, // row 2 + 1, 0, // row 3 + 1, // row 4 + 1, 1, // row 5 + 1, 1, 0, // row 6 + 1, 0, // row 7 + 1, 0, // row 8 + 1, 0 // row 9 + }}; + auto expected_offsets = + fixed_width_column_wrapper{0, 2, 4, 6, 8, 9, 11, 14, 16, 18, 20}; + + auto expected_result = make_lists_column(static_cast(group_column).size(), + expected_offsets.release(), + expected_child.release(), + 0, + {}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + } + + { + // Nulls included and nulls are NOT equal. + auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::UNEQUAL)); + // Null values are sorted to the tails of lists (sets) + auto expected_child = fixed_width_column_wrapper{{ + 10, 0, // row 0 + 10, 0, 0, // row 1 + 13, 0, 0, // row 2 + 13, 0, // row 3 + 13, // row 4 + 20, 21, // row 5 + 20, 21, 0, // row 6 + 21, 0, 0, // row 7 + 23, 0, 0, // row 8 + 23, 0 // row 9 + }, + { + 1, 0, // row 0 + 1, 0, 0, // row 1 + 1, 0, 0, // row 2 + 1, 0, // row 3 + 1, // row 4 + 1, 1, // row 5 + 1, 1, 0, // row 6 + 1, 0, 0, // row 7 + 1, 0, 0, // row 8 + 1, 0 // row 9 + }}; + auto expected_offsets = + fixed_width_column_wrapper{0, 2, 5, 8, 10, 11, 13, 16, 19, 22, 24}; auto expected_result = make_lists_column(static_cast(group_column).size(), expected_offsets.release(), @@ -1703,10 +1771,22 @@ TYPED_TEST(TypedCollectSetTest, BasicGroupedRollingWindowWithNulls) min_periods, *make_collect_set_aggregation(null_policy::EXCLUDE)); - auto expected_child = - fixed_width_column_wrapper{10, 10, 13, 13, 13, 20, 20, 21, 21, 23, 21, 23}; - - auto expected_offsets = fixed_width_column_wrapper{0, 1, 2, 3, 4, 5, 6, 8, 10, 12}; + auto expected_child = fixed_width_column_wrapper{ + 10, // row 0 + 10, // row 1 + 13, // row 2 + 13, // row 3 + 13, // row 4 + 20, + 21, // row 5 + 20, + 21, // row 6 + 21, // row 7 + 23, // row 8 + 23 // row 9 + }; + + auto expected_offsets = fixed_width_column_wrapper{0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12}; auto expected_result = make_lists_column(static_cast(group_column).size(), expected_offsets.release(), @@ -1957,6 +2037,68 @@ TEST_F(CollectSetTest, BoolGroupedRollingWindow) CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); } +TEST_F(CollectSetTest, FloatGroupedRollingWindowWithNaNs) +{ + using namespace cudf; + using namespace cudf::test; + + auto const group_column = fixed_width_column_wrapper{1, 1, 1, 1, 1, 2, 2, 2, 2}; + auto const input_column = fixed_width_column_wrapper{ + {1.23, 0.2341, 0.2341, -5.23e9, std::nan("1"), 1.1, std::nan("1"), std::nan("1"), 0.0}, + {true, true, true, true, true, true, true, true, false}}; + + auto const preceding = 2; + auto const following = 1; + auto const min_periods = 1; + // test on nan_equality::UNEQUAL + auto const result = grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation()); + + auto const expected_result = lists_column_wrapper{ + {{0.2341, 1.23}, std::initializer_list{true, true}}, + {{0.2341, 1.23}, std::initializer_list{true, true}}, + {{-5.23e9, 0.2341}, std::initializer_list{true, true}}, + {{-5.23e9, 0.2341, std::nan("1")}, std::initializer_list{true, true, true}}, + {{-5.23e9, std::nan("1")}, std::initializer_list{true, true}}, + {{1.1, std::nan("1")}, std::initializer_list{true, true}}, + {{1.1, std::nan("1"), std::nan("1")}, std::initializer_list{true, true, true}}, + {{std::nan("1"), std::nan("1"), 0.0}, std::initializer_list{true, true, false}}, + {{std::nan("1"), 0.0}, + std::initializer_list{ + true, false}}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result->view()); + + // test on nan_equality::ALL_EQUAL + auto const result_nan_equal = + grouped_rolling_window(table_view{std::vector{group_column}}, + input_column, + preceding, + following, + min_periods, + *make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); + + auto const expected_result_nan_equal = lists_column_wrapper{ + {{0.2341, 1.23}, std::initializer_list{true, true}}, + {{0.2341, 1.23}, std::initializer_list{true, true}}, + {{-5.23e9, 0.2341}, std::initializer_list{true, true}}, + {{-5.23e9, 0.2341, std::nan("1")}, std::initializer_list{true, true, true}}, + {{-5.23e9, std::nan("1")}, std::initializer_list{true, true}}, + {{1.1, std::nan("1")}, std::initializer_list{true, true}}, + {{1.1, std::nan("1")}, std::initializer_list{true, true}}, + {{std::nan("1"), 0.0}, std::initializer_list{true, false}}, + {{std::nan("1"), 0.0}, + std::initializer_list{true, + false}}}.release(); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_nan_equal->view(), result_nan_equal->view()); +} + TEST_F(CollectSetTest, BasicRollingWindowWithNaNs) { using namespace cudf; @@ -2002,6 +2144,27 @@ TEST_F(CollectSetTest, BasicRollingWindowWithNaNs) *make_collect_set_aggregation(null_policy::EXCLUDE)); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view()); + + auto const expected_result_for_nan_equal = + lists_column_wrapper{ + {0.2341, 1.23}, + {0.2341, 1.23, std::nan("1")}, + {0.2341, std::nan("1")}, + {-5.23e9, std::nan("1")}, + {-5.23e9, std::nan("1")}, + } + .release(); + + auto const result_with_nan_equal = + rolling_window(input_column, + 2, + 1, + 1, + *make_collect_set_aggregation( + null_policy::INCLUDE, null_equality::EQUAL, nan_equality::ALL_EQUAL)); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result_for_nan_equal->view(), + result_with_nan_equal->view()); } TEST_F(CollectSetTest, ListTypeRollingWindow)