Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix groupby argmin/max gather of sorted-order indices #17591

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,7 @@ void aggregate_result_functor::operator()<aggregation::MIN>(aggregation const& a
operator()<aggregation::ARGMIN>(*argmin_agg);
column_view const argmin_result = cache.get_result(values, *argmin_agg);

// We make a view of ARGMIN result without a null mask and gather using
// this mask. The values in data buffer of ARGMIN result corresponding
// to null values was initialized to ARGMIN_SENTINEL which is an out of
// bounds index value and causes the gathered value to be null.
// Compute the ARGMIN result without the null mask in the gather map.
column_view const null_removed_map(
data_type(type_to_id<size_type>()),
argmin_result.size(),
Expand Down Expand Up @@ -250,10 +247,7 @@ void aggregate_result_functor::operator()<aggregation::MAX>(aggregation const& a
operator()<aggregation::ARGMAX>(*argmax_agg);
column_view const argmax_result = cache.get_result(values, *argmax_agg);

// We make a view of ARGMAX result without a null mask and gather using
// this mask. The values in data buffer of ARGMAX result corresponding
// to null values was initialized to ARGMAX_SENTINEL which is an out of
// bounds index value and causes the gathered value to be null.
// Compute the ARGMAX result without the null mask in the gather map.
column_view const null_removed_map(
data_type(type_to_id<size_type>()),
argmax_result.size(),
Expand Down
31 changes: 15 additions & 16 deletions cpp/src/groupby/sort/group_argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,21 @@ std::unique_ptr<column> group_argmax(column_view const& values,
stream,
mr);

// The functor returns the index of maximum in the sorted values.
// We need the index of maximum in the original unsorted values.
// So use indices to gather the sort order used to sort `values`.
// Gather map cannot be null so we make a view with the mask removed.
// The values in data buffer of indices corresponding to null values was
// initialized to ARGMAX_SENTINEL. Using gather_if.
// This can't use gather because nulls in gathered column will not store ARGMAX_SENTINEL.
auto indices_view = indices->mutable_view();
thrust::gather_if(rmm::exec_policy(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
indices_view.begin<size_type>(), // stencil
key_sort_order.begin<size_type>(), // input
indices_view.begin<size_type>(), // result
[] __device__(auto i) { return (i != cudf::detail::ARGMAX_SENTINEL); });
return indices;
// The functor returns the indices of minimums based on the sorted keys.
// We need the indices of minimums from the original unsorted keys
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
// so we use these indices and the key_sort_order to map to the correct indices.
// We do not use cudf::gather since we can move the null-mask separately.
auto indices_view = indices->view();
auto output = rmm::device_uvector<size_type>(indices_view.size(), stream, mr);
thrust::gather(rmm::exec_policy_nosync(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
key_sort_order.begin<size_type>(), // input
output.data() // result (most not overlap map)
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
);
auto null_count = indices_view.null_count();
auto null_mask = indices->release().null_mask.release();
return std::make_unique<column>(std::move(output), std::move(*null_mask), null_count);
}

} // namespace detail
Expand Down
32 changes: 16 additions & 16 deletions cpp/src/groupby/sort/group_argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cudf/utilities/span.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>

#include <thrust/gather.h>

Expand All @@ -42,22 +43,21 @@ std::unique_ptr<column> group_argmin(column_view const& values,
stream,
mr);

// The functor returns the index of minimum in the sorted values.
// We need the index of minimum in the original unsorted values.
// So use indices to gather the sort order used to sort `values`.
// The values in data buffer of indices corresponding to null values was
// initialized to ARGMIN_SENTINEL. Using gather_if.
// This can't use gather because nulls in gathered column will not store ARGMIN_SENTINEL.
auto indices_view = indices->mutable_view();
thrust::gather_if(rmm::exec_policy(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
indices_view.begin<size_type>(), // stencil
key_sort_order.begin<size_type>(), // input
indices_view.begin<size_type>(), // result
[] __device__(auto i) { return (i != cudf::detail::ARGMIN_SENTINEL); });

return indices;
// The functor returns the indices of minimums based on the sorted keys.
// We need the indices of minimums from the original unsorted keys
// so we use these and the key_sort_order to map to the correct indices.
// We do not use cudf::gather since we can move the null-mask separately.
auto indices_view = indices->view();
auto output = rmm::device_uvector<size_type>(indices_view.size(), stream, mr);
thrust::gather(rmm::exec_policy_nosync(stream),
indices_view.begin<size_type>(), // map first
indices_view.end<size_type>(), // map last
key_sort_order.begin<size_type>(), // input
output.data() // result (most not overlap map)
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
);
auto null_count = indices_view.null_count();
auto null_mask = indices->release().null_mask.release();
return std::make_unique<column>(std::move(output), std::move(*null_mask), null_count);
}

} // namespace detail
Expand Down
Loading