Skip to content

Commit

Permalink
Locality Aware Broadcast (#185)
Browse files Browse the repository at this point in the history
* locality aware broadcast

* clarify comments

* fix mpi native broadcast signature

* adding mocking and test for new broadcast

* remove fmt call from mpi native

* refactor to local and remote leaders; use init method to group all initialisation

* clarify terminology and nomenclature around
  • Loading branch information
csegarragonz authored Dec 6, 2021
1 parent 6ff81aa commit 4eef572
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 85 deletions.
24 changes: 22 additions & 2 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@
#define MPI_MSGTYPE_COUNT_PREFIX "mpi-msgtype-torank"

namespace faabric::scheduler {

// -----------------------------------
// Mocking
// -----------------------------------
// MPITOPTP - mocking at the MPI level won't be needed when using the PTP broker
// as the broker already has mocking capabilities
std::vector<faabric::MpiHostsToRanksMessage> getMpiHostsToRanksMessages();

std::vector<std::shared_ptr<faabric::MPIMessage>> getMpiMockedMessages(
int sendRank);

typedef faabric::util::Queue<std::shared_ptr<faabric::MPIMessage>>
InMemoryMpiQueue;

Expand Down Expand Up @@ -76,8 +87,9 @@ class MpiWorld
faabric::MPIMessage::MPIMessageType messageType =
faabric::MPIMessage::NORMAL);

void broadcast(int sendRank,
const uint8_t* buffer,
void broadcast(int rootRank,
int thisRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
faabric::MPIMessage::MPIMessageType messageType =
Expand Down Expand Up @@ -214,6 +226,14 @@ class MpiWorld
// Track at which host each rank lives
std::vector<std::string> rankHosts;
int getIndexForRanks(int sendRank, int recvRank);
std::vector<int> getRanksForHost(const std::string& host);

// Track ranks that are local to this world, and local/remote leaders
// MPITOPTP - this information exists in the broker
int localLeader;
std::vector<int> localRanks;
std::vector<int> remoteLeaders;
void initLocalRemoteLeaders();

// In-memory queues for local messaging
std::vector<std::shared_ptr<InMemoryMpiQueue>> localQueues;
Expand Down
22 changes: 8 additions & 14 deletions src/mpi_native/mpi_native.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,14 @@ int MPI_Bcast(void* buffer,
faabric::scheduler::MpiWorld& world = getExecutingWorld();

int rank = executingContext.getRank();
if (rank == root) {
SPDLOG_DEBUG(fmt::format("MPI_Bcast {} -> all", rank));
world.broadcast(
rank, (uint8_t*)buffer, datatype, count, faabric::MPIMessage::NORMAL);
} else {
SPDLOG_DEBUG(fmt::format("MPI_Bcast {} <- {}", rank, root));
world.recv(root,
rank,
(uint8_t*)buffer,
datatype,
count,
nullptr,
faabric::MPIMessage::NORMAL);
}
SPDLOG_DEBUG("MPI_Bcast {} -> all", rank);
world.broadcast(root,
rank,
(uint8_t*)buffer,
datatype,
count,
faabric::MPIMessage::BROADCAST);

return MPI_SUCCESS;
}

Expand Down
1 change: 1 addition & 0 deletions src/proto/faabric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ message MPIMessage {
ALLREDUCE = 8;
ALLTOALL = 9;
SENDRECV = 10;
BROADCAST = 11;
};

MPIMessageType messageType = 1;
Expand Down
206 changes: 147 additions & 59 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,32 @@ static thread_local std::unordered_map<
// Id of the message that created this thread-local instance
static thread_local faabric::Message* thisRankMsg = nullptr;

// This is used for mocking in tests
namespace faabric::scheduler {

// -----------------------------------
// Mocking
// -----------------------------------
static std::mutex mockMutex;

static std::vector<faabric::MpiHostsToRanksMessage> rankMessages;

namespace faabric::scheduler {
// The identifier in this map is the sending rank. For the receiver's rank
// we can inspect the MPIMessage object
static std::map<int, std::vector<std::shared_ptr<faabric::MPIMessage>>>
mpiMockedMessages;

std::vector<faabric::MpiHostsToRanksMessage> getMpiHostsToRanksMessages()
{
faabric::util::UniqueLock lock(mockMutex);
return rankMessages;
}

std::vector<std::shared_ptr<faabric::MPIMessage>> getMpiMockedMessages(
int sendRank)
{
faabric::util::UniqueLock lock(mockMutex);
return mpiMockedMessages[sendRank];
}

MpiWorld::MpiWorld()
: thisHost(faabric::util::getSystemConfig().endpointHost)
Expand Down Expand Up @@ -223,6 +245,12 @@ void MpiWorld::create(faabric::Message& call, int newId, int newSize)
rankHosts = executedAt;
basePorts = initLocalBasePorts(executedAt);

// Record which ranks are local to this world, and query for all leaders
initLocalRemoteLeaders();
// Given that we are initialising the whole MpiWorld here, the local leader
// should also be rank 0
assert(localLeader == 0);

// Initialise the memory queues for message reception
initLocalQueues();
}
Expand Down Expand Up @@ -295,6 +323,10 @@ void MpiWorld::destroy()
iSendRequests.size());
throw std::runtime_error("Destroying world with outstanding requests");
}

// Clear structures used for mocking
rankMessages.clear();
mpiMockedMessages.clear();
}

void MpiWorld::initialiseFromMsg(faabric::Message& msg)
Expand Down Expand Up @@ -322,6 +354,9 @@ void MpiWorld::initialiseFromMsg(faabric::Message& msg)
basePorts = { hostRankMsg.baseports().begin(),
hostRankMsg.baseports().end() };

// Record which ranks are local to this world, and query for all leaders
initLocalRemoteLeaders();

// Initialise the memory queues for message reception
initLocalQueues();
}
Expand All @@ -344,6 +379,40 @@ std::string MpiWorld::getHostForRank(int rank)
return host;
}

std::vector<int> MpiWorld::getRanksForHost(const std::string& host)
{
assert(rankHosts.size() == size);

std::vector<int> ranksForHost;
for (int i = 0; i < rankHosts.size(); i++) {
if (rankHosts.at(i) == host) {
ranksForHost.push_back(i);
}
}

return ranksForHost;
}

// The local leader for an MPI world is defined as the lowest rank assigned to
// this host
void MpiWorld::initLocalRemoteLeaders()
{
std::set<std::string> uniqueHosts(rankHosts.begin(), rankHosts.end());

for (const std::string& host : uniqueHosts) {
auto ranksInHost = getRanksForHost(host);
// Persist the ranks that are colocated in this host for further use
if (host == thisHost) {
localRanks = ranksInHost;
localLeader =
*std::min_element(ranksInHost.begin(), ranksInHost.end());
} else {
remoteLeaders.push_back(
*std::min_element(ranksInHost.begin(), ranksInHost.end()));
}
}
}

// Returns a pair (sendPort, recvPort)
// To assign the send and recv ports, we follow a protocol establishing:
// 1) Port range (offset) corresponding to the world that receives
Expand Down Expand Up @@ -580,6 +649,12 @@ void MpiWorld::send(int sendRank,
m->set_buffer(buffer, dataType->size * count);
}

// Mock the message sending in tests
if (faabric::util::isMockMode()) {
mpiMockedMessages[sendRank].push_back(m);
return;
}

// Dispatch the message locally or globally
if (isLocal) {
SPDLOG_TRACE("MPI - send {} -> {}", sendRank, recvRank);
Expand Down Expand Up @@ -616,6 +691,11 @@ void MpiWorld::recv(int sendRank,
// Sanity-check input parameters
checkRanksRange(sendRank, recvRank);

// If mocking the messages, ignore calls to receive that may block
if (faabric::util::isMockMode()) {
return;
}

// Recv message from underlying transport
std::shared_ptr<faabric::MPIMessage> m =
recvBatchReturnLast(sendRank, recvRank);
Expand Down Expand Up @@ -699,21 +779,59 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer,
}

void MpiWorld::broadcast(int sendRank,
const uint8_t* buffer,
int recvRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
SPDLOG_TRACE("MPI - bcast {} -> all", sendRank);

for (int r = 0; r < size; r++) {
// Skip this rank (it's doing the broadcasting)
if (r == sendRank) {
continue;
if (recvRank == sendRank) {
// The sending rank sends a message to all local ranks in the broadcast,
// and all remote leaders
for (const int localRecvRank : localRanks) {
if (localRecvRank == recvRank) {
continue;
}

send(recvRank, localRecvRank, buffer, dataType, count, messageType);
}

for (const int remoteRecvRank : remoteLeaders) {
send(
recvRank, remoteRecvRank, buffer, dataType, count, messageType);
}
} else if (recvRank == localLeader) {
// If we are the local leader, first we receive the message sent by
// the sending rank
recv(sendRank, recvRank, buffer, dataType, count, nullptr, messageType);

// If the broadcast originated locally, we are done. If not, we now
// distribute to all our local ranks
if (getHostForRank(sendRank) != thisHost) {
for (const int localRecvRank : localRanks) {
if (localRecvRank == recvRank) {
continue;
}

// Send to the other ranks
send(sendRank, r, buffer, dataType, count, messageType);
send(recvRank,
localRecvRank,
buffer,
dataType,
count,
messageType);
}
}
} else {
// If we are neither the sending rank nor a local leader, we receive
// from either our leader master if the broadcast originated in a
// different host, or the sending rank itself if we are on the same host
int sendingRank =
getHostForRank(sendRank) == thisHost ? sendRank : localLeader;

recv(
sendingRank, recvRank, buffer, dataType, count, nullptr, messageType);
}
}

Expand Down Expand Up @@ -874,23 +992,14 @@ void MpiWorld::allGather(int rank,
// Note that sendCount and recvCount here are per-rank, so we need to work
// out the full buffer size
int fullCount = recvCount * size;
if (rank == root) {
// Broadcast the result
broadcast(root,
recvBuffer,
recvType,
fullCount,
faabric::MPIMessage::ALLGATHER);
} else {
// Await the broadcast from the master
recv(root,
rank,
recvBuffer,
recvType,
fullCount,
nullptr,
faabric::MPIMessage::ALLGATHER);
}

// Do a broadcast with a hard-coded root
broadcast(root,
rank,
recvBuffer,
recvType,
fullCount,
faabric::MPIMessage::ALLGATHER);
}

void MpiWorld::awaitAsyncRequest(int requestId)
Expand Down Expand Up @@ -1006,26 +1115,12 @@ void MpiWorld::allReduce(int rank,
faabric_op_t* operation)
{
// Rank 0 coordinates the allreduce operation
if (rank == 0) {
// Run the standard reduce
reduce(0, 0, sendBuffer, recvBuffer, datatype, count, operation);

// Broadcast the result
broadcast(
0, recvBuffer, datatype, count, faabric::MPIMessage::ALLREDUCE);
} else {
// Run the standard reduce
reduce(rank, 0, sendBuffer, recvBuffer, datatype, count, operation);
// First, all ranks reduce to rank 0
reduce(rank, 0, sendBuffer, recvBuffer, datatype, count, operation);

// Await the broadcast from the master
recv(0,
rank,
recvBuffer,
datatype,
count,
nullptr,
faabric::MPIMessage::ALLREDUCE);
}
// Second, 0 broadcasts the result to all ranks
broadcast(
0, rank, recvBuffer, datatype, count, faabric::MPIMessage::ALLREDUCE);
}

void MpiWorld::op_reduce(faabric_op_t* operation,
Expand Down Expand Up @@ -1244,8 +1339,9 @@ void MpiWorld::probe(int sendRank, int recvRank, MPI_Status* status)

void MpiWorld::barrier(int thisRank)
{
// Rank 0 coordinates the barrier operation
if (thisRank == 0) {
// This is the root, hence just does the waiting
// This is the root, hence waits for all ranks to get to the barrier
SPDLOG_TRACE("MPI - barrier init {}", thisRank);

// Await messages from all others
Expand All @@ -1255,25 +1351,17 @@ void MpiWorld::barrier(int thisRank)
r, 0, nullptr, MPI_INT, 0, &s, faabric::MPIMessage::BARRIER_JOIN);
SPDLOG_TRACE("MPI - recv barrier join {}", s.MPI_SOURCE);
}

// Broadcast that the barrier is done
broadcast(0, nullptr, MPI_INT, 0, faabric::MPIMessage::BARRIER_DONE);
} else {
// Tell the root that we're waiting
SPDLOG_TRACE("MPI - barrier join {}", thisRank);
send(
thisRank, 0, nullptr, MPI_INT, 0, faabric::MPIMessage::BARRIER_JOIN);

// Receive a message saying the barrier is done
recv(0,
thisRank,
nullptr,
MPI_INT,
0,
nullptr,
faabric::MPIMessage::BARRIER_DONE);
SPDLOG_TRACE("MPI - barrier done {}", thisRank);
}

// Rank 0 broadcasts that the barrier is done (the others block here)
broadcast(
0, thisRank, nullptr, MPI_INT, 0, faabric::MPIMessage::BARRIER_DONE);
SPDLOG_TRACE("MPI - barrier done {}", thisRank);
}

std::shared_ptr<InMemoryMpiQueue> MpiWorld::getLocalQueue(int sendRank,
Expand Down
Loading

0 comments on commit 4eef572

Please sign in to comment.