Skip to content

Commit

Permalink
fix compile times for rank
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala committed Feb 8, 2023
1 parent 53e918f commit 65e2bce
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions cpp/src/sort/rank.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@ namespace cudf {
namespace detail {
namespace {

template <typename PermutationIteratorType, typename DeviceComparatorType>
struct unique_functor {
unique_functor(PermutationIteratorType permute, DeviceComparatorType device_comparator)
: _permute(permute), _device_comparator(device_comparator)
{
}

auto __device__ operator()(size_type index)
{
return static_cast<size_type>(index == 0 ||
not _device_comparator(_permute[index], _permute[index - 1]));
}

PermutationIteratorType _permute;
DeviceComparatorType _device_comparator;
};

// Assign rank from 1 to n unique values. Equal values get same rank value.
rmm::device_uvector<size_type> sorted_dense_rank(column_view input_col,
column_view sorted_order_view,
Expand All @@ -62,33 +79,32 @@ rmm::device_uvector<size_type> sorted_dense_rank(column_view input_col,
auto const input_size = input_col.size();
rmm::device_uvector<size_type> dense_rank_sorted(input_size, stream);

auto const comparator_helper = [&](auto const device_comparator) {
thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(input_size),
dense_rank_sorted.data(),
unique_functor<decltype(sorted_index_order), decltype(device_comparator)>{
sorted_index_order, device_comparator});
};

if (cudf::detail::has_nested_columns(t_input)) {
auto const device_comparator =
comparator.equal_to<true>(nullate::DYNAMIC{has_nested_nulls(t_input)});

auto conv = [permute = sorted_index_order, device_comparator] __device__(size_type index) {
return static_cast<size_type>(index == 0 ||
not device_comparator(permute[index], permute[index - 1]));
};
auto const unique_it = cudf::detail::make_counting_transform_iterator(0, conv);

thrust::inclusive_scan(
rmm::exec_policy(stream), unique_it, unique_it + input_size, dense_rank_sorted.data());

comparator_helper(device_comparator);
} else {
auto const device_comparator =
comparator.equal_to<false>(nullate::DYNAMIC{has_nested_nulls(t_input)});

auto conv = [permute = sorted_index_order, device_comparator] __device__(size_type index) {
return static_cast<size_type>(index == 0 ||
not device_comparator(permute[index], permute[index - 1]));
};
auto const unique_it = cudf::detail::make_counting_transform_iterator(0, conv);

thrust::inclusive_scan(
rmm::exec_policy(stream), unique_it, unique_it + input_size, dense_rank_sorted.data());
comparator_helper(device_comparator);
}

thrust::inclusive_scan(rmm::exec_policy(stream),
dense_rank_sorted.begin(),
dense_rank_sorted.end(),
dense_rank_sorted.data());

return dense_rank_sorted;
}

Expand Down

0 comments on commit 65e2bce

Please sign in to comment.