Skip to content

Commit

Permalink
Fix elements_with_stride_nd when the initial index is outside the extent
Browse files Browse the repository at this point in the history
Also:
  - add support for the AccCpuThreads acelerator
  - rename member variables for consistency
  - improve comments
  • Loading branch information
fwyzard committed Aug 13, 2023
1 parent f81c6e1 commit 938674c
Showing 1 changed file with 51 additions and 27 deletions.
78 changes: 51 additions & 27 deletions HeterogeneousCore/AlpakaInterface/interface/workdivision.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ namespace cms::alpakatools {
struct requires_single_thread_per_block<alpaka::AccGpuHipRt<TDim, Idx>> : public std::false_type {};
#endif // ALPAKA_ACC_GPU_HIP_ENABLED

#ifdef ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLED
template <typename TDim>
struct requires_single_thread_per_block<alpaka::AccCpuThreads<TDim, Idx>> : public std::false_type {};
#endif // ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLED

// Whether or not the accelerator expects the threads-per-block and elements-per-thread to be swapped
template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc>>>
inline constexpr bool requires_single_thread_per_block_v = requires_single_thread_per_block<TAcc>::value;
Expand Down Expand Up @@ -75,13 +80,13 @@ namespace cms::alpakatools {
public:
ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc)
: elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
extent_{stride_} {}

ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc, Idx extent)
: elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
extent_{extent} {}

Expand All @@ -94,7 +99,7 @@ namespace cms::alpakatools {
extent_{extent},
first_{std::min(first, extent)},
index_{first_},
last_{std::min(first + elements, extent)} {}
range_{std::min(first + elements, extent)} {}

public:
ALPAKA_FN_ACC inline Idx operator*() const { return index_; }
Expand All @@ -104,21 +109,21 @@ namespace cms::alpakatools {
if constexpr (requires_single_thread_per_block_v<TAcc>) {
// increment the index along the elements processed by the current thread
++index_;
if (index_ < last_)
if (index_ < range_)
return *this;
}

// increment the thread index with the grid stride
first_ += stride_;
index_ = first_;
last_ = std::min(first_ + elements_, extent_);
range_ = std::min(first_ + elements_, extent_);
if (index_ < extent_)
return *this;

// the iterator has reached or passed the end of the extent, clamp it to the extent
first_ = extent_;
index_ = extent_;
last_ = extent_;
range_ = extent_;
return *this;
}

Expand All @@ -143,16 +148,16 @@ namespace cms::alpakatools {
// modified by the pre/post-increment operator
Idx first_;
Idx index_;
Idx last_;
Idx range_;
};

ALPAKA_FN_ACC inline iterator begin() const { return iterator(elements_, stride_, extent_, first_); }
ALPAKA_FN_ACC inline iterator begin() const { return iterator(elements_, stride_, extent_, thread_); }

ALPAKA_FN_ACC inline iterator end() const { return iterator(elements_, stride_, extent_, extent_); }

private:
const Idx elements_;
const Idx first_;
const Idx thread_;
const Idx stride_;
const Idx extent_;
};
Expand All @@ -165,16 +170,19 @@ namespace cms::alpakatools {

ALPAKA_FN_ACC inline elements_with_stride_nd(TAcc const& acc)
: elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)},
first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) * elements_},
thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) * elements_},
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc) * elements_},
extent_{stride_} {}

ALPAKA_FN_ACC inline elements_with_stride_nd(TAcc const& acc, Vec extent)
: elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)},
first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) * elements_},
thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc) * elements_},
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc) * elements_},
extent_{extent} {}

// tag used to construct an end iterator
struct at_end_t {};

class iterator {
friend class elements_with_stride_nd;

Expand All @@ -199,19 +207,23 @@ namespace cms::alpakatools {
ALPAKA_FN_ACC constexpr inline bool operator!=(iterator const& other) const { return not(*this == other); }

private:
// private, explicit constructor
// construct an iterator pointing to the first element to be processed by the current thread
ALPAKA_FN_ACC inline iterator(elements_with_stride_nd const* loop, Vec first)
: loop_{loop},
thread_{alpaka::elementwise_min(first, loop->extent_)},
first_{alpaka::elementwise_min(first, loop->extent_)},
range_{alpaka::elementwise_min(first + loop->elements_, loop->extent_)},
index_{thread_} {}
index_{first_} {}

// construct an end iterator, pointing post the end of the extent
ALPAKA_FN_ACC inline iterator(elements_with_stride_nd const* loop, at_end_t const&)
: loop_{loop}, first_{loop_->extent_}, range_{loop_->extent_}, index_{loop_->extent_} {}

template <size_t I>
ALPAKA_FN_ACC inline constexpr bool nth_elements_loop() {
bool overflow = false;
++index_[I];
if (index_[I] >= range_[I]) {
index_[I] = thread_[I];
index_[I] = first_[I];
overflow = true;
}
return overflow;
Expand All @@ -234,13 +246,13 @@ namespace cms::alpakatools {
template <size_t I>
ALPAKA_FN_ACC inline constexpr bool nth_strided_loop() {
bool overflow = false;
thread_[I] += loop_->stride_[I];
if (thread_[I] >= loop_->extent_[I]) {
thread_[I] = loop_->first_[I];
first_[I] += loop_->stride_[I];
if (first_[I] >= loop_->extent_[I]) {
first_[I] = loop_->thread_[I];
overflow = true;
}
index_[I] = thread_[I];
range_[I] = std::min(thread_[I] + loop_->elements_[I], loop_->extent_[I]);
index_[I] = first_[I];
range_[I] = std::min(first_[I] + loop_->elements_[I], loop_->extent_[I]);
return overflow;
}

Expand Down Expand Up @@ -277,7 +289,7 @@ namespace cms::alpakatools {
}

// the iterator has reached or passed the end of the extent, clamp it to the extent
thread_ = loop_->extent_;
first_ = loop_->extent_;
range_ = loop_->extent_;
index_ = loop_->extent_;
}
Expand All @@ -286,18 +298,30 @@ namespace cms::alpakatools {
const elements_with_stride_nd* loop_;

// modified by the pre/post-increment operator
Vec thread_; // first element processed by this thread
Vec range_; // last element processed by this thread
Vec index_; // current element processed by this thread
Vec first_; // first element processed by this thread
Vec range_; // last element processed by this thread
Vec index_; // current element processed by this thread
};

ALPAKA_FN_ACC inline iterator begin() const { return iterator{this, first_}; }
ALPAKA_FN_ACC inline iterator begin() const {
// check that all dimensions of the current thread index are within the extent
if ((thread_ < extent_).all()) {
// construct an iterator pointing to the first element to be processed by the current thread
return iterator{this, thread_};
} else {
// construct an end iterator, pointing post the end of the extent
return iterator{this, at_end_t{}};
}
}

ALPAKA_FN_ACC inline iterator end() const { return iterator{this, extent_}; }
ALPAKA_FN_ACC inline iterator end() const {
// construct an end iterator, pointing post the end of the extent
return iterator{this, at_end_t{}};
}

private:
const Vec elements_;
const Vec first_;
const Vec thread_;
const Vec stride_;
const Vec extent_;
};
Expand Down

0 comments on commit 938674c

Please sign in to comment.