From d05de978f2d1f34b7629bd54ab9485df1f9949ef Mon Sep 17 00:00:00 2001 From: Karthikeyan <6488848+karthikeyann@users.noreply.github.com> Date: Wed, 14 Jul 2021 04:15:44 +0530 Subject: [PATCH] Fix min/max sorted groupby aggregation on string column with nulls (argmin, argmax sentinel value missing on nulls) (#8731) closes https://github.com/rapidsai/cudf/issues/8717 Usages of argmin, argmax depends on presence of sentinel values at nulls (ARGMIN_SENTINEL, ARGMAX_SENTINEL), group_argmin and group_argmax need to guarantee these sentinel values in their output. but `cudf::detaill:gather` doesn't guarantee that. This PR fixes this. - [x] replace `cudf::detail::gather` with `thrust::gather_if` on indices to fix missing SENTINEL values for argmin, argmax. - [x] add unit tests. Authors: - Karthikeyan (https://github.com/karthikeyann) Approvers: - Jake Hemstad (https://github.com/jrhemstad) - Jason Lowe (https://github.com/jlowe) - Mark Harris (https://github.com/harrism) URL: https://github.com/rapidsai/cudf/pull/8731 --- cpp/src/groupby/sort/group_argmax.cu | 33 ++++++++++-------------- cpp/src/groupby/sort/group_argmin.cu | 33 ++++++++++-------------- cpp/tests/groupby/max_tests.cpp | 38 +++++++++++++++++++++++++++- cpp/tests/groupby/min_tests.cpp | 38 +++++++++++++++++++++++++++- 4 files changed, 102 insertions(+), 40 deletions(-) diff --git a/cpp/src/groupby/sort/group_argmax.cu b/cpp/src/groupby/sort/group_argmax.cu index bed64c5147a..6ce23ffc35b 100644 --- a/cpp/src/groupby/sort/group_argmax.cu +++ b/cpp/src/groupby/sort/group_argmax.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include -#include +#include namespace cudf { namespace groupby { @@ -39,29 +39,24 @@ std::unique_ptr group_argmax(column_view const& values, num_groups, group_labels, stream, - rmm::mr::get_current_device_resource()); + mr); // The functor returns the index of maximum in the sorted values. // We need the index of maximum in the original unsorted values. // So use indices to gather the sort order used to sort `values`. // Gather map cannot be null so we make a view with the mask removed. // The values in data buffer of indices corresponding to null values was - // initialized to ARGMAX_SENTINEL which is an out of bounds index value (-1) - // and causes the gathered value to be null. - column_view null_removed_indices( - data_type(type_to_id()), - indices->size(), - static_cast(indices->view().template data())); - auto result_table = - cudf::detail::gather(table_view({key_sort_order}), - null_removed_indices, - indices->nullable() ? cudf::out_of_bounds_policy::NULLIFY - : cudf::out_of_bounds_policy::DONT_CHECK, - cudf::detail::negative_index_policy::NOT_ALLOWED, - stream, - mr); - - return std::move(result_table->release()[0]); + // initialized to ARGMAX_SENTINEL. Using gather_if. + // This can't use gather because nulls in gathered column will not store ARGMAX_SENTINEL. + auto indices_view = indices->mutable_view(); + thrust::gather_if(rmm::exec_policy(stream), + indices_view.begin(), // map first + indices_view.end(), // map last + indices_view.begin(), // stencil + key_sort_order.begin(), // input + indices_view.begin(), // result + [] __device__(auto i) { return (i != cudf::detail::ARGMAX_SENTINEL); }); + return indices; } } // namespace detail diff --git a/cpp/src/groupby/sort/group_argmin.cu b/cpp/src/groupby/sort/group_argmin.cu index ec97a609390..ab91c2c0d29 100644 --- a/cpp/src/groupby/sort/group_argmin.cu +++ b/cpp/src/groupby/sort/group_argmin.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include -#include +#include namespace cudf { namespace groupby { @@ -39,29 +39,24 @@ std::unique_ptr group_argmin(column_view const& values, num_groups, group_labels, stream, - rmm::mr::get_current_device_resource()); + mr); // The functor returns the index of minimum in the sorted values. // We need the index of minimum in the original unsorted values. // So use indices to gather the sort order used to sort `values`. - // Gather map cannot be null so we make a view with the mask removed. // The values in data buffer of indices corresponding to null values was - // initialized to ARGMIN_SENTINEL which is an out of bounds index value (-1) - // and causes the gathered value to be null. - column_view null_removed_indices( - data_type(type_to_id()), - indices->size(), - static_cast(indices->view().template data())); - auto result_table = - cudf::detail::gather(table_view({key_sort_order}), - null_removed_indices, - indices->nullable() ? cudf::out_of_bounds_policy::NULLIFY - : cudf::out_of_bounds_policy::DONT_CHECK, - cudf::detail::negative_index_policy::NOT_ALLOWED, - stream, - mr); + // initialized to ARGMIN_SENTINEL. Using gather_if. + // This can't use gather because nulls in gathered column will not store ARGMIN_SENTINEL. + auto indices_view = indices->mutable_view(); + thrust::gather_if(rmm::exec_policy(stream), + indices_view.begin(), // map first + indices_view.end(), // map last + indices_view.begin(), // stencil + key_sort_order.begin(), // input + indices_view.begin(), // result + [] __device__(auto i) { return (i != cudf::detail::ARGMIN_SENTINEL); }); - return std::move(result_table->release()[0]); + return indices; } } // namespace detail diff --git a/cpp/tests/groupby/max_tests.cpp b/cpp/tests/groupby/max_tests.cpp index e0da55b080f..b5710d3f4bc 100644 --- a/cpp/tests/groupby/max_tests.cpp +++ b/cpp/tests/groupby/max_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -162,6 +162,42 @@ TEST_F(groupby_max_string_test, zero_valid_values) test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg2), force_use_sort_impl::YES); } +TEST_F(groupby_max_string_test, max_sorted_strings) +{ + // testcase replicated in issue #8717 + cudf::test::strings_column_wrapper keys( + {"", "", "", "", "", "", "06", "06", "06", "06", "10", "10", "10", "10", "14", "14", + "14", "14", "18", "18", "18", "18", "22", "22", "22", "22", "26", "26", "26", "26", "30", "30", + "30", "30", "34", "34", "34", "34", "38", "38", "38", "38", "42", "42", "42", "42"}, + {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + cudf::test::strings_column_wrapper vals( + {"", "", "", "", "", "", "06", "", "", "", "10", "", "", "", "14", "", + "", "", "18", "", "", "", "22", "", "", "", "26", "", "", "", "30", "", + "", "", "34", "", "", "", "38", "", "", "", "42", "", "", ""}, + {0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0}); + cudf::test::strings_column_wrapper expect_keys( + {"06", "10", "14", "18", "22", "26", "30", "34", "38", "42", ""}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0}); + cudf::test::strings_column_wrapper expect_vals( + {"06", "10", "14", "18", "22", "26", "30", "34", "38", "42", ""}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0}); + + // fixed_width_column_wrapper expect_argmax( + // {6, 10, 14, 18, 22, 26, 30, 34, 38, 42, -1}, + // {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0}); + auto agg = cudf::make_max_aggregation(); + test_single_agg(keys, + vals, + expect_keys, + expect_vals, + std::move(agg), + force_use_sort_impl::NO, + null_policy::INCLUDE, + sorted::YES); +} + struct groupby_dictionary_max_test : public cudf::test::BaseFixture { }; diff --git a/cpp/tests/groupby/min_tests.cpp b/cpp/tests/groupby/min_tests.cpp index 8f997875a78..1544e867595 100644 --- a/cpp/tests/groupby/min_tests.cpp +++ b/cpp/tests/groupby/min_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -162,6 +162,42 @@ TEST_F(groupby_min_string_test, zero_valid_values) test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg2), force_use_sort_impl::YES); } +TEST_F(groupby_min_string_test, min_sorted_strings) +{ + // testcase replicated in issue #8717 + cudf::test::strings_column_wrapper keys( + {"", "", "", "", "", "", "06", "06", "06", "06", "10", "10", "10", "10", "14", "14", + "14", "14", "18", "18", "18", "18", "22", "22", "22", "22", "26", "26", "26", "26", "30", "30", + "30", "30", "34", "34", "34", "34", "38", "38", "38", "38", "42", "42", "42", "42"}, + {0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + cudf::test::strings_column_wrapper vals( + {"", "", "", "", "", "", "06", "", "", "", "10", "", "", "", "14", "", + "", "", "18", "", "", "", "22", "", "", "", "26", "", "", "", "30", "", + "", "", "34", "", "", "", "38", "", "", "", "42", "", "", ""}, + {0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0}); + cudf::test::strings_column_wrapper expect_keys( + {"06", "10", "14", "18", "22", "26", "30", "34", "38", "42", ""}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0}); + cudf::test::strings_column_wrapper expect_vals( + {"06", "10", "14", "18", "22", "26", "30", "34", "38", "42", ""}, + {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0}); + + // fixed_width_column_wrapper expect_argmin( + // {6, 10, 14, 18, 22, 26, 30, 34, 38, 42, -1}, + // {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0}); + auto agg = cudf::make_min_aggregation(); + test_single_agg(keys, + vals, + expect_keys, + expect_vals, + std::move(agg), + force_use_sort_impl::NO, + null_policy::INCLUDE, + sorted::YES); +} + struct groupby_dictionary_min_test : public cudf::test::BaseFixture { };