Skip to content

Commit

Permalink
generalised for_each_indexed for all ranks
Browse files Browse the repository at this point in the history
  • Loading branch information
hamsteri15 committed Jan 24, 2024
1 parent 303a18c commit 6589471
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 4 deletions.
7 changes: 3 additions & 4 deletions include/bits/communication/distributed_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

#include "channel.hpp"
#include "gather.hpp"
#include "include/bits/algorithms/algorithms.hpp"
#include "include/bits/core/tuple_extensions.hpp"
#include "topology.hpp"
#include "include/bits/algorithms/algorithms.hpp"

namespace jada {

Expand Down Expand Up @@ -347,9 +348,7 @@ static void for_each_indexed(ExecutionPolicy&& policy,
auto span = subspans[i];

auto F = [=](auto md_idx){
auto copy = md_idx;
std::get<0>(copy) += std::get<0>(offset);
std::get<1>(copy) += std::get<1>(offset);
const auto copy = elementwise_add(md_idx, offset);
f(copy, span(md_idx));
};
detail::md_for_each(policy, all_indices(span), F);
Expand Down
2 changes: 2 additions & 0 deletions include/bits/core/rank.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ template <class... Idx> struct Rank<std::pair<Idx...>> {
static constexpr size_t value = 2;
};



/*
template <class... Idx> struct Rank<ranges::common_tuple<Idx...>> {
Expand Down
30 changes: 30 additions & 0 deletions include/bits/core/tuple_extensions.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "rank.hpp"
#include "utils.hpp"

namespace jada {

namespace detail {

template <size_t... Is>
static constexpr auto
elementwise_add(auto tuple1, auto& tuple2, std::index_sequence<Is...>) {
return std::make_tuple((std::get<Is>(tuple1) + std::get<Is>(tuple2))...);
}
} // namespace detail

///
///@brief Elementwise addition of two generic types which have std::get implemented.
///
///@param lhs Left-hand side operand.
///@param rhs Right-hand side operand.
///@return A tuple of elementwise added values.
///
static constexpr auto elementwise_add(auto lhs, auto rhs) {
static_assert(rank(lhs) == rank(rhs), "Rank mismatch in elemenwise_add");
constexpr size_t N = rank(lhs);
return detail::elementwise_add(lhs, rhs, std::make_index_sequence<N>{});
}

} // namespace jada
15 changes: 15 additions & 0 deletions include/bits/core/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,19 @@ template <typename tuple_t> constexpr auto tuple_to_array(tuple_t&& tuple) {
return std::apply(get_array, std::forward<tuple_t>(tuple));
}

namespace detail{


template<typename... Ts, size_t... Is>
auto tuple_add(const std::tuple<Ts...>& tuple1, const std::tuple<Ts...>& tuple2, std::index_sequence<Is...>) {
return std::make_tuple((std::get<Is>(tuple1) + std::get<Is>(tuple2))...);
}
}

template<typename... Ts>
auto tuple_add(const std::tuple<Ts...>& tuple1, const std::tuple<Ts...>& tuple2) {
return detail::tuple_add(tuple1, tuple2, std::index_sequence_for<Ts...>{});
}


} // namespace jada

0 comments on commit 6589471

Please sign in to comment.