Skip to content

Commit

Permalink
Fix min/max sorted groupby aggregation on string column with nulls (a…
Browse files Browse the repository at this point in the history
…rgmin, argmax sentinel value missing on nulls) (#8731)

closes #8717

Usages of argmin, argmax depends on presence of sentinel values at nulls (ARGMIN_SENTINEL, ARGMAX_SENTINEL), group_argmin and group_argmax need to guarantee these sentinel values in their output. but `cudf::detaill:gather` doesn't guarantee that. This PR fixes this.
- [x] replace `cudf::detail::gather` with `thrust::gather_if` on indices to fix missing SENTINEL values for argmin, argmax.
- [x] add unit tests.

Authors:
  - Karthikeyan (https://github.com/karthikeyann)

Approvers:
  - Jake Hemstad (https://github.com/jrhemstad)
  - Jason Lowe (https://github.com/jlowe)
  - Mark Harris (https://github.com/harrism)

URL: #8731
  • Loading branch information
karthikeyann authored Jul 13, 2021
1 parent d91d011 commit d05de97
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 40 deletions.
33 changes: 14 additions & 19 deletions cpp/src/groupby/sort/group_argmax.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,7 @@

#include <rmm/cuda_stream_view.hpp>

#include <thrust/transform.h>
#include <thrust/gather.h>

namespace cudf {
namespace groupby {
Expand All @@ -39,29 +39,24 @@ std::unique_ptr<column> group_argmax(column_view const& values,
num_groups,
group_labels,
stream,
rmm::mr::get_current_device_resource());
mr);

// The functor returns the index of maximum in the sorted values.
// We need the index of maximum in the original unsorted values.
// So use indices to gather the sort order used to sort `values`.
// Gather map cannot be null so we make a view with the mask removed.
// The values in data buffer of indices corresponding to null values was
// initialized to ARGMAX_SENTINEL which is an out of bounds index value (-1)
// and causes the gathered value to be null.
column_view null_removed_indices(
data_type(type_to_id<size_type>()),
indices->size(),
static_cast<void const*>(indices->view().template data<size_type>()));
auto result_table =
cudf::detail::gather(table_view({key_sort_order}),
null_removed_indices,
indices->nullable() ? cudf::out_of_bounds_policy::NULLIFY
: cudf::out_of_bounds_policy::DONT_CHECK,
cudf::detail::negative_index_policy::NOT_ALLOWED,
stream,
mr);

return std::move(result_table->release()[0]);
// initialized to ARGMAX_SENTINEL. Using gather_if.
// This can't use gather because nulls in gathered column will not store ARGMAX_SENTINEL.
auto indices_view = indices->mutable_view();
thrust::gather_if(rmm::exec_policy(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
indices_view.begin<size_type>(), // stencil
key_sort_order.begin<size_type>(), // input
indices_view.begin<size_type>(), // result
[] __device__(auto i) { return (i != cudf::detail::ARGMAX_SENTINEL); });
return indices;
}

} // namespace detail
Expand Down
33 changes: 14 additions & 19 deletions cpp/src/groupby/sort/group_argmin.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,7 @@

#include <rmm/cuda_stream_view.hpp>

#include <thrust/transform.h>
#include <thrust/gather.h>

namespace cudf {
namespace groupby {
Expand All @@ -39,29 +39,24 @@ std::unique_ptr<column> group_argmin(column_view const& values,
num_groups,
group_labels,
stream,
rmm::mr::get_current_device_resource());
mr);

// The functor returns the index of minimum in the sorted values.
// We need the index of minimum in the original unsorted values.
// So use indices to gather the sort order used to sort `values`.
// Gather map cannot be null so we make a view with the mask removed.
// The values in data buffer of indices corresponding to null values was
// initialized to ARGMIN_SENTINEL which is an out of bounds index value (-1)
// and causes the gathered value to be null.
column_view null_removed_indices(
data_type(type_to_id<size_type>()),
indices->size(),
static_cast<void const*>(indices->view().template data<size_type>()));
auto result_table =
cudf::detail::gather(table_view({key_sort_order}),
null_removed_indices,
indices->nullable() ? cudf::out_of_bounds_policy::NULLIFY
: cudf::out_of_bounds_policy::DONT_CHECK,
cudf::detail::negative_index_policy::NOT_ALLOWED,
stream,
mr);
// initialized to ARGMIN_SENTINEL. Using gather_if.
// This can't use gather because nulls in gathered column will not store ARGMIN_SENTINEL.
auto indices_view = indices->mutable_view();
thrust::gather_if(rmm::exec_policy(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
indices_view.begin<size_type>(), // stencil
key_sort_order.begin<size_type>(), // input
indices_view.begin<size_type>(), // result
[] __device__(auto i) { return (i != cudf::detail::ARGMIN_SENTINEL); });

return std::move(result_table->release()[0]);
return indices;
}

} // namespace detail
Expand Down
38 changes: 37 additions & 1 deletion cpp/tests/groupby/max_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -162,6 +162,42 @@ TEST_F(groupby_max_string_test, zero_valid_values)
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg2), force_use_sort_impl::YES);
}

TEST_F(groupby_max_string_test, max_sorted_strings)
{
// testcase replicated in issue #8717
cudf::test::strings_column_wrapper keys(
{"", "", "", "", "", "", "06", "06", "06", "06", "10", "10", "10", "10", "14", "14",
"14", "14", "18", "18", "18", "18", "22", "22", "22", "22", "26", "26", "26", "26", "30", "30",
"30", "30", "34", "34", "34", "34", "38", "38", "38", "38", "42", "42", "42", "42"},
{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
cudf::test::strings_column_wrapper vals(
{"", "", "", "", "", "", "06", "", "", "", "10", "", "", "", "14", "",
"", "", "18", "", "", "", "22", "", "", "", "26", "", "", "", "30", "",
"", "", "34", "", "", "", "38", "", "", "", "42", "", "", ""},
{0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1,
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0});
cudf::test::strings_column_wrapper expect_keys(
{"06", "10", "14", "18", "22", "26", "30", "34", "38", "42", ""},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0});
cudf::test::strings_column_wrapper expect_vals(
{"06", "10", "14", "18", "22", "26", "30", "34", "38", "42", ""},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0});

// fixed_width_column_wrapper<size_type> expect_argmax(
// {6, 10, 14, 18, 22, 26, 30, 34, 38, 42, -1},
// {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0});
auto agg = cudf::make_max_aggregation();
test_single_agg(keys,
vals,
expect_keys,
expect_vals,
std::move(agg),
force_use_sort_impl::NO,
null_policy::INCLUDE,
sorted::YES);
}

struct groupby_dictionary_max_test : public cudf::test::BaseFixture {
};

Expand Down
38 changes: 37 additions & 1 deletion cpp/tests/groupby/min_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -162,6 +162,42 @@ TEST_F(groupby_min_string_test, zero_valid_values)
test_single_agg(keys, vals, expect_keys, expect_vals, std::move(agg2), force_use_sort_impl::YES);
}

TEST_F(groupby_min_string_test, min_sorted_strings)
{
// testcase replicated in issue #8717
cudf::test::strings_column_wrapper keys(
{"", "", "", "", "", "", "06", "06", "06", "06", "10", "10", "10", "10", "14", "14",
"14", "14", "18", "18", "18", "18", "22", "22", "22", "22", "26", "26", "26", "26", "30", "30",
"30", "30", "34", "34", "34", "34", "38", "38", "38", "38", "42", "42", "42", "42"},
{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
cudf::test::strings_column_wrapper vals(
{"", "", "", "", "", "", "06", "", "", "", "10", "", "", "", "14", "",
"", "", "18", "", "", "", "22", "", "", "", "26", "", "", "", "30", "",
"", "", "34", "", "", "", "38", "", "", "", "42", "", "", ""},
{0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1,
0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0});
cudf::test::strings_column_wrapper expect_keys(
{"06", "10", "14", "18", "22", "26", "30", "34", "38", "42", ""},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0});
cudf::test::strings_column_wrapper expect_vals(
{"06", "10", "14", "18", "22", "26", "30", "34", "38", "42", ""},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0});

// fixed_width_column_wrapper<size_type> expect_argmin(
// {6, 10, 14, 18, 22, 26, 30, 34, 38, 42, -1},
// {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0});
auto agg = cudf::make_min_aggregation();
test_single_agg(keys,
vals,
expect_keys,
expect_vals,
std::move(agg),
force_use_sort_impl::NO,
null_policy::INCLUDE,
sorted::YES);
}

struct groupby_dictionary_min_test : public cudf::test::BaseFixture {
};

Expand Down

0 comments on commit d05de97

Please sign in to comment.