From c214cf43cc0a9f2548afdbab0cc5e777de0134ad Mon Sep 17 00:00:00 2001 From: Jason Lowe Date: Thu, 3 Feb 2022 14:33:12 -0600 Subject: [PATCH] Simplify construction of the stride iterator --- .../main/native/src/aggregation128_utils.cu | 43 +++++-------------- 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/java/src/main/native/src/aggregation128_utils.cu b/java/src/main/native/src/aggregation128_utils.cu index 4bd428dab03..865f607ff7d 100644 --- a/java/src/main/native/src/aggregation128_utils.cu +++ b/java/src/main/native/src/aggregation128_utils.cu @@ -23,41 +23,13 @@ #include #include #include +#include +#include #include "aggregation128_utils.hpp" namespace { -// Iterate every 4th 32-bit value, i.e.: one "chunk" of a __int128_t value -class chunk_strided_range { -public: - typedef typename thrust::iterator_difference::type difference_type; - - struct stride_functor : public thrust::unary_function { - __device__ inline difference_type operator()(difference_type i) const { return i * 4; } - }; - - typedef typename thrust::counting_iterator CountingIterator; - typedef typename thrust::transform_iterator TransformIterator; - typedef typename thrust::permutation_iterator - PermutationIterator; - - typedef PermutationIterator iterator; - - chunk_strided_range(int32_t const *start, int32_t const *finish, int chunk_idx) - : start(start + chunk_idx), finish(finish + chunk_idx) {} - - iterator begin() const { - return PermutationIterator(start, TransformIterator(CountingIterator(0), stride_functor())); - } - - iterator end() const { return begin() + ((finish - start) + 3) / 4; } - -private: - int32_t const *start; - int32_t const *finish; -}; - // Functor to reassemble a 128-bit value from four 64-bit chunks with overflow detection. class chunk_assembler : public thrust::unary_function { public: @@ -111,9 +83,14 @@ std::unique_ptr extract_chunk32(cudf::column_view const &in_col, c auto out_col = cudf::make_fixed_width_column(type, num_rows, copy_bitmask(in_col)); auto out_view = out_col->mutable_view(); auto const in_begin = in_col.begin(); - auto const in_end = in_begin + in_col.size() * 4; - chunk_strided_range range(in_begin, in_end, chunk_idx); - thrust::copy(rmm::exec_policy(stream), range.begin(), range.end(), out_view.data()); + + // Build an iterator for every fourth 32-bit value, i.e.: one "chunk" of a __int128_t value + thrust::transform_iterator transform_iter{thrust::counting_iterator{0}, + [] __device__(auto i) { return i * 4; }}; + thrust::permutation_iterator stride_iter{in_begin + chunk_idx, transform_iter}; + + thrust::copy(rmm::exec_policy(stream), stride_iter, stride_iter + num_rows, + out_view.data()); return out_col; }