Skip to content

Commit

Permalink
Use stream in groupby calls (#7705)
Browse files Browse the repository at this point in the history
**sort_groupby_helper::**
- [x] sorted_values()
- [x] grouped_values()
-  unique_keys()
-  sorted_keys()
- [x] num_groups()
-  num_keys()
- [x] key_sort_order()
- [x] group_offsets()
- [x] group_labels()
- [x] unsorted_keys_labels()
- [x] keys_bitmask_column()

**groupby::**
- [x] - dispatch_aggregation()

Authors:
  - Karthikeyan (@karthikeyann)

Approvers:
  - David (@davidwendt)
  - Ram (Ramakrishna Prabhu) (@rgsl888prabhu)

URL: #7705
  • Loading branch information
karthikeyann authored Mar 27, 2021
1 parent 44adf97 commit ccc28d5
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 65 deletions.
26 changes: 13 additions & 13 deletions cpp/include/cudf/detail/groupby/sort_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct sort_groupby_helper {
*/
std::unique_ptr<column> sorted_values(
column_view const& values,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -108,7 +108,7 @@ struct sort_groupby_helper {
*/
std::unique_ptr<column> grouped_values(
column_view const& values,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -117,7 +117,7 @@ struct sort_groupby_helper {
* @return a new table in which each row is a unique row in the sorted key table.
*/
std::unique_ptr<table> unique_keys(
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -126,13 +126,13 @@ struct sort_groupby_helper {
* @return a new table containing the sorted keys.
*/
std::unique_ptr<table> sorted_keys(
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Get the number of groups in `keys`
*/
size_type num_groups() { return group_offsets().size() - 1; }
size_type num_groups(rmm::cuda_stream_view stream) { return group_offsets(stream).size() - 1; }

/**
* @brief Return the effective number of keys
Expand All @@ -141,7 +141,7 @@ struct sort_groupby_helper {
* When include_null_keys = NO, returned value is the number of rows in `keys`
* in which no element is null
*/
size_type num_keys(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
size_type num_keys(rmm::cuda_stream_view stream);

/**
* @brief Get the sorted order of `keys`.
Expand All @@ -156,7 +156,7 @@ struct sort_groupby_helper {
*
* @return the sort order indices for `keys`.
*/
column_view key_sort_order(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
column_view key_sort_order(rmm::cuda_stream_view stream);

/**
* @brief Get each group's offset into the sorted order of `keys`.
Expand All @@ -169,13 +169,13 @@ struct sort_groupby_helper {
* @return vector of offsets of the starting point of each group in the sorted
* key table
*/
index_vector const& group_offsets(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
index_vector const& group_offsets(rmm::cuda_stream_view stream);

/**
* @brief Get the group labels corresponding to the sorted order of `keys`.
*
* Each group is assigned a unique numerical "label" in
* `[0, 1, 2, ... , num_groups() - 1, num_groups())`.
* `[0, 1, 2, ... , num_groups() - 1, num_groups(stream))`.
* For a row in sorted `keys`, its corresponding group label indicates which
* group it belongs to.
*
Expand All @@ -184,15 +184,15 @@ struct sort_groupby_helper {
*
* @return vector of group labels for each row in the sorted key column
*/
index_vector const& group_labels(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
index_vector const& group_labels(rmm::cuda_stream_view stream);

private:
/**
* @brief Get the group labels for unsorted keys
*
* Returns the group label for every row in the original `keys` table. For a
* given unique key row, its group label is equivalent to what is returned by
* `group_labels()`. However, if a row contains a null value, and
* `group_labels(stream)`. However, if a row contains a null value, and
* `include_null_keys == NO`, then its label is NULL.
*
* Computes and stores unsorted labels on first invocation and returns stored
Expand All @@ -201,7 +201,7 @@ struct sort_groupby_helper {
* @return A nullable column of `INT32` containing group labels in the order
* of the unsorted key table
*/
column_view unsorted_keys_labels(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
column_view unsorted_keys_labels(rmm::cuda_stream_view stream);

/**
* @brief Get the column representing the row bitmask for the `keys`
Expand All @@ -215,7 +215,7 @@ struct sort_groupby_helper {
* Computes and stores bitmask on first invocation and returns stored column
* on subsequent calls.
*/
column_view keys_bitmask_column(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
column_view keys_bitmask_column(rmm::cuda_stream_view stream);

private:
column_ptr _key_sorted_order; ///< Indices to produce _keys in sorted order
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/groupby/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ std::pair<std::unique_ptr<table>, std::vector<aggregation_result>> groupby::aggr

if (_keys.num_rows() == 0) { return std::make_pair(empty_like(_keys), empty_results(requests)); }

return dispatch_aggregation(requests, 0, mr);
return dispatch_aggregation(requests, rmm::cuda_stream_default, mr);
}

// Compute scan requests
Expand Down Expand Up @@ -190,7 +190,7 @@ groupby::groups groupby::get_groups(table_view values, rmm::mr::device_memory_re

if (values.num_columns()) {
auto grouped_values = cudf::detail::gather(values,
helper().key_sort_order(),
helper().key_sort_order(rmm::cuda_stream_default),
cudf::out_of_bounds_policy::DONT_CHECK,
cudf::detail::negative_index_policy::NOT_ALLOWED,
rmm::cuda_stream_default,
Expand Down
60 changes: 32 additions & 28 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ void aggregrate_result_functor::operator()<aggregation::COUNT_VALID>(aggregation
agg,
get_grouped_values().nullable()
? detail::group_count_valid(
get_grouped_values(), helper.group_labels(), helper.num_groups(), stream, mr)
: detail::group_count_all(helper.group_offsets(), helper.num_groups(), stream, mr));
get_grouped_values(), helper.group_labels(stream), helper.num_groups(stream), stream, mr)
: detail::group_count_all(
helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
}

template <>
Expand All @@ -80,18 +81,21 @@ void aggregrate_result_functor::operator()<aggregation::COUNT_ALL>(aggregation c
if (cache.has_result(col_idx, agg)) return;

cache.add_result(
col_idx, agg, detail::group_count_all(helper.group_offsets(), helper.num_groups(), stream, mr));
col_idx,
agg,
detail::group_count_all(helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
}

template <>
void aggregrate_result_functor::operator()<aggregation::SUM>(aggregation const& agg)
{
if (cache.has_result(col_idx, agg)) return;

cache.add_result(col_idx,
agg,
detail::group_sum(
get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr));
cache.add_result(
col_idx,
agg,
detail::group_sum(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
};

template <>
Expand All @@ -102,9 +106,9 @@ void aggregrate_result_functor::operator()<aggregation::ARGMAX>(aggregation cons
cache.add_result(col_idx,
agg,
detail::group_argmax(get_grouped_values(),
helper.num_groups(),
helper.group_labels(),
helper.key_sort_order(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr));
};
Expand All @@ -117,9 +121,9 @@ void aggregrate_result_functor::operator()<aggregation::ARGMIN>(aggregation cons
cache.add_result(col_idx,
agg,
detail::group_argmin(get_grouped_values(),
helper.num_groups(),
helper.group_labels(),
helper.key_sort_order(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr));
};
Expand All @@ -132,7 +136,7 @@ void aggregrate_result_functor::operator()<aggregation::MIN>(aggregation const&
auto result = [&]() {
if (cudf::is_fixed_width(values.type())) {
return detail::group_min(
get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr);
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr);
} else {
auto argmin_agg = make_argmin_aggregation();
operator()<aggregation::ARGMIN>(*argmin_agg);
Expand Down Expand Up @@ -169,7 +173,7 @@ void aggregrate_result_functor::operator()<aggregation::MAX>(aggregation const&
auto result = [&]() {
if (cudf::is_fixed_width(values.type())) {
return detail::group_max(
get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr);
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr);
} else {
auto argmax_agg = make_argmax_aggregation();
operator()<aggregation::ARGMAX>(*argmax_agg);
Expand Down Expand Up @@ -238,7 +242,7 @@ void aggregrate_result_functor::operator()<aggregation::VARIANCE>(aggregation co
auto result = detail::group_var(get_grouped_values(),
mean_result,
group_sizes,
helper.group_labels(),
helper.group_labels(stream),
var_agg._ddof,
stream,
mr);
Expand Down Expand Up @@ -271,8 +275,8 @@ void aggregrate_result_functor::operator()<aggregation::QUANTILE>(aggregation co

auto result = detail::group_quantiles(get_sorted_values(),
group_sizes,
helper.group_offsets(),
helper.num_groups(),
helper.group_offsets(stream),
helper.num_groups(stream),
quantile_agg._quantiles,
quantile_agg._interpolation,
stream,
Expand All @@ -291,8 +295,8 @@ void aggregrate_result_functor::operator()<aggregation::MEDIAN>(aggregation cons

auto result = detail::group_quantiles(get_sorted_values(),
group_sizes,
helper.group_offsets(),
helper.num_groups(),
helper.group_offsets(stream),
helper.num_groups(stream),
{0.5},
interpolation::LINEAR,
stream,
Expand All @@ -308,9 +312,9 @@ void aggregrate_result_functor::operator()<aggregation::NUNIQUE>(aggregation con
auto nunique_agg = static_cast<cudf::detail::nunique_aggregation const&>(agg);

auto result = detail::group_nunique(get_sorted_values(),
helper.group_labels(),
helper.num_groups(),
helper.group_offsets(),
helper.group_labels(stream),
helper.num_groups(stream),
helper.group_offsets(stream),
nunique_agg._null_handling,
stream,
mr);
Expand All @@ -337,9 +341,9 @@ void aggregrate_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation
agg,
detail::group_nth_element(get_grouped_values(),
group_sizes,
helper.group_labels(),
helper.group_offsets(),
helper.num_groups(),
helper.group_labels(stream),
helper.group_offsets(stream),
helper.num_groups(stream),
nth_element_agg._n,
nth_element_agg._null_handling,
stream,
Expand All @@ -357,7 +361,7 @@ void aggregrate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregatio
if (cache.has_result(col_idx, agg)) return;

auto result = detail::group_collect(
get_grouped_values(), helper.group_offsets(), helper.num_groups(), stream, mr);
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);

cache.add_result(col_idx, agg, std::move(result));
};
Expand All @@ -373,7 +377,7 @@ void aggregrate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
if (cache.has_result(col_idx, agg)) { return; }

auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(), helper.num_groups(), stream, mr);
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto const nulls_equal =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_equal;
cache.add_result(col_idx,
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/groupby/sort/functors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct store_result_functor {
// It's overridden in scan implementation.
return sorted_values->view();
else
return (grouped_values = helper.grouped_values(values))->view();
return (grouped_values = helper.grouped_values(values, stream))->view();
};

/**
Expand All @@ -76,7 +76,7 @@ struct store_result_functor {
column_view get_sorted_values()
{
return sorted_values ? sorted_values->view()
: (sorted_values = helper.sorted_values(values))->view();
: (sorted_values = helper.sorted_values(values, stream))->view();
};

protected:
Expand Down
13 changes: 8 additions & 5 deletions cpp/src/groupby/sort/scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct scan_result_functor final : store_result_functor {
if (grouped_values)
return grouped_values->view();
else
return (grouped_values = helper.grouped_values(values))->view();
return (grouped_values = helper.grouped_values(values, stream))->view();
};
};

Expand All @@ -71,7 +71,8 @@ void scan_result_functor::operator()<aggregation::SUM>(aggregation const& agg)
cache.add_result(
col_idx,
agg,
detail::sum_scan(get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr));
detail::sum_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
}

template <>
Expand All @@ -82,7 +83,8 @@ void scan_result_functor::operator()<aggregation::MIN>(aggregation const& agg)
cache.add_result(
col_idx,
agg,
detail::min_scan(get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr));
detail::min_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
}

template <>
Expand All @@ -93,15 +95,16 @@ void scan_result_functor::operator()<aggregation::MAX>(aggregation const& agg)
cache.add_result(
col_idx,
agg,
detail::max_scan(get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr));
detail::max_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
}

template <>
void scan_result_functor::operator()<aggregation::COUNT_ALL>(aggregation const& agg)
{
if (cache.has_result(col_idx, agg)) return;

cache.add_result(col_idx, agg, detail::count_scan(helper.group_labels(), stream, mr));
cache.add_result(col_idx, agg, detail::count_scan(helper.group_labels(stream), stream, mr));
}
} // namespace detail

Expand Down
Loading

0 comments on commit ccc28d5

Please sign in to comment.