Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to equal NaNs in make_collect_set_aggregation. #11621

Merged
merged 9 commits into from
Oct 20, 2022
7 changes: 4 additions & 3 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,10 @@ std::unique_ptr<Base> make_collect_list_aggregation(
* @return A COLLECT_SET aggregation object
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);
std::unique_ptr<Base> make_collect_set_aggregation(
null_policy null_handling = null_policy::INCLUDE,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::ALL_EQUAL);

/**
* @brief Factory to create a LAG aggregation
Expand Down
15 changes: 12 additions & 3 deletions cpp/tests/groupby/collect_set_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ TEST_F(CollectSetTest, FloatsWithNaN)
vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f},
validity_col{true, true, true, true, true, true, true, false}}};
auto const [out_keys, out_lists] =
groupby_collect_set(keys, vals, CollectSetTest::collect_set());
groupby_collect_set(keys,
vals,
cudf::make_collect_set_aggregation<cudf::groupby_aggregation>(
null_policy::INCLUDE, null_equality::EQUAL, nan_equality::UNEQUAL));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity);
}
Expand All @@ -258,7 +261,10 @@ TEST_F(CollectSetTest, FloatsWithNaN)
vals_expected = {{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f},
validity_col{true, true, true, true, true, true, true, false, false}}};
auto const [out_keys, out_lists] =
groupby_collect_set(keys, vals, CollectSetTest::collect_set_null_unequal());
groupby_collect_set(keys,
vals,
cudf::make_collect_set_aggregation<cudf::groupby_aggregation>(
null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::UNEQUAL));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity);
}
Expand All @@ -267,7 +273,10 @@ TEST_F(CollectSetTest, FloatsWithNaN)
{
vals_expected = {{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN}};
auto const [out_keys, out_lists] =
groupby_collect_set(keys, vals, CollectSetTest::collect_set_null_exclude());
groupby_collect_set(keys,
vals,
cudf::make_collect_set_aggregation<cudf::groupby_aggregation>(
null_policy::EXCLUDE, null_equality::EQUAL, nan_equality::UNEQUAL));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(keys_expected, *out_keys, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUAL(vals_expected, *out_lists, verbosity);
}
Expand Down
12 changes: 8 additions & 4 deletions cpp/tests/reductions/collect_ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,19 @@ TEST_F(CollectTest, CollectSetWithNaN)
// nan unequal with null equal
fp_wrapper expected1{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f},
{1, 1, 1, 1, 1, 1, 1, 0}};
auto const ret1 = collect_set(col, make_collect_set_aggregation<reduce_aggregation>());
auto const ret1 =
collect_set(col,
make_collect_set_aggregation<reduce_aggregation>(
null_policy::INCLUDE, null_equality::EQUAL, nan_equality::UNEQUAL));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected1, dynamic_cast<list_scalar*>(ret1.get())->view());

// nan unequal with null unequal
fp_wrapper expected2{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, -NAN, NAN, NAN, 0.0f, 0.0f},
{1, 1, 1, 1, 1, 1, 1, 0, 0}};
auto const ret2 = collect_set(
col,
make_collect_set_aggregation<reduce_aggregation>(null_policy::INCLUDE, null_equality::UNEQUAL));
auto const ret2 =
collect_set(col,
make_collect_set_aggregation<reduce_aggregation>(
null_policy::INCLUDE, null_equality::UNEQUAL, nan_equality::UNEQUAL));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, dynamic_cast<list_scalar*>(ret2.get())->view());

// nan equal with null equal
Expand Down
30 changes: 19 additions & 11 deletions cpp/tests/rolling/collect_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2118,13 +2118,14 @@ TEST_F(CollectSetTest, FloatGroupedRollingWindowWithNaNs)
auto const following = 1;
auto const min_periods = 1;
// test on nan_equality::UNEQUAL
auto const result =
grouped_rolling_collect_set(table_view{std::vector<column_view>{group_column}},
input_column,
preceding,
following,
min_periods,
*make_collect_set_aggregation<rolling_aggregation>());
auto const result = grouped_rolling_collect_set(
table_view{std::vector<column_view>{group_column}},
input_column,
preceding,
following,
min_periods,
*make_collect_set_aggregation<rolling_aggregation>(
null_policy::INCLUDE, null_equality::EQUAL, nan_equality::UNEQUAL));

auto const expected_result = lists_column_wrapper<double>{
{{0.2341, 1.23}, std::initializer_list<bool>{true, true}},
Expand Down Expand Up @@ -2186,7 +2187,8 @@ TEST_F(CollectSetTest, BasicRollingWindowWithNaNs)
prev_column,
foll_column,
1,
*make_collect_set_aggregation<rolling_aggregation>());
*make_collect_set_aggregation<rolling_aggregation>(
null_policy::INCLUDE, null_equality::EQUAL, nan_equality::UNEQUAL));

auto const expected_result =
lists_column_wrapper<double>{
Expand All @@ -2200,16 +2202,22 @@ TEST_F(CollectSetTest, BasicRollingWindowWithNaNs)

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_column_based_window->view());

auto const result_fixed_window = rolling_collect_set(
input_column, 2, 1, 1, *make_collect_set_aggregation<rolling_aggregation>());
auto const result_fixed_window =
rolling_collect_set(input_column,
2,
1,
1,
*make_collect_set_aggregation<rolling_aggregation>(
null_policy::INCLUDE, null_equality::EQUAL, nan_equality::UNEQUAL));
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_fixed_window->view());

auto const result_with_nulls_excluded =
rolling_collect_set(input_column,
2,
1,
1,
*make_collect_set_aggregation<rolling_aggregation>(null_policy::EXCLUDE));
*make_collect_set_aggregation<rolling_aggregation>(
null_policy::EXCLUDE, null_equality::EQUAL, nan_equality::UNEQUAL));

CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result->view(), result_with_nulls_excluded->view());

Expand Down