diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index a26a0c7947b..f0b190e6438 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -588,8 +588,9 @@ std::unique_ptr make_merge_lists_aggregation(); * @return A MERGE_SETS aggregation object */ template -std::unique_ptr make_merge_sets_aggregation(null_equality nulls_equal = null_equality::EQUAL, - nan_equality nans_equal = nan_equality::UNEQUAL); +std::unique_ptr make_merge_sets_aggregation( + null_equality nulls_equal = null_equality::EQUAL, + nan_equality nans_equal = nan_equality::ALL_EQUAL); /** * @brief Factory to create a MERGE_M2 aggregation diff --git a/cpp/tests/reductions/collect_ops_tests.cpp b/cpp/tests/reductions/collect_ops_tests.cpp index a0fdab5e994..c00d43ad320 100644 --- a/cpp/tests/reductions/collect_ops_tests.cpp +++ b/cpp/tests/reductions/collect_ops_tests.cpp @@ -239,14 +239,17 @@ TEST_F(CollectTest, MergeSetsWithNaN) // nan unequal with null equal fp_wrapper expected1{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, NAN, NAN, 0.0f}, {1, 1, 1, 1, 1, 1, 0}}; - auto const ret1 = collect_set(col, make_merge_sets_aggregation()); + auto const ret1 = collect_set( + col, + make_merge_sets_aggregation(null_equality::EQUAL, nan_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected1, dynamic_cast(ret1.get())->view()); // nan unequal with null unequal fp_wrapper expected2{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, NAN, NAN, 0.0f, 0.0f, 0.0f}, {1, 1, 1, 1, 1, 1, 0, 0, 0}}; - auto const ret2 = - collect_set(col, make_merge_sets_aggregation(null_equality::UNEQUAL)); + auto const ret2 = collect_set( + col, + make_merge_sets_aggregation(null_equality::UNEQUAL, nan_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, dynamic_cast(ret2.get())->view()); // nan equal with null equal