Skip to content

Commit

Permalink
tile transform for distributed array
Browse files Browse the repository at this point in the history
  • Loading branch information
hamsteri15 committed Jan 24, 2024
1 parent 6589471 commit 968db00
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 6 deletions.
36 changes: 30 additions & 6 deletions include/bits/communication/distributed_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

#include "channel.hpp"
#include "gather.hpp"
#include "include/bits/algorithms/algorithms.hpp"
#include "include/bits/core/tuple_extensions.hpp"
#include "topology.hpp"
#include "include/bits/algorithms/algorithms.hpp"

namespace jada {

Expand Down Expand Up @@ -341,19 +341,43 @@ static void for_each_indexed(ExecutionPolicy&& policy,
DistributedArray<N, T>& arr,
BinaryIndexFunction f) {

auto boxes = arr.local_boxes();
auto boxes = arr.local_boxes();
auto subspans = make_subspans(arr);
for (size_t i = 0; i < subspans.size(); ++i){
for (size_t i = 0; i < subspans.size(); ++i) {
auto offset = boxes[i].box.m_begin;
auto span = subspans[i];
auto span = subspans[i];

auto F = [=](auto md_idx){
auto F = [=](auto md_idx) {
const auto copy = elementwise_add(md_idx, offset);
f(copy, span(md_idx));
};
detail::md_for_each(policy, all_indices(span), F);
}

}

template <size_t Dir,
class ExecutionPolicy,
size_t N,
class ET1,
class ET2,
class UnaryTileFunction>
static void tile_transform(ExecutionPolicy&& policy,
const DistributedArray<N, ET1>& input,
DistributedArray<N, ET2>& output,
UnaryTileFunction f) {

auto boxes = input.local_boxes();
auto i_subspans = make_subspans(input);
auto o_subspans = make_subspans(output);

for (size_t i = 0; i < i_subspans.size(); ++i) {
auto offset = boxes[i].box.m_begin;
auto i_span = i_subspans[i];
auto o_span = o_subspans[i];

tile_transform<Dir>(policy, i_span, o_span, f);

}
}

} // namespace jada
60 changes: 60 additions & 0 deletions test/test_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,9 +913,69 @@ TEST_CASE("Test DistributedArray")
CHECK(to_vector(arr) == correct);

}


}


SECTION("tile_transform"){

index_type ni = 4;
index_type nj = 3;

std::array<index_type, 2> bpad{1,1};
std::array<index_type, 2> epad{1,1};

std::vector<int> a(size_t(ni*nj), 0);
std::vector<int> b(size_t(ni*nj), 0);

Box<2> domain({0,0}, {nj, ni});
auto topo = decompose(domain, mpi::world_size(), {true, true});

auto arr_a = distribute(a, topo, mpi::get_world_rank(), bpad, epad);
auto arr_b = distribute(b, topo, mpi::get_world_rank(), bpad, epad);

auto op = [](auto f) {
return f(-1) + f(1);
};


for (auto& data : arr_a.local_data()){
std::fill(data.begin(), data.end(), 1);
}
for (auto& data : arr_b.local_data()){
std::fill(data.begin(), data.end(), -1);
}


//mpi::wait();

tile_transform<0>(std::execution::par_unseq, arr_a, arr_b, op);

mpi::wait();

std::vector<int> correct1 =
{
2,2,2,2,
2,2,2,2,
2,2,2,2
};


CHECK(to_vector(arr_b) == correct1);

for (auto& data : arr_b.local_data()){
std::fill(data.begin(), data.end(), -3);
}



tile_transform<1>(std::execution::par_unseq, arr_a, arr_b, op);

CHECK(to_vector(arr_b) == correct1);

}



}
Expand Down

0 comments on commit 968db00

Please sign in to comment.