Skip to content

Commit

Permalink
Simplify construction of the stride iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowe committed Feb 3, 2022
1 parent 6f381ba commit c214cf4
Showing 1 changed file with 10 additions and 33 deletions.
43 changes: 10 additions & 33 deletions java/src/main/native/src/aggregation128_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,13 @@
#include <cudf/utilities/error.hpp>
#include <rmm/exec_policy.hpp>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/transform_iterator.h>

#include "aggregation128_utils.hpp"

namespace {

// Iterate every 4th 32-bit value, i.e.: one "chunk" of a __int128_t value
class chunk_strided_range {
public:
typedef typename thrust::iterator_difference<int32_t const *>::type difference_type;

struct stride_functor : public thrust::unary_function<difference_type, difference_type> {
__device__ inline difference_type operator()(difference_type i) const { return i * 4; }
};

typedef typename thrust::counting_iterator<difference_type> CountingIterator;
typedef typename thrust::transform_iterator<stride_functor, CountingIterator> TransformIterator;
typedef typename thrust::permutation_iterator<int32_t const *, TransformIterator>
PermutationIterator;

typedef PermutationIterator iterator;

chunk_strided_range(int32_t const *start, int32_t const *finish, int chunk_idx)
: start(start + chunk_idx), finish(finish + chunk_idx) {}

iterator begin() const {
return PermutationIterator(start, TransformIterator(CountingIterator(0), stride_functor()));
}

iterator end() const { return begin() + ((finish - start) + 3) / 4; }

private:
int32_t const *start;
int32_t const *finish;
};

// Functor to reassemble a 128-bit value from four 64-bit chunks with overflow detection.
class chunk_assembler : public thrust::unary_function<cudf::size_type, __int128_t> {
public:
Expand Down Expand Up @@ -111,9 +83,14 @@ std::unique_ptr<cudf::column> extract_chunk32(cudf::column_view const &in_col, c
auto out_col = cudf::make_fixed_width_column(type, num_rows, copy_bitmask(in_col));
auto out_view = out_col->mutable_view();
auto const in_begin = in_col.begin<int32_t>();
auto const in_end = in_begin + in_col.size() * 4;
chunk_strided_range range(in_begin, in_end, chunk_idx);
thrust::copy(rmm::exec_policy(stream), range.begin(), range.end(), out_view.data<int32_t>());

// Build an iterator for every fourth 32-bit value, i.e.: one "chunk" of a __int128_t value
thrust::transform_iterator transform_iter{thrust::counting_iterator{0},
[] __device__(auto i) { return i * 4; }};
thrust::permutation_iterator stride_iter{in_begin + chunk_idx, transform_iter};

thrust::copy(rmm::exec_policy(stream), stride_iter, stride_iter + num_rows,
out_view.data<int32_t>());
return out_col;
}

Expand Down

0 comments on commit c214cf4

Please sign in to comment.