From 5c2150e0d942fa525205451cd954e48ff35b8a84 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Fri, 21 Oct 2022 15:10:17 -0500 Subject: [PATCH] Default to equal NaNs in make_merge_sets_aggregation. (#11952) Partially resolves https://github.com/rapidsai/cudf/issues/11329. This helps to align our default behaviors for null and NaN equality across APIs, specifically for `make_merge_sets_aggregation` in this PR. All functions should default to treating null values as equal to one another and NaN values as equal to one another. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Nghia Truong (https://github.com/ttnghia) - Vyas Ramasubramani (https://github.com/vyasr) - David Wendt (https://github.com/davidwendt) URL: https://github.com/rapidsai/cudf/pull/11952 --- cpp/include/cudf/aggregation.hpp | 5 +++-- cpp/tests/reductions/collect_ops_tests.cpp | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cpp/include/cudf/aggregation.hpp b/cpp/include/cudf/aggregation.hpp index a92da0b0347..d319041f8b1 100644 --- a/cpp/include/cudf/aggregation.hpp +++ b/cpp/include/cudf/aggregation.hpp @@ -589,8 +589,9 @@ std::unique_ptr make_merge_lists_aggregation(); * @return A MERGE_SETS aggregation object */ template -std::unique_ptr make_merge_sets_aggregation(null_equality nulls_equal = null_equality::EQUAL, - nan_equality nans_equal = nan_equality::UNEQUAL); +std::unique_ptr make_merge_sets_aggregation( + null_equality nulls_equal = null_equality::EQUAL, + nan_equality nans_equal = nan_equality::ALL_EQUAL); /** * @brief Factory to create a MERGE_M2 aggregation diff --git a/cpp/tests/reductions/collect_ops_tests.cpp b/cpp/tests/reductions/collect_ops_tests.cpp index 2bb13fd671f..90014c3b10f 100644 --- a/cpp/tests/reductions/collect_ops_tests.cpp +++ b/cpp/tests/reductions/collect_ops_tests.cpp @@ -243,14 +243,17 @@ TEST_F(CollectTest, MergeSetsWithNaN) // nan unequal with null equal fp_wrapper expected1{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, NAN, NAN, 0.0f}, {1, 1, 1, 1, 1, 1, 0}}; - auto const ret1 = collect_set(col, make_merge_sets_aggregation()); + auto const ret1 = collect_set( + col, + make_merge_sets_aggregation(null_equality::EQUAL, nan_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected1, dynamic_cast(ret1.get())->view()); // nan unequal with null unequal fp_wrapper expected2{{-2.3e-5f, 1.0f, 2.3e5f, -NAN, NAN, NAN, 0.0f, 0.0f, 0.0f}, {1, 1, 1, 1, 1, 1, 0, 0, 0}}; - auto const ret2 = - collect_set(col, make_merge_sets_aggregation(null_equality::UNEQUAL)); + auto const ret2 = collect_set( + col, + make_merge_sets_aggregation(null_equality::UNEQUAL, nan_equality::UNEQUAL)); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, dynamic_cast(ret2.get())->view()); // nan equal with null equal