From 65e2bce1df185d015f9debdbf1d63841c99341c7 Mon Sep 17 00:00:00 2001 From: divyegala Date: Wed, 8 Feb 2023 11:30:26 -0800 Subject: [PATCH] fix compile times for rank --- cpp/src/sort/rank.cu | 50 +++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/cpp/src/sort/rank.cu b/cpp/src/sort/rank.cu index 5045878b8cf..b3c8da9d7d7 100644 --- a/cpp/src/sort/rank.cu +++ b/cpp/src/sort/rank.cu @@ -48,6 +48,23 @@ namespace cudf { namespace detail { namespace { +template +struct unique_functor { + unique_functor(PermutationIteratorType permute, DeviceComparatorType device_comparator) + : _permute(permute), _device_comparator(device_comparator) + { + } + + auto __device__ operator()(size_type index) + { + return static_cast(index == 0 || + not _device_comparator(_permute[index], _permute[index - 1])); + } + + PermutationIteratorType _permute; + DeviceComparatorType _device_comparator; +}; + // Assign rank from 1 to n unique values. Equal values get same rank value. rmm::device_uvector sorted_dense_rank(column_view input_col, column_view sorted_order_view, @@ -62,33 +79,32 @@ rmm::device_uvector sorted_dense_rank(column_view input_col, auto const input_size = input_col.size(); rmm::device_uvector dense_rank_sorted(input_size, stream); + auto const comparator_helper = [&](auto const device_comparator) { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input_size), + dense_rank_sorted.data(), + unique_functor{ + sorted_index_order, device_comparator}); + }; + if (cudf::detail::has_nested_columns(t_input)) { auto const device_comparator = comparator.equal_to(nullate::DYNAMIC{has_nested_nulls(t_input)}); - auto conv = [permute = sorted_index_order, device_comparator] __device__(size_type index) { - return static_cast(index == 0 || - not device_comparator(permute[index], permute[index - 1])); - }; - auto const unique_it = cudf::detail::make_counting_transform_iterator(0, conv); - - thrust::inclusive_scan( - rmm::exec_policy(stream), unique_it, unique_it + input_size, dense_rank_sorted.data()); - + comparator_helper(device_comparator); } else { auto const device_comparator = comparator.equal_to(nullate::DYNAMIC{has_nested_nulls(t_input)}); - auto conv = [permute = sorted_index_order, device_comparator] __device__(size_type index) { - return static_cast(index == 0 || - not device_comparator(permute[index], permute[index - 1])); - }; - auto const unique_it = cudf::detail::make_counting_transform_iterator(0, conv); - - thrust::inclusive_scan( - rmm::exec_policy(stream), unique_it, unique_it + input_size, dense_rank_sorted.data()); + comparator_helper(device_comparator); } + thrust::inclusive_scan(rmm::exec_policy(stream), + dense_rank_sorted.begin(), + dense_rank_sorted.end(), + dense_rank_sorted.data()); + return dense_rank_sorted; }