Skip to content

Commit

Permalink
adding mocking and test for new broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
csegarragonz committed Dec 1, 2021
1 parent 6c33600 commit 62d04b3
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 12 deletions.
9 changes: 9 additions & 0 deletions include/faabric/scheduler/MpiWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
#define MPI_MSGTYPE_COUNT_PREFIX "mpi-msgtype-torank"

namespace faabric::scheduler {

// -----------------------------------
// Mocking
// -----------------------------------
std::vector<faabric::MpiHostsToRanksMessage> getMpiHostsToRanksMessages();

std::vector<std::shared_ptr<faabric::MPIMessage>> getMpiMockedMessages(
int sendRank);

typedef faabric::util::Queue<std::shared_ptr<faabric::MPIMessage>>
InMemoryMpiQueue;

Expand Down
65 changes: 53 additions & 12 deletions src/scheduler/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,32 @@ static thread_local std::unordered_map<
// Id of the message that created this thread-local instance
static thread_local faabric::Message* thisRankMsg = nullptr;

// This is used for mocking in tests
namespace faabric::scheduler {

// -----------------------------------
// Mocking
// -----------------------------------
static std::mutex mockMutex;

static std::vector<faabric::MpiHostsToRanksMessage> rankMessages;

namespace faabric::scheduler {
// The identifier in this map is the sending rank. For the receiver's rank
// we can inspect the MPIMessage object
static std::map<int, std::vector<std::shared_ptr<faabric::MPIMessage>>>
mpiMockedMessages;

std::vector<faabric::MpiHostsToRanksMessage> getMpiHostsToRanksMessages()
{
faabric::util::UniqueLock lock(mockMutex);
return rankMessages;
}

std::vector<std::shared_ptr<faabric::MPIMessage>> getMpiMockedMessages(
int sendRank)
{
faabric::util::UniqueLock lock(mockMutex);
return mpiMockedMessages[sendRank];
}

MpiWorld::MpiWorld()
: thisHost(faabric::util::getSystemConfig().endpointHost)
Expand Down Expand Up @@ -303,6 +325,10 @@ void MpiWorld::destroy()
iSendRequests.size());
throw std::runtime_error("Destroying world with outstanding requests");
}

// Clear structures used for mocking
rankMessages.clear();
mpiMockedMessages.clear();
}

void MpiWorld::initialiseFromMsg(faabric::Message& msg)
Expand Down Expand Up @@ -640,6 +666,12 @@ void MpiWorld::send(int sendRank,
m->set_buffer(buffer, dataType->size * count);
}

// Mock the message sending in tests
if (faabric::util::isMockMode()) {
mpiMockedMessages[sendRank].push_back(m);
return;
}

// Dispatch the message locally or globally
if (isLocal) {
SPDLOG_TRACE("MPI - send {} -> {}", sendRank, recvRank);
Expand Down Expand Up @@ -676,6 +708,11 @@ void MpiWorld::recv(int sendRank,
// Sanity-check input parameters
checkRanksRange(sendRank, recvRank);

// If mocking the messages, ignore calls to receive that may block
if (faabric::util::isMockMode()) {
return;
}

// Recv message from underlying transport
std::shared_ptr<faabric::MPIMessage> m =
recvBatchReturnLast(sendRank, recvRank);
Expand Down Expand Up @@ -758,39 +795,43 @@ void MpiWorld::sendRecv(uint8_t* sendBuffer,
awaitAsyncRequest(recvId);
}

void MpiWorld::broadcast(int sendRank,
void MpiWorld::broadcast(int rootRank,
int thisRank,
uint8_t* buffer,
faabric_datatype_t* dataType,
int count,
faabric::MPIMessage::MPIMessageType messageType)
{
SPDLOG_TRACE("MPI - bcast {} -> all", sendRank);
SPDLOG_TRACE("MPI - bcast {} -> all", rootRank);

if (thisRank == sendRank) {
// The sendRank (originator of the broadcast) sends a message to all
if (thisRank == rootRank) {
// The rootRank (originator of the broadcast) sends a message to all
// local ranks in the broadcast, and all remote leaders in the broadcast
for (const int localRecvRank : localRanks) {
if (localRecvRank == sendRank) {
if (localRecvRank == thisRank) {
continue;
}

send(sendRank, localRecvRank, buffer, dataType, count, messageType);
send(thisRank, localRecvRank, buffer, dataType, count, messageType);
}

for (const int remoteRecvRank : remoteMasters) {
send(
sendRank, remoteRecvRank, buffer, dataType, count, messageType);
thisRank, remoteRecvRank, buffer, dataType, count, messageType);
}
} else if (thisRank == localMaster) {
// If we are the local master, first we receive the message sent by
// the originator of the broadcast
recv(sendRank, thisRank, buffer, dataType, count, nullptr, messageType);
recv(rootRank, thisRank, buffer, dataType, count, nullptr, messageType);

// If the broadcast originated locally, we are done. If not, we now
// distribute to all our local ranks
if (getHostForRank(sendRank) != thisHost) {
if (getHostForRank(rootRank) != thisHost) {
for (const int localRecvRank : localRanks) {
if (localRecvRank == thisRank) {
continue;
}

send(thisRank,
localRecvRank,
buffer,
Expand All @@ -804,7 +845,7 @@ void MpiWorld::broadcast(int sendRank,
// either our local master if the broadcast originated in a remote host,
// or the broadcast originator itself if we are on the same host
int sendingRank =
getHostForRank(sendRank) == thisHost ? sendRank : localMaster;
getHostForRank(rootRank) == thisHost ? rootRank : localMaster;

recv(
sendingRank, thisRank, buffer, dataType, count, nullptr, messageType);
Expand Down
105 changes: 105 additions & 0 deletions tests/test/scheduler/test_remote_mpi_worlds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,4 +918,109 @@ TEST_CASE_METHOD(RemoteMpiTestFixture, "Test UMB creation", "[mpi]")

thisWorld.destroy();
}

std::set<int> getReceiversFromMessages(
std::vector<std::shared_ptr<faabric::MPIMessage>> msgs)
{
std::set<int> retSet;
for (const auto& msg : msgs) {
retSet.insert(msg->destination());
}

return retSet;
}

TEST_CASE_METHOD(RemoteMpiTestFixture,
"Test number of messages sent during broadcast",
"[mpi]")
{
// Register three ranks
setWorldSizes(4, 2, 2);

// Init worlds
MpiWorld& thisWorld = getMpiWorldRegistry().createWorld(msg, worldId);
faabric::util::setMockMode(true);
thisWorld.broadcastHostsToRanks();
REQUIRE(getMpiHostsToRanksMessages().size() == 1);
otherWorld.initialiseFromMsg(msg);

// Call broadcast and check sent messages
std::set<int> expectedRecvRanks;
int expectedNumMsg;
int sendRank;
int rootRank;

SECTION("Check from root rank (local), and root is local master")
{
rootRank = 0;
sendRank = rootRank;
expectedNumMsg = 2;
expectedRecvRanks = { 1, 2 };
}

SECTION("Check from root rank (local), and root is non-local master")
{
rootRank = 1;
sendRank = rootRank;
expectedNumMsg = 2;
expectedRecvRanks = { 0, 2 };
}

SECTION("Check from local non-root rank, and non-root is local master")
{
rootRank = 1;
sendRank = 0;
expectedNumMsg = 0;
expectedRecvRanks = {};
}

SECTION("Check from local non-root rank, and non-root is non-local-master")
{
rootRank = 0;
sendRank = 1;
expectedNumMsg = 0;
expectedRecvRanks = {};
}

SECTION("Check from remote rank, and remote rank is local master")
{
rootRank = 0;
sendRank = 2;
expectedNumMsg = 1;
expectedRecvRanks = { 3 };
}

SECTION("Check from remote rank, and remote rank is not local master")
{
rootRank = 0;
sendRank = 3;
expectedNumMsg = 0;
expectedRecvRanks = {};
}

// Check for root
std::vector<int> messageData = { 0, 1, 2 };
if (sendRank < 2) {
thisWorld.broadcast(rootRank,
sendRank,
BYTES(messageData.data()),
MPI_INT,
messageData.size(),
faabric::MPIMessage::BROADCAST);
} else {
otherWorld.broadcast(rootRank,
sendRank,
BYTES(messageData.data()),
MPI_INT,
messageData.size(),
faabric::MPIMessage::BROADCAST);
}
auto msgs = getMpiMockedMessages(sendRank);
REQUIRE(msgs.size() == expectedNumMsg);
REQUIRE(getReceiversFromMessages(msgs) == expectedRecvRanks);

faabric::util::setMockMode(false);
otherWorld.destroy();
thisWorld.destroy();
}
}

0 comments on commit 62d04b3

Please sign in to comment.