Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use stream in groupby calls #7705

Merged
merged 10 commits into from
Mar 27, 2021
26 changes: 13 additions & 13 deletions cpp/include/cudf/detail/groupby/sort_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct sort_groupby_helper {
*/
std::unique_ptr<column> sorted_values(
column_view const& values,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -108,7 +108,7 @@ struct sort_groupby_helper {
*/
std::unique_ptr<column> grouped_values(
column_view const& values,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -117,7 +117,7 @@ struct sort_groupby_helper {
* @return a new table in which each row is a unique row in the sorted key table.
*/
std::unique_ptr<table> unique_keys(
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -126,13 +126,13 @@ struct sort_groupby_helper {
* @return a new table containing the sorted keys.
*/
std::unique_ptr<table> sorted_keys(
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Get the number of groups in `keys`
*/
size_type num_groups() { return group_offsets().size() - 1; }
size_type num_groups(rmm::cuda_stream_view stream) { return group_offsets(stream).size() - 1; }

/**
* @brief Return the effective number of keys
Expand All @@ -141,7 +141,7 @@ struct sort_groupby_helper {
* When include_null_keys = NO, returned value is the number of rows in `keys`
* in which no element is null
*/
size_type num_keys(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
size_type num_keys(rmm::cuda_stream_view stream);

/**
* @brief Get the sorted order of `keys`.
Expand All @@ -156,7 +156,7 @@ struct sort_groupby_helper {
*
* @return the sort order indices for `keys`.
*/
column_view key_sort_order(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
column_view key_sort_order(rmm::cuda_stream_view stream);

/**
* @brief Get each group's offset into the sorted order of `keys`.
Expand All @@ -169,13 +169,13 @@ struct sort_groupby_helper {
* @return vector of offsets of the starting point of each group in the sorted
* key table
*/
index_vector const& group_offsets(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
index_vector const& group_offsets(rmm::cuda_stream_view stream);

/**
* @brief Get the group labels corresponding to the sorted order of `keys`.
*
* Each group is assigned a unique numerical "label" in
* `[0, 1, 2, ... , num_groups() - 1, num_groups())`.
* `[0, 1, 2, ... , num_groups() - 1, num_groups(stream))`.
* For a row in sorted `keys`, its corresponding group label indicates which
* group it belongs to.
*
Expand All @@ -184,15 +184,15 @@ struct sort_groupby_helper {
*
* @return vector of group labels for each row in the sorted key column
*/
index_vector const& group_labels(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
index_vector const& group_labels(rmm::cuda_stream_view stream);

private:
/**
* @brief Get the group labels for unsorted keys
*
* Returns the group label for every row in the original `keys` table. For a
* given unique key row, its group label is equivalent to what is returned by
* `group_labels()`. However, if a row contains a null value, and
* `group_labels(stream)`. However, if a row contains a null value, and
* `include_null_keys == NO`, then its label is NULL.
*
* Computes and stores unsorted labels on first invocation and returns stored
Expand All @@ -201,7 +201,7 @@ struct sort_groupby_helper {
* @return A nullable column of `INT32` containing group labels in the order
* of the unsorted key table
*/
column_view unsorted_keys_labels(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
column_view unsorted_keys_labels(rmm::cuda_stream_view stream);

/**
* @brief Get the column representing the row bitmask for the `keys`
Expand All @@ -215,7 +215,7 @@ struct sort_groupby_helper {
* Computes and stores bitmask on first invocation and returns stored column
* on subsequent calls.
*/
column_view keys_bitmask_column(rmm::cuda_stream_view stream = rmm::cuda_stream_default);
column_view keys_bitmask_column(rmm::cuda_stream_view stream);

private:
column_ptr _key_sorted_order; ///< Indices to produce _keys in sorted order
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/groupby/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ std::pair<std::unique_ptr<table>, std::vector<aggregation_result>> groupby::aggr

if (_keys.num_rows() == 0) { return std::make_pair(empty_like(_keys), empty_results(requests)); }

return dispatch_aggregation(requests, 0, mr);
return dispatch_aggregation(requests, rmm::cuda_stream_default, mr);
}

// Compute scan requests
Expand Down Expand Up @@ -190,7 +190,7 @@ groupby::groups groupby::get_groups(table_view values, rmm::mr::device_memory_re

if (values.num_columns()) {
auto grouped_values = cudf::detail::gather(values,
helper().key_sort_order(),
helper().key_sort_order(rmm::cuda_stream_default),
cudf::out_of_bounds_policy::DONT_CHECK,
cudf::detail::negative_index_policy::NOT_ALLOWED,
rmm::cuda_stream_default,
Expand Down
60 changes: 32 additions & 28 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ void aggregrate_result_functor::operator()<aggregation::COUNT_VALID>(aggregation
agg,
get_grouped_values().nullable()
? detail::group_count_valid(
get_grouped_values(), helper.group_labels(), helper.num_groups(), stream, mr)
: detail::group_count_all(helper.group_offsets(), helper.num_groups(), stream, mr));
get_grouped_values(), helper.group_labels(stream), helper.num_groups(stream), stream, mr)
: detail::group_count_all(
helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
}

template <>
Expand All @@ -80,18 +81,21 @@ void aggregrate_result_functor::operator()<aggregation::COUNT_ALL>(aggregation c
if (cache.has_result(col_idx, agg)) return;

cache.add_result(
col_idx, agg, detail::group_count_all(helper.group_offsets(), helper.num_groups(), stream, mr));
col_idx,
agg,
detail::group_count_all(helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
}

template <>
void aggregrate_result_functor::operator()<aggregation::SUM>(aggregation const& agg)
{
if (cache.has_result(col_idx, agg)) return;

cache.add_result(col_idx,
agg,
detail::group_sum(
get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr));
cache.add_result(
col_idx,
agg,
detail::group_sum(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
};

template <>
Expand All @@ -102,9 +106,9 @@ void aggregrate_result_functor::operator()<aggregation::ARGMAX>(aggregation cons
cache.add_result(col_idx,
agg,
detail::group_argmax(get_grouped_values(),
helper.num_groups(),
helper.group_labels(),
helper.key_sort_order(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr));
};
Expand All @@ -117,9 +121,9 @@ void aggregrate_result_functor::operator()<aggregation::ARGMIN>(aggregation cons
cache.add_result(col_idx,
agg,
detail::group_argmin(get_grouped_values(),
helper.num_groups(),
helper.group_labels(),
helper.key_sort_order(),
helper.num_groups(stream),
helper.group_labels(stream),
helper.key_sort_order(stream),
stream,
mr));
};
Expand All @@ -132,7 +136,7 @@ void aggregrate_result_functor::operator()<aggregation::MIN>(aggregation const&
auto result = [&]() {
if (cudf::is_fixed_width(values.type())) {
return detail::group_min(
get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr);
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr);
} else {
auto argmin_agg = make_argmin_aggregation();
operator()<aggregation::ARGMIN>(*argmin_agg);
Expand Down Expand Up @@ -169,7 +173,7 @@ void aggregrate_result_functor::operator()<aggregation::MAX>(aggregation const&
auto result = [&]() {
if (cudf::is_fixed_width(values.type())) {
return detail::group_max(
get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr);
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr);
} else {
auto argmax_agg = make_argmax_aggregation();
operator()<aggregation::ARGMAX>(*argmax_agg);
Expand Down Expand Up @@ -238,7 +242,7 @@ void aggregrate_result_functor::operator()<aggregation::VARIANCE>(aggregation co
auto result = detail::group_var(get_grouped_values(),
mean_result,
group_sizes,
helper.group_labels(),
helper.group_labels(stream),
var_agg._ddof,
stream,
mr);
Expand Down Expand Up @@ -271,8 +275,8 @@ void aggregrate_result_functor::operator()<aggregation::QUANTILE>(aggregation co

auto result = detail::group_quantiles(get_sorted_values(),
group_sizes,
helper.group_offsets(),
helper.num_groups(),
helper.group_offsets(stream),
helper.num_groups(stream),
quantile_agg._quantiles,
quantile_agg._interpolation,
stream,
Expand All @@ -291,8 +295,8 @@ void aggregrate_result_functor::operator()<aggregation::MEDIAN>(aggregation cons

auto result = detail::group_quantiles(get_sorted_values(),
group_sizes,
helper.group_offsets(),
helper.num_groups(),
helper.group_offsets(stream),
helper.num_groups(stream),
{0.5},
interpolation::LINEAR,
stream,
Expand All @@ -308,9 +312,9 @@ void aggregrate_result_functor::operator()<aggregation::NUNIQUE>(aggregation con
auto nunique_agg = static_cast<cudf::detail::nunique_aggregation const&>(agg);

auto result = detail::group_nunique(get_sorted_values(),
helper.group_labels(),
helper.num_groups(),
helper.group_offsets(),
helper.group_labels(stream),
helper.num_groups(stream),
helper.group_offsets(stream),
nunique_agg._null_handling,
stream,
mr);
Expand All @@ -337,9 +341,9 @@ void aggregrate_result_functor::operator()<aggregation::NTH_ELEMENT>(aggregation
agg,
detail::group_nth_element(get_grouped_values(),
group_sizes,
helper.group_labels(),
helper.group_offsets(),
helper.num_groups(),
helper.group_labels(stream),
helper.group_offsets(stream),
helper.num_groups(stream),
nth_element_agg._n,
nth_element_agg._null_handling,
stream,
Expand All @@ -357,7 +361,7 @@ void aggregrate_result_functor::operator()<aggregation::COLLECT_LIST>(aggregatio
if (cache.has_result(col_idx, agg)) return;

auto result = detail::group_collect(
get_grouped_values(), helper.group_offsets(), helper.num_groups(), stream, mr);
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);

cache.add_result(col_idx, agg, std::move(result));
};
Expand All @@ -373,7 +377,7 @@ void aggregrate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
if (cache.has_result(col_idx, agg)) { return; }

auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(), helper.num_groups(), stream, mr);
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto const nulls_equal =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_equal;
cache.add_result(col_idx,
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/groupby/sort/functors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct store_result_functor {
// It's overridden in scan implementation.
return sorted_values->view();
else
return (grouped_values = helper.grouped_values(values))->view();
return (grouped_values = helper.grouped_values(values, stream))->view();
};

/**
Expand All @@ -76,7 +76,7 @@ struct store_result_functor {
column_view get_sorted_values()
{
return sorted_values ? sorted_values->view()
: (sorted_values = helper.sorted_values(values))->view();
: (sorted_values = helper.sorted_values(values, stream))->view();
};

protected:
Expand Down
13 changes: 8 additions & 5 deletions cpp/src/groupby/sort/scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ struct scan_result_functor final : store_result_functor {
if (grouped_values)
return grouped_values->view();
else
return (grouped_values = helper.grouped_values(values))->view();
return (grouped_values = helper.grouped_values(values, stream))->view();
};
};

Expand All @@ -71,7 +71,8 @@ void scan_result_functor::operator()<aggregation::SUM>(aggregation const& agg)
cache.add_result(
col_idx,
agg,
detail::sum_scan(get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr));
detail::sum_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
}

template <>
Expand All @@ -82,7 +83,8 @@ void scan_result_functor::operator()<aggregation::MIN>(aggregation const& agg)
cache.add_result(
col_idx,
agg,
detail::min_scan(get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr));
detail::min_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
}

template <>
Expand All @@ -93,15 +95,16 @@ void scan_result_functor::operator()<aggregation::MAX>(aggregation const& agg)
cache.add_result(
col_idx,
agg,
detail::max_scan(get_grouped_values(), helper.num_groups(), helper.group_labels(), stream, mr));
detail::max_scan(
get_grouped_values(), helper.num_groups(stream), helper.group_labels(stream), stream, mr));
}

template <>
void scan_result_functor::operator()<aggregation::COUNT_ALL>(aggregation const& agg)
{
if (cache.has_result(col_idx, agg)) return;

cache.add_result(col_idx, agg, detail::count_scan(helper.group_labels(), stream, mr));
cache.add_result(col_idx, agg, detail::count_scan(helper.group_labels(stream), stream, mr));
}
} // namespace detail

Expand Down
Loading