Skip to content

Commit

Permalink
Pass lamdas by && in helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ghugo83 committed Mar 29, 2021
1 parent 48d4d15 commit 36f8c3d
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions src/alpaka/AlpakaCore/alpakaWorkDivHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ namespace cms {
ALPAKA_FN_ACC void for_each_element_in_thread_1D_index_in_block(const T_Acc& acc,
const uint32_t maxNumberOfElements,
const uint32_t elementIdxShift,
const Func func) {
Func&& func) {
const auto& [firstElementIdx, endElementIdx] = cms::alpakatools::element_index_range_in_block_truncated(
acc, Vec1::all(maxNumberOfElements), Vec1::all(elementIdxShift));

for (uint32_t elementIdx = firstElementIdx[0u]; elementIdx < endElementIdx[0u]; ++elementIdx) {
func(elementIdx);
std::forward<Func>(func)(elementIdx);
}
}

Expand All @@ -171,9 +171,10 @@ namespace cms {
template <typename T_Acc, typename Func>
ALPAKA_FN_ACC void for_each_element_in_thread_1D_index_in_block(const T_Acc& acc,
const uint32_t maxNumberOfElements,
const Func func) {
Func&& func) {
const uint32_t elementIdxShift = 0;
cms::alpakatools::for_each_element_in_thread_1D_index_in_block(acc, maxNumberOfElements, elementIdxShift, func);
cms::alpakatools::for_each_element_in_thread_1D_index_in_block(
acc, maxNumberOfElements, elementIdxShift, std::forward<Func>(func));
}

/*
Expand All @@ -185,13 +186,13 @@ namespace cms {
ALPAKA_FN_ACC void for_each_element_in_thread_1D_index_in_grid(const T_Acc& acc,
const uint32_t maxNumberOfElements,
uint32_t elementIdxShift,
const Func func) {
Func&& func) {
// Take into account the block index in grid to compute the element indices.
const uint32_t blockIdxInGrid(alpaka::idx::getIdx<alpaka::Grid, alpaka::Blocks>(acc)[0u]);
const uint32_t blockDimension(alpaka::workdiv::getWorkDiv<alpaka::Block, alpaka::Elems>(acc)[0u]);
elementIdxShift += blockIdxInGrid * blockDimension;

for_each_element_in_thread_1D_index_in_block(acc, maxNumberOfElements, elementIdxShift, func);
for_each_element_in_thread_1D_index_in_block(acc, maxNumberOfElements, elementIdxShift, std::forward<Func>(func));
}

/*
Expand All @@ -200,9 +201,10 @@ namespace cms {
template <typename T_Acc, typename Func>
ALPAKA_FN_ACC void for_each_element_in_thread_1D_index_in_grid(const T_Acc& acc,
const uint32_t maxNumberOfElements,
const Func func) {
Func&& func) {
const uint32_t elementIdxShift = 0;
cms::alpakatools::for_each_element_in_thread_1D_index_in_grid(acc, maxNumberOfElements, elementIdxShift, func);
cms::alpakatools::for_each_element_in_thread_1D_index_in_grid(
acc, maxNumberOfElements, elementIdxShift, std::forward<Func>(func));
}

/******************************************************************************
Expand All @@ -219,7 +221,7 @@ namespace cms {
ALPAKA_FN_ACC void for_each_element_1D_block_stride(const T_Acc& acc,
const uint32_t maxNumberOfElements,
const uint32_t elementIdxShift,
const Func func) {
Func&& func) {
// Get thread / element indices in block.
const auto& [firstElementIdxNoStride, endElementIdxNoStride] =
cms::alpakatools::element_index_range_in_block(acc, Vec1::all(elementIdxShift));
Expand All @@ -233,7 +235,7 @@ namespace cms {
threadIdx += blockDimension, endElementIdx += blockDimension) {
// (CPU) Loop on all elements.
for (uint32_t i = threadIdx; i < std::min(endElementIdx, maxNumberOfElements); ++i) {
func(i);
std::forward<Func>(func)(i);
}
}
}
Expand All @@ -244,9 +246,10 @@ namespace cms {
template <typename T_Acc, typename Func>
ALPAKA_FN_ACC void for_each_element_1D_block_stride(const T_Acc& acc,
const uint32_t maxNumberOfElements,
const Func func) {
Func&& func) {
const uint32_t elementIdxShift = 0;
cms::alpakatools::for_each_element_1D_block_stride(acc, maxNumberOfElements, elementIdxShift, func);
cms::alpakatools::for_each_element_1D_block_stride(
acc, maxNumberOfElements, elementIdxShift, std::forward<Func>(func));
}

/*
Expand All @@ -259,7 +262,7 @@ namespace cms {
ALPAKA_FN_ACC void for_each_element_1D_grid_stride(const T_Acc& acc,
const uint32_t maxNumberOfElements,
const uint32_t elementIdxShift,
const Func func) {
Func&& func) {
Vec1 elementIdxShiftVec = Vec1::all(elementIdxShift);

// Get thread / element indices in block.
Expand All @@ -275,7 +278,7 @@ namespace cms {
threadIdx += gridDimension, endElementIdx += gridDimension) {
// (CPU) Loop on all elements.
for (uint32_t i = threadIdx; i < std::min(endElementIdx, maxNumberOfElements); ++i) {
func(i);
std::forward<Func>(func)(i);
}
}
}
Expand All @@ -286,9 +289,10 @@ namespace cms {
template <typename T_Acc, typename Func>
ALPAKA_FN_ACC void for_each_element_1D_grid_stride(const T_Acc& acc,
const uint32_t maxNumberOfElements,
const Func func) {
Func&& func) {
const uint32_t elementIdxShift = 0;
cms::alpakatools::for_each_element_1D_grid_stride(acc, maxNumberOfElements, elementIdxShift, func);
cms::alpakatools::for_each_element_1D_grid_stride(
acc, maxNumberOfElements, elementIdxShift, std::forward<Func>(func));
}

} // namespace alpakatools
Expand Down

0 comments on commit 36f8c3d

Please sign in to comment.