Skip to content

Commit

Permalink
Try to reverse sort_helper.cu
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed May 29, 2022
1 parent 74a33d4 commit 15d036a
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions cpp/src/groupby/sort/sort_helper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <cudf/detail/gather.hpp>
#include <cudf/detail/groupby/sort_helper.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/labeling/label_segments.cuh>
#include <cudf/detail/scatter.hpp>
#include <cudf/detail/sorting.hpp>
#include <cudf/detail/structs/utilities.hpp>
Expand All @@ -33,11 +32,17 @@
#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/binary_search.h>
#include <thrust/distance.h>
#include <thrust/fill.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/scan.h>
#include <thrust/scatter.h>
#include <thrust/sequence.h>
#include <thrust/uninitialized_fill.h>
#include <thrust/unique.h>

#include <algorithm>
Expand Down Expand Up @@ -218,13 +223,22 @@ sort_groupby_helper::index_vector const& sort_groupby_helper::group_labels(
_group_labels = std::make_unique<index_vector>(num_keys(stream), stream);

auto& group_labels = *_group_labels;

if (num_keys(stream) == 0) return group_labels;

cudf::detail::label_segments(group_offsets(stream).begin(),
group_offsets(stream).end(),
group_labels.begin(),
group_labels.end(),
stream);
thrust::uninitialized_fill(rmm::exec_policy(stream),
group_labels.begin(),
group_labels.end(),
index_vector::value_type{0});
thrust::scatter(rmm::exec_policy(stream),
thrust::make_constant_iterator(1, decltype(num_groups(stream))(1)),
thrust::make_constant_iterator(1, num_groups(stream)),
group_offsets(stream).begin() + 1,
group_labels.begin());

thrust::inclusive_scan(
rmm::exec_policy(stream), group_labels.begin(), group_labels.end(), group_labels.begin());

return group_labels;
}

Expand Down

0 comments on commit 15d036a

Please sign in to comment.