diff --git a/HeterogeneousCore/AlpakaInterface/interface/workdivision.h b/HeterogeneousCore/AlpakaInterface/interface/workdivision.h index 0d295855976da..7449bb153c9f7 100644 --- a/HeterogeneousCore/AlpakaInterface/interface/workdivision.h +++ b/HeterogeneousCore/AlpakaInterface/interface/workdivision.h @@ -32,6 +32,11 @@ namespace cms::alpakatools { struct requires_single_thread_per_block> : public std::false_type {}; #endif // ALPAKA_ACC_GPU_HIP_ENABLED +#ifdef ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLED + template + struct requires_single_thread_per_block> : public std::false_type {}; +#endif // ALPAKA_ACC_CPU_B_SEQ_T_THREADS_ENABLED + // Whether or not the accelerator expects the threads-per-block and elements-per-thread to be swapped template >> inline constexpr bool requires_single_thread_per_block_v = requires_single_thread_per_block::value; @@ -75,13 +80,13 @@ namespace cms::alpakatools { public: ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc) : elements_{alpaka::getWorkDiv(acc)[0u]}, - first_{alpaka::getIdx(acc)[0u] * elements_}, + thread_{alpaka::getIdx(acc)[0u] * elements_}, stride_{alpaka::getWorkDiv(acc)[0u] * elements_}, extent_{stride_} {} ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc, Idx extent) : elements_{alpaka::getWorkDiv(acc)[0u]}, - first_{alpaka::getIdx(acc)[0u] * elements_}, + thread_{alpaka::getIdx(acc)[0u] * elements_}, stride_{alpaka::getWorkDiv(acc)[0u] * elements_}, extent_{extent} {} @@ -94,7 +99,7 @@ namespace cms::alpakatools { extent_{extent}, first_{std::min(first, extent)}, index_{first_}, - last_{std::min(first + elements, extent)} {} + range_{std::min(first + elements, extent)} {} public: ALPAKA_FN_ACC inline Idx operator*() const { return index_; } @@ -104,21 +109,21 @@ namespace cms::alpakatools { if constexpr (requires_single_thread_per_block_v) { // increment the index along the elements processed by the current thread ++index_; - if (index_ < last_) + if (index_ < range_) return *this; } // increment the thread index with the grid stride first_ += stride_; index_ = first_; - last_ = std::min(first_ + elements_, extent_); + range_ = std::min(first_ + elements_, extent_); if (index_ < extent_) return *this; // the iterator has reached or passed the end of the extent, clamp it to the extent first_ = extent_; index_ = extent_; - last_ = extent_; + range_ = extent_; return *this; } @@ -143,16 +148,16 @@ namespace cms::alpakatools { // modified by the pre/post-increment operator Idx first_; Idx index_; - Idx last_; + Idx range_; }; - ALPAKA_FN_ACC inline iterator begin() const { return iterator(elements_, stride_, extent_, first_); } + ALPAKA_FN_ACC inline iterator begin() const { return iterator(elements_, stride_, extent_, thread_); } ALPAKA_FN_ACC inline iterator end() const { return iterator(elements_, stride_, extent_, extent_); } private: const Idx elements_; - const Idx first_; + const Idx thread_; const Idx stride_; const Idx extent_; }; @@ -165,16 +170,19 @@ namespace cms::alpakatools { ALPAKA_FN_ACC inline elements_with_stride_nd(TAcc const& acc) : elements_{alpaka::getWorkDiv(acc)}, - first_{alpaka::getIdx(acc) * elements_}, + thread_{alpaka::getIdx(acc) * elements_}, stride_{alpaka::getWorkDiv(acc) * elements_}, extent_{stride_} {} ALPAKA_FN_ACC inline elements_with_stride_nd(TAcc const& acc, Vec extent) : elements_{alpaka::getWorkDiv(acc)}, - first_{alpaka::getIdx(acc) * elements_}, + thread_{alpaka::getIdx(acc) * elements_}, stride_{alpaka::getWorkDiv(acc) * elements_}, extent_{extent} {} + // tag used to construct an end iterator + struct at_end_t {}; + class iterator { friend class elements_with_stride_nd; @@ -199,19 +207,23 @@ namespace cms::alpakatools { ALPAKA_FN_ACC constexpr inline bool operator!=(iterator const& other) const { return not(*this == other); } private: - // private, explicit constructor + // construct an iterator pointing to the first element to be processed by the current thread ALPAKA_FN_ACC inline iterator(elements_with_stride_nd const* loop, Vec first) : loop_{loop}, - thread_{alpaka::elementwise_min(first, loop->extent_)}, + first_{alpaka::elementwise_min(first, loop->extent_)}, range_{alpaka::elementwise_min(first + loop->elements_, loop->extent_)}, - index_{thread_} {} + index_{first_} {} + + // construct an end iterator, pointing post the end of the extent + ALPAKA_FN_ACC inline iterator(elements_with_stride_nd const* loop, at_end_t const&) + : loop_{loop}, first_{loop_->extent_}, range_{loop_->extent_}, index_{loop_->extent_} {} template ALPAKA_FN_ACC inline constexpr bool nth_elements_loop() { bool overflow = false; ++index_[I]; if (index_[I] >= range_[I]) { - index_[I] = thread_[I]; + index_[I] = first_[I]; overflow = true; } return overflow; @@ -234,13 +246,13 @@ namespace cms::alpakatools { template ALPAKA_FN_ACC inline constexpr bool nth_strided_loop() { bool overflow = false; - thread_[I] += loop_->stride_[I]; - if (thread_[I] >= loop_->extent_[I]) { - thread_[I] = loop_->first_[I]; + first_[I] += loop_->stride_[I]; + if (first_[I] >= loop_->extent_[I]) { + first_[I] = loop_->thread_[I]; overflow = true; } - index_[I] = thread_[I]; - range_[I] = std::min(thread_[I] + loop_->elements_[I], loop_->extent_[I]); + index_[I] = first_[I]; + range_[I] = std::min(first_[I] + loop_->elements_[I], loop_->extent_[I]); return overflow; } @@ -277,7 +289,7 @@ namespace cms::alpakatools { } // the iterator has reached or passed the end of the extent, clamp it to the extent - thread_ = loop_->extent_; + first_ = loop_->extent_; range_ = loop_->extent_; index_ = loop_->extent_; } @@ -286,18 +298,30 @@ namespace cms::alpakatools { const elements_with_stride_nd* loop_; // modified by the pre/post-increment operator - Vec thread_; // first element processed by this thread - Vec range_; // last element processed by this thread - Vec index_; // current element processed by this thread + Vec first_; // first element processed by this thread + Vec range_; // last element processed by this thread + Vec index_; // current element processed by this thread }; - ALPAKA_FN_ACC inline iterator begin() const { return iterator{this, first_}; } + ALPAKA_FN_ACC inline iterator begin() const { + // check that all dimensions of the current thread index are within the extent + if ((thread_ < extent_).all()) { + // construct an iterator pointing to the first element to be processed by the current thread + return iterator{this, thread_}; + } else { + // construct an end iterator, pointing post the end of the extent + return iterator{this, at_end_t{}}; + } + } - ALPAKA_FN_ACC inline iterator end() const { return iterator{this, extent_}; } + ALPAKA_FN_ACC inline iterator end() const { + // construct an end iterator, pointing post the end of the extent + return iterator{this, at_end_t{}}; + } private: const Vec elements_; - const Vec first_; + const Vec thread_; const Vec stride_; const Vec extent_; };