Skip to content

Commit

Permalink
Add support for STRUCT input to groupby (#9024)
Browse files Browse the repository at this point in the history
This commit adds support for `STRUCT` columns in `groupby`. This should now allow for groupby aggregations to work when any of the grouping columns are `STRUCT`, including nested `STRUCTS`.

Note: List columns are still not supported on `groupby`, even as members of `STRUCT` columns, at any level of nesting. Only `STRUCT`, `STRUCT<STRUCT>`, etc. are currently supported.

Depends on #8956 (i.e. `unflatten_nested_columns()`).

Authors:
  - MithunR (https://github.com/mythrocks)

Approvers:
  - Jake Hemstad (https://github.com/jrhemstad)

URL: #9024
  • Loading branch information
mythrocks authored Aug 26, 2021
1 parent 4e0584b commit d9d565e
Show file tree
Hide file tree
Showing 5 changed files with 361 additions and 14 deletions.
14 changes: 3 additions & 11 deletions cpp/include/cudf/detail/groupby/sort_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,7 @@ struct sort_groupby_helper {
*/
sort_groupby_helper(table_view const& keys,
null_policy include_null_keys = null_policy::EXCLUDE,
sorted keys_pre_sorted = sorted::NO)
: _keys(keys),
_num_keys(-1),
_keys_pre_sorted(keys_pre_sorted),
_include_null_keys(include_null_keys)
{
if (keys_pre_sorted == sorted::YES and include_null_keys == null_policy::EXCLUDE and
has_nulls(keys)) {
_keys_pre_sorted = sorted::NO;
}
};
sorted keys_pre_sorted = sorted::NO);

~sort_groupby_helper() = default;
sort_groupby_helper(sort_groupby_helper const&) = delete;
Expand Down Expand Up @@ -227,6 +217,8 @@ struct sort_groupby_helper {
column_ptr _unsorted_keys_labels; ///< Group labels for unsorted _keys
column_ptr _keys_bitmask_column; ///< Column representing rows with one or more nulls values
table_view _keys; ///< Input keys to sort by
table_view _unflattened_keys; ///< Input keys, unflattened and possibly nested
std::vector<column_ptr> _struct_null_vectors; ///< Null vectors for struct columns in _keys

index_vector_ptr
_group_offsets; ///< Indices into sorted _keys indicating starting index of each groups
Expand Down
11 changes: 10 additions & 1 deletion cpp/src/groupby/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
#include <structs/utilities.hpp>

#include <rmm/cuda_stream_view.hpp>

Expand Down Expand Up @@ -62,6 +63,8 @@ std::pair<std::unique_ptr<table>, std::vector<aggregation_result>> groupby::disp
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
using namespace cudf::structs::detail;

// If sort groupby has been called once on this groupby object, then
// always use sort groupby from now on. Because once keys are sorted,
// all the aggs that can be done by hash groupby are efficiently done by
Expand All @@ -70,7 +73,13 @@ std::pair<std::unique_ptr<table>, std::vector<aggregation_result>> groupby::disp
// satisfied with a hash implementation
if (_keys_are_sorted == sorted::NO and not _helper and
detail::hash::can_use_hash_groupby(_keys, requests)) {
return detail::hash::groupby(_keys, requests, _include_null_keys, stream, mr);
// Optionally flatten nested key columns.
auto [flattened_keys, _, __, ___] =
flatten_nested_columns(_keys, {}, {}, column_nullability::FORCE);
auto [grouped_keys, results] =
detail::hash::groupby(flattened_keys, requests, _include_null_keys, stream, mr);
return std::make_pair(unflatten_nested_columns(std::move(grouped_keys), _keys),
std::move(results));
} else {
return sort_aggregate(requests, stream, mr);
}
Expand Down
30 changes: 28 additions & 2 deletions cpp/src/groupby/sort/sort_helper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cudf/detail/sorting.hpp>
#include <cudf/table/row_operators.cuh>
#include <cudf/table/table_device_view.cuh>
#include <structs/utilities.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>
Expand Down Expand Up @@ -88,6 +89,31 @@ namespace cudf {
namespace groupby {
namespace detail {
namespace sort {

sort_groupby_helper::sort_groupby_helper(table_view const& keys,
null_policy include_null_keys,
sorted keys_pre_sorted)
: _unflattened_keys(keys),
_num_keys(-1),
_keys_pre_sorted(keys_pre_sorted),
_include_null_keys(include_null_keys)
{
using namespace cudf::structs::detail;

auto [flattened_keys, _, __, struct_null_vectors] =
flatten_nested_columns(keys, {}, {}, column_nullability::FORCE);
_struct_null_vectors = std::move(struct_null_vectors);
_keys = flattened_keys;

// Cannot depend on caller's sorting if the column contains nulls,
// and null values are to be excluded.
// Re-sort the data, to filter out nulls more easily.
if (keys_pre_sorted == sorted::YES and include_null_keys == null_policy::EXCLUDE and
has_nulls(keys)) {
_keys_pre_sorted = sorted::NO;
}
};

size_type sort_groupby_helper::num_keys(rmm::cuda_stream_view stream)
{
if (_num_keys > -1) return _num_keys;
Expand Down Expand Up @@ -309,7 +335,7 @@ std::unique_ptr<table> sort_groupby_helper::unique_keys(rmm::cuda_stream_view st
auto gather_map_it = thrust::make_transform_iterator(
group_offsets(stream).begin(), [idx_data] __device__(size_type i) { return idx_data[i]; });

return cudf::detail::gather(_keys,
return cudf::detail::gather(_unflattened_keys,
gather_map_it,
gather_map_it + num_groups(stream),
out_of_bounds_policy::DONT_CHECK,
Expand All @@ -320,7 +346,7 @@ std::unique_ptr<table> sort_groupby_helper::unique_keys(rmm::cuda_stream_view st
std::unique_ptr<table> sort_groupby_helper::sorted_keys(rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
return cudf::detail::gather(_keys,
return cudf::detail::gather(_unflattened_keys,
key_sort_order(stream),
cudf::out_of_bounds_policy::DONT_CHECK,
cudf::detail::negative_index_policy::NOT_ALLOWED,
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ ConfigureTest(GROUPBY_TEST
groupby/replace_nulls_tests.cpp
groupby/shift_tests.cpp
groupby/std_tests.cpp
groupby/structs_tests.cpp
groupby/sum_of_squares_tests.cpp
groupby/sum_scan_tests.cpp
groupby/sum_tests.cpp
Expand Down
Loading

0 comments on commit d9d565e

Please sign in to comment.