Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to equal NaNs in make_merge_sets_aggregation. #11952

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,9 @@ std::unique_ptr<Base> make_merge_lists_aggregation();
* @return A MERGE_SETS aggregation object
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_sets_aggregation(null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);
std::unique_ptr<Base> make_merge_sets_aggregation(
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::ALL_EQUAL);

/**
* @brief Factory to create a MERGE_M2 aggregation
Expand Down
9 changes: 6 additions & 3 deletions cpp/tests/reductions/collect_ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,17 @@ TEST_F(CollectTest, MergeSetsWithNaN)

// nan unequal with null equal
fp_wrapper expected1{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, NAN, NAN, 0.0f}, {1, 1, 1, 1, 1, 1, 0}};
auto const ret1 = collect_set(col, make_merge_sets_aggregation<reduce_aggregation>());
auto const ret1 = collect_set(
col,
make_merge_sets_aggregation<reduce_aggregation>(null_equality::EQUAL, nan_equality::UNEQUAL));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected1, dynamic_cast<list_scalar*>(ret1.get())->view());

// nan unequal with null unequal
fp_wrapper expected2{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, NAN, NAN, 0.0f, 0.0f, 0.0f},
{1, 1, 1, 1, 1, 1, 0, 0, 0}};
auto const ret2 =
collect_set(col, make_merge_sets_aggregation<reduce_aggregation>(null_equality::UNEQUAL));
auto const ret2 = collect_set(
col,
make_merge_sets_aggregation<reduce_aggregation>(null_equality::UNEQUAL, nan_equality::UNEQUAL));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, dynamic_cast<list_scalar*>(ret2.get())->view());

// nan equal with null equal
Expand Down