From 6589471e9a24ba607cf10f9b4be44f840b1860ab Mon Sep 17 00:00:00 2001 From: Petteri Peltonen Date: Wed, 24 Jan 2024 13:33:59 +0200 Subject: [PATCH] generalised for_each_indexed for all ranks --- .../bits/communication/distributed_array.hpp | 7 ++--- include/bits/core/rank.hpp | 2 ++ include/bits/core/tuple_extensions.hpp | 30 +++++++++++++++++++ include/bits/core/utils.hpp | 15 ++++++++++ 4 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 include/bits/core/tuple_extensions.hpp diff --git a/include/bits/communication/distributed_array.hpp b/include/bits/communication/distributed_array.hpp index 3b1a873..76e5e3d 100644 --- a/include/bits/communication/distributed_array.hpp +++ b/include/bits/communication/distributed_array.hpp @@ -2,8 +2,9 @@ #include "channel.hpp" #include "gather.hpp" -#include "include/bits/algorithms/algorithms.hpp" +#include "include/bits/core/tuple_extensions.hpp" #include "topology.hpp" +#include "include/bits/algorithms/algorithms.hpp" namespace jada { @@ -347,9 +348,7 @@ static void for_each_indexed(ExecutionPolicy&& policy, auto span = subspans[i]; auto F = [=](auto md_idx){ - auto copy = md_idx; - std::get<0>(copy) += std::get<0>(offset); - std::get<1>(copy) += std::get<1>(offset); + const auto copy = elementwise_add(md_idx, offset); f(copy, span(md_idx)); }; detail::md_for_each(policy, all_indices(span), F); diff --git a/include/bits/core/rank.hpp b/include/bits/core/rank.hpp index edb9fda..355b2cc 100644 --- a/include/bits/core/rank.hpp +++ b/include/bits/core/rank.hpp @@ -36,6 +36,8 @@ template struct Rank> { static constexpr size_t value = 2; }; + + /* template struct Rank> { diff --git a/include/bits/core/tuple_extensions.hpp b/include/bits/core/tuple_extensions.hpp new file mode 100644 index 0000000..f803221 --- /dev/null +++ b/include/bits/core/tuple_extensions.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "rank.hpp" +#include "utils.hpp" + +namespace jada { + +namespace detail { + +template +static constexpr auto +elementwise_add(auto tuple1, auto& tuple2, std::index_sequence) { + return std::make_tuple((std::get(tuple1) + std::get(tuple2))...); +} +} // namespace detail + +/// +///@brief Elementwise addition of two generic types which have std::get implemented. +/// +///@param lhs Left-hand side operand. +///@param rhs Right-hand side operand. +///@return A tuple of elementwise added values. +/// +static constexpr auto elementwise_add(auto lhs, auto rhs) { + static_assert(rank(lhs) == rank(rhs), "Rank mismatch in elemenwise_add"); + constexpr size_t N = rank(lhs); + return detail::elementwise_add(lhs, rhs, std::make_index_sequence{}); +} + +} // namespace jada \ No newline at end of file diff --git a/include/bits/core/utils.hpp b/include/bits/core/utils.hpp index b401574..95b6979 100644 --- a/include/bits/core/utils.hpp +++ b/include/bits/core/utils.hpp @@ -51,4 +51,19 @@ template constexpr auto tuple_to_array(tuple_t&& tuple) { return std::apply(get_array, std::forward(tuple)); } +namespace detail{ + + +template +auto tuple_add(const std::tuple& tuple1, const std::tuple& tuple2, std::index_sequence) { + return std::make_tuple((std::get(tuple1) + std::get(tuple2))...); +} +} + +template +auto tuple_add(const std::tuple& tuple1, const std::tuple& tuple2) { + return detail::tuple_add(tuple1, tuple2, std::index_sequence_for{}); +} + + } // namespace jada