Skip to content

Commit

Permalink
Add data structure for V-style MPI collective return data and use
Browse files Browse the repository at this point in the history
  • Loading branch information
rupertnash committed Oct 5, 2022
1 parent d5cf70d commit 7f4b2f5
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 77 deletions.
55 changes: 51 additions & 4 deletions Code/net/MpiCommunicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,62 @@
#ifndef HEMELB_NET_MPICOMMUNICATOR_H
#define HEMELB_NET_MPICOMMUNICATOR_H

#include <vector>
#include <cassert>
#include <map>
#include "net/MpiError.h"
#include <memory>
#include <cassert>
#include <numeric>
#include <span>
#include <vector>

#include "net/MpiError.h"

namespace hemelb
{
namespace net
{
class MpiGroup;

// This type hold the data that results from operations like
// Allgatherv. The data is stored contiguously in the `data`
// vector and the displacements (an array of int, as required
// by MPI) in `displacements`. This has a size of the number
// of participating processes, plus one.
//
// This type presents an interface similar to an array of
// arrays: specifically, indexing returns a span over the
// corresponding process's data.
template <typename T>
struct displaced_data {
std::vector<int> displacements;
std::vector<T> data;

// Default is empty.
displaced_data() = default;
// Construct from a vector of sizes, computing the displacements.
displaced_data(std::vector<int> const& sizes) {
auto const N = sizes.size();
displacements.resize(N);
displacements[0] = 0;
std::inclusive_scan(sizes.begin(),
sizes.end(),
displacements.begin() + 1);
data.resize(displacements[N]);
}

// Size should be the number of participating processes from the collective.
std::size_t size() const {
return displacements.size() - 1;
}

// Return a span over the data belonging to the given process.
std::span<T> operator[](int i) {
assert(i >= 0 && i < displacements.size());
auto start = displacements[i];
auto count = displacements[i+1] - start;
return std::span<T>{data.data() + start, unsigned(count)};
}
};

class MpiCommunicator
{
public:
Expand Down Expand Up @@ -147,6 +191,9 @@ namespace hemelb
template <typename T>
std::vector<T> AllGather(const T& val) const;

template <typename T>
displaced_data<T> AllGatherV(const std::vector<T>& vals) const;

/**
* Performs an all gather operation of fixed size among the neighbours defined in a MPI graph communicator
* @param val local contribution to all gather operation
Expand All @@ -161,7 +208,7 @@ namespace hemelb
* @return vector of vectors with contributions from each neighbour. Use GetNeighbors() to map zero-based indices of outermost vector to MPI ranks
*/
template<typename T>
std::vector<std::vector<T>> AllNeighGatherV(const std::vector<T>& val) const;
displaced_data<T> AllNeighGatherV(const std::vector<T>& val) const;

template<typename T>
std::vector<T> AllToAll(const std::vector<T>& vals) const;
Expand Down
51 changes: 19 additions & 32 deletions Code/net/MpiCommunicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,17 @@ namespace hemelb
return ans;
}

template <typename T>
displaced_data<T> MpiCommunicator::AllGatherV(const std::vector<T>& local_data) const {
std::vector<int> per_rank_sizes = AllGather((int) local_data.size());
auto ans = displaced_data<T>{per_rank_sizes};
HEMELB_MPI_CALL(MPI_Allgatherv,
(local_data.data(), local_data.size(), MpiDataType<T>(),
ans.data.data(), per_rank_sizes.data(), ans.displacements.data(), MpiDataType<T>(),
*this)
);
return ans;
}
template<typename T>
std::vector<T> MpiCommunicator::AllNeighGather(const T& val) const
{
Expand All @@ -179,40 +190,16 @@ namespace hemelb
}

template<typename T>
std::vector<std::vector<T>> MpiCommunicator::AllNeighGatherV(const std::vector<T>& val) const
displaced_data<T> MpiCommunicator::AllNeighGatherV(const std::vector<T>& val) const
{
int numProcs = GetNeighborsCount();
std::vector<int> valSizes = AllNeighGather((int) val.size());
std::vector<int> valDisplacements(numProcs + 1);

// TODO: that's a scan
int totalSize = std::accumulate(valSizes.begin(),
valSizes.end(),
0);

valDisplacements[0] = 0;
for (int j = 0; j < numProcs; ++j)
{
valDisplacements[j + 1] = valDisplacements[j] + valSizes[j];
}

std::vector<T> allVal(totalSize);
HEMELB_MPI_CALL(MPI_Neighbor_allgatherv,
( val.data(), val.size(), MpiDataType<T>(),
allVal.data(), valSizes.data(), valDisplacements.data(), MpiDataType<T>(),
*this ));

std::vector<std::vector<T>> ans(numProcs);
for (int procIndex = 0; procIndex < numProcs; ++procIndex)
{
ans[procIndex].reserve(valDisplacements[procIndex + 1] - valDisplacements[procIndex]);
for (auto indexAllCoords = valDisplacements[procIndex];
indexAllCoords < valDisplacements[procIndex + 1]; ++indexAllCoords)
{
ans[procIndex].push_back(allVal[indexAllCoords]);
}
}

auto ans = displaced_data<T>{valSizes};
HEMELB_MPI_CALL(
MPI_Neighbor_allgatherv,
(val.data(), val.size(), MpiDataType<T>(),
ans.data.data(), valSizes.data(), ans.displacements.data(), MpiDataType<T>(),
*this)
);
return ans;
}

Expand Down
52 changes: 11 additions & 41 deletions Code/redblood/parallel/GraphBasedCommunication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,16 @@ namespace hemelb
auto const& neighbouringProcs = comm.GetNeighbors();
if (neighbouringProcs.size() > 0)
{
std::vector<std::vector<LatticeVector>> neighSites = comm.AllNeighGatherV(locallyOwnedSites);
auto neighSites = comm.AllNeighGatherV(locallyOwnedSites);
assert(neighSites.size() == comm.GetNeighborsCount());

// Finish populating map with knowledge comming from neighbours
for (auto const& neighbour : hemelb::util::enumerate(neighbouringProcs))
{
for (auto const& globalCoord : neighSites[neighbour.index])
{
// lattice sites are uniquely owned, so no chance of coordinates being repeated across processes
assert(coordsToProcMap.count(globalCoord) == 0);
coordsToProcMap[globalCoord] = neighbour.value;
}
for (auto&& [i, p]: util::enumerate(neighbouringProcs)) {
for (auto const& globalCoord: neighSites[i]) {
// lattice sites are uniquely owned, so no chance of coordinates being repeated across processes
assert(coordsToProcMap.count(globalCoord) == 0);
coordsToProcMap[globalCoord] = p;
}
}
}

Expand Down Expand Up @@ -86,36 +84,7 @@ namespace hemelb
serialisedLocalCoords.push_back(domain.GetSite(siteIndex).GetGlobalSiteCoords());
}

/// @\todo refactor into a method net::MpiCommunicator::AllGatherv
int numProcs = comm.Size();
std::vector<int> allSerialisedCoordSizes = comm.AllGather((int) serialisedLocalCoords.size());
std::vector<int> allSerialisedCoordDisplacements(numProcs + 1);

site_t totalSize = std::accumulate(allSerialisedCoordSizes.begin(),
allSerialisedCoordSizes.end(),
0);

allSerialisedCoordDisplacements[0] = 0;
for (int j = 0; j < numProcs; ++j)
{
allSerialisedCoordDisplacements[j + 1] = allSerialisedCoordDisplacements[j]
+ allSerialisedCoordSizes[j];
}

std::vector<LatticeVector> allSerialisedCoords(totalSize);
HEMELB_MPI_CALL(MPI_Allgatherv,
( serialisedLocalCoords.data(), serialisedLocalCoords.size(), net::MpiDataType<LatticeVector>(), allSerialisedCoords.data(), allSerialisedCoordSizes.data(), allSerialisedCoordDisplacements.data(), net::MpiDataType<LatticeVector>(), comm ));

std::vector<std::vector<LatticeVector>> coordsPerProc(numProcs);
for (decltype(numProcs) procIndex = 0; procIndex < numProcs; ++procIndex)
{
for (auto indexAllCoords = allSerialisedCoordDisplacements[procIndex];
indexAllCoords < allSerialisedCoordDisplacements[procIndex + 1]; ++indexAllCoords)
{
coordsPerProc[procIndex].push_back(allSerialisedCoords[indexAllCoords]);
}
}
/// end of refactoring
auto coordsPerProc = comm.AllGatherV(serialisedLocalCoords);

auto cellsEffectiveSizeSq = cellsEffectiveSize * cellsEffectiveSize;
auto areProcsNeighbours =
Expand All @@ -138,10 +107,11 @@ namespace hemelb
return distanceSqBetweenSubdomainEdges < cellsEffectiveSizeSq;
};

auto const numProcs = comm.Size();
std::vector<std::vector<int>> vertices(numProcs);
for (int procA(0); procA < numProcs; ++procA)
for (int procA = 0; procA < numProcs; ++procA)
{
for (int procB(procA+1); procB < numProcs; ++procB)
for (int procB = procA + 1; procB < numProcs; ++procB)
{
if (areProcsNeighbours(procA, procB))
{
Expand Down

0 comments on commit 7f4b2f5

Please sign in to comment.