From 968db00022b5c6683f0cb33ddc5862f3694f492b Mon Sep 17 00:00:00 2001 From: Petteri Peltonen Date: Wed, 24 Jan 2024 16:16:46 +0200 Subject: [PATCH] tile transform for distributed array --- .../bits/communication/distributed_array.hpp | 36 +++++++++-- test/test_communication.cpp | 60 +++++++++++++++++++ 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/include/bits/communication/distributed_array.hpp b/include/bits/communication/distributed_array.hpp index 76e5e3d..c95e564 100644 --- a/include/bits/communication/distributed_array.hpp +++ b/include/bits/communication/distributed_array.hpp @@ -2,9 +2,9 @@ #include "channel.hpp" #include "gather.hpp" +#include "include/bits/algorithms/algorithms.hpp" #include "include/bits/core/tuple_extensions.hpp" #include "topology.hpp" -#include "include/bits/algorithms/algorithms.hpp" namespace jada { @@ -341,19 +341,43 @@ static void for_each_indexed(ExecutionPolicy&& policy, DistributedArray& arr, BinaryIndexFunction f) { - auto boxes = arr.local_boxes(); + auto boxes = arr.local_boxes(); auto subspans = make_subspans(arr); - for (size_t i = 0; i < subspans.size(); ++i){ + for (size_t i = 0; i < subspans.size(); ++i) { auto offset = boxes[i].box.m_begin; - auto span = subspans[i]; + auto span = subspans[i]; - auto F = [=](auto md_idx){ + auto F = [=](auto md_idx) { const auto copy = elementwise_add(md_idx, offset); f(copy, span(md_idx)); }; detail::md_for_each(policy, all_indices(span), F); } - +} + +template +static void tile_transform(ExecutionPolicy&& policy, + const DistributedArray& input, + DistributedArray& output, + UnaryTileFunction f) { + + auto boxes = input.local_boxes(); + auto i_subspans = make_subspans(input); + auto o_subspans = make_subspans(output); + + for (size_t i = 0; i < i_subspans.size(); ++i) { + auto offset = boxes[i].box.m_begin; + auto i_span = i_subspans[i]; + auto o_span = o_subspans[i]; + + tile_transform(policy, i_span, o_span, f); + + } } } // namespace jada diff --git a/test/test_communication.cpp b/test/test_communication.cpp index 28fcbef..d29ad6f 100644 --- a/test/test_communication.cpp +++ b/test/test_communication.cpp @@ -913,9 +913,69 @@ TEST_CASE("Test DistributedArray") CHECK(to_vector(arr) == correct); } + } + + SECTION("tile_transform"){ + + index_type ni = 4; + index_type nj = 3; + + std::array bpad{1,1}; + std::array epad{1,1}; + + std::vector a(size_t(ni*nj), 0); + std::vector b(size_t(ni*nj), 0); + + Box<2> domain({0,0}, {nj, ni}); + auto topo = decompose(domain, mpi::world_size(), {true, true}); + + auto arr_a = distribute(a, topo, mpi::get_world_rank(), bpad, epad); + auto arr_b = distribute(b, topo, mpi::get_world_rank(), bpad, epad); + + auto op = [](auto f) { + return f(-1) + f(1); + }; + + + for (auto& data : arr_a.local_data()){ + std::fill(data.begin(), data.end(), 1); + } + for (auto& data : arr_b.local_data()){ + std::fill(data.begin(), data.end(), -1); + } + + + //mpi::wait(); + + tile_transform<0>(std::execution::par_unseq, arr_a, arr_b, op); + + mpi::wait(); + + std::vector correct1 = + { + 2,2,2,2, + 2,2,2,2, + 2,2,2,2 + }; + + + CHECK(to_vector(arr_b) == correct1); + + for (auto& data : arr_b.local_data()){ + std::fill(data.begin(), data.end(), -3); + } + + + + tile_transform<1>(std::execution::par_unseq, arr_a, arr_b, op); + + CHECK(to_vector(arr_b) == correct1); + + } + }