diff --git a/include/faabric/runner/FaabricMain.h b/include/faabric/runner/FaabricMain.h index 7f49f7727..6a5fa6f72 100644 --- a/include/faabric/runner/FaabricMain.h +++ b/include/faabric/runner/FaabricMain.h @@ -5,6 +5,7 @@ #include #include #include +#include #include namespace faabric::runner { @@ -23,11 +24,14 @@ class FaabricMain void startSnapshotServer(); + void startPointToPointServer(); + void shutdown(); private: faabric::state::StateServer stateServer; faabric::scheduler::FunctionCallServer functionServer; faabric::snapshot::SnapshotServer snapshotServer; + faabric::transport::PointToPointServer pointToPointServer; }; } diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 2e955df5a..b8ca3c212 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -80,7 +80,18 @@ class AsyncSendMessageEndpoint final : public MessageEndpoint void send(const uint8_t* data, size_t dataSize, bool more = false); - zmq::socket_t pushSocket; + zmq::socket_t socket; +}; + +class AsyncInternalSendMessageEndpoint final : public MessageEndpoint +{ + public: + AsyncInternalSendMessageEndpoint(const std::string& inProcLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + void send(const uint8_t* data, size_t dataSize, bool more = false); + + zmq::socket_t socket; }; class SyncSendMessageEndpoint final : public MessageEndpoint @@ -183,6 +194,15 @@ class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint std::optional recv(int size = 0) override; }; +class AsyncInternalRecvMessageEndpoint final : public RecvMessageEndpoint +{ + public: + AsyncInternalRecvMessageEndpoint(const std::string& inprocLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + std::optional recv(int size = 0) override; +}; + class SyncRecvMessageEndpoint final : public RecvMessageEndpoint { public: diff --git a/include/faabric/transport/MessageEndpointClient.h b/include/faabric/transport/MessageEndpointClient.h index 02e945925..7875dc37e 100644 --- a/include/faabric/transport/MessageEndpointClient.h +++ b/include/faabric/transport/MessageEndpointClient.h @@ -30,7 +30,6 @@ class MessageEndpointClient protected: const std::string host; - private: const int asyncPort; const int syncPort; diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 7f9f7438a..c23414f51 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -56,6 +56,8 @@ class MessageEndpointServer virtual void stop(); + virtual void onWorkerStop(); + void setRequestLatch(); void awaitRequestLatch(); diff --git a/include/faabric/transport/PointToPointBroker.h b/include/faabric/transport/PointToPointBroker.h new file mode 100644 index 000000000..19717867e --- /dev/null +++ b/include/faabric/transport/PointToPointBroker.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +namespace faabric::transport { +class PointToPointBroker +{ + public: + PointToPointBroker(); + + std::string getHostForReceiver(int appId, int recvIdx); + + void setHostForReceiver(int appId, int recvIdx, const std::string& host); + + void broadcastMappings(int appId); + + void sendMappings(int appId, const std::string& host); + + std::set getIdxsRegisteredForApp(int appId); + + void sendMessage(int appId, + int sendIdx, + int recvIdx, + const uint8_t* buffer, + size_t bufferSize); + + std::vector recvMessage(int appId, int sendIdx, int recvIdx); + + void clear(); + + void resetThreadLocalCache(); + + private: + std::shared_mutex brokerMutex; + + std::unordered_map> appIdxs; + std::unordered_map mappings; + + std::shared_ptr getClient(const std::string& host); + + faabric::scheduler::Scheduler& sch; +}; + +PointToPointBroker& getPointToPointBroker(); +} diff --git a/include/faabric/transport/PointToPointCall.h b/include/faabric/transport/PointToPointCall.h new file mode 100644 index 000000000..a636b0b39 --- /dev/null +++ b/include/faabric/transport/PointToPointCall.h @@ -0,0 +1,10 @@ +#pragma once + +namespace faabric::transport { + +enum PointToPointCall +{ + MAPPING = 0, + MESSAGE = 1 +}; +} diff --git a/include/faabric/transport/PointToPointClient.h b/include/faabric/transport/PointToPointClient.h new file mode 100644 index 000000000..806592426 --- /dev/null +++ b/include/faabric/transport/PointToPointClient.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace faabric::transport { + +std::vector> +getSentMappings(); + +std::vector> +getSentPointToPointMessages(); + +void clearSentMessages(); + +class PointToPointClient : public faabric::transport::MessageEndpointClient +{ + public: + PointToPointClient(const std::string& hostIn); + + void sendMappings(faabric::PointToPointMappings& mappings); + + void sendMessage(faabric::PointToPointMessage& msg); +}; +} diff --git a/include/faabric/transport/PointToPointServer.h b/include/faabric/transport/PointToPointServer.h new file mode 100644 index 000000000..2c961b908 --- /dev/null +++ b/include/faabric/transport/PointToPointServer.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +namespace faabric::transport { + +class PointToPointServer final : public MessageEndpointServer +{ + public: + PointToPointServer(); + + private: + PointToPointBroker& reg; + + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override; + + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override; + + void onWorkerStop() override; + + std::unique_ptr doRecvMappings( + const uint8_t* buffer, + size_t bufferSize); +}; +} diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index 523842787..5a330a831 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -15,3 +15,7 @@ #define SNAPSHOT_ASYNC_PORT 8007 #define SNAPSHOT_SYNC_PORT 8008 #define SNAPSHOT_INPROC_LABEL "snapshot" + +#define POINT_TO_POINT_ASYNC_PORT 8009 +#define POINT_TO_POINT_SYNC_PORT 8010 +#define POINT_TO_POINT_INPROC_LABEL "ptp" diff --git a/include/faabric/util/bytes.h b/include/faabric/util/bytes.h index 609952cad..ef47afeeb 100644 --- a/include/faabric/util/bytes.h +++ b/include/faabric/util/bytes.h @@ -12,6 +12,8 @@ std::vector stringToBytes(const std::string& str); std::string bytesToString(const std::vector& bytes); +std::string formatByteArrayToIntString(const std::vector& bytes); + void trimTrailingZeros(std::vector& vectorIn); int safeCopyToBuffer(const std::vector& dataIn, diff --git a/include/faabric/util/config.h b/include/faabric/util/config.h index f18bd3f1f..b424ebcf4 100644 --- a/include/faabric/util/config.h +++ b/include/faabric/util/config.h @@ -47,6 +47,7 @@ class SystemConfig int functionServerThreads; int stateServerThreads; int snapshotServerThreads; + int pointToPointServerThreads; SystemConfig(); diff --git a/src/proto/faabric.proto b/src/proto/faabric.proto index 07c7ef476..999570e61 100644 --- a/src/proto/faabric.proto +++ b/src/proto/faabric.proto @@ -200,3 +200,24 @@ message StateAppendedResponse { string key = 2; repeated AppendedValue values = 3; } + +// --------------------------------------------- +// POINT-TO-POINT +// --------------------------------------------- + +message PointToPointMessage { + int32 appId = 1; + int32 sendIdx = 2; + int32 recvIdx = 3; + + bytes data = 4; +} + +message PointToPointMappings { + message PointToPointMapping { + int32 appId = 1; + int32 recvIdx = 2; + string host = 3; + } + repeated PointToPointMapping mappings = 1; +} diff --git a/src/runner/FaabricMain.cpp b/src/runner/FaabricMain.cpp index 0b61ddf81..43956e916 100644 --- a/src/runner/FaabricMain.cpp +++ b/src/runner/FaabricMain.cpp @@ -40,6 +40,9 @@ void FaabricMain::startBackground() // Snapshots startSnapshotServer(); + // Point-to-point messaging + startPointToPointServer(); + // Work sharing startFunctionCallServer(); } @@ -71,6 +74,12 @@ void FaabricMain::startSnapshotServer() snapshotServer.start(); } +void FaabricMain::startPointToPointServer() +{ + SPDLOG_INFO("Starting point-to-point server"); + pointToPointServer.start(); +} + void FaabricMain::startStateServer() { // Skip state server if not in inmemory mode @@ -99,6 +108,9 @@ void FaabricMain::shutdown() SPDLOG_INFO("Waiting for the snapshot server to finish"); snapshotServer.stop(); + SPDLOG_INFO("Waiting for the point-to-point server to finish"); + pointToPointServer.stop(); + auto& sch = faabric::scheduler::getScheduler(); sch.shutdown(); diff --git a/src/transport/CMakeLists.txt b/src/transport/CMakeLists.txt index b49d705f9..a9ffb8684 100644 --- a/src/transport/CMakeLists.txt +++ b/src/transport/CMakeLists.txt @@ -11,6 +11,9 @@ set(HEADERS "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointClient.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MessageEndpointServer.h" "${FAABRIC_INCLUDE_DIR}/faabric/transport/MpiMessageEndpoint.h" + "${FAABRIC_INCLUDE_DIR}/faabric/transport/PointToPointBroker.h" + "${FAABRIC_INCLUDE_DIR}/faabric/transport/PointToPointClient.h" + "${FAABRIC_INCLUDE_DIR}/faabric/transport/PointToPointServer.h" ) set(LIB_FILES @@ -20,6 +23,9 @@ set(LIB_FILES MessageEndpointClient.cpp MessageEndpointServer.cpp MpiMessageEndpoint.cpp + PointToPointBroker.cpp + PointToPointClient.cpp + PointToPointServer.cpp ${HEADERS} ) diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 8fb0b1c22..6eec1c457 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -313,14 +313,14 @@ AsyncSendMessageEndpoint::AsyncSendMessageEndpoint(const std::string& hostIn, int timeoutMs) : MessageEndpoint(hostIn, portIn, timeoutMs) { - pushSocket = + socket = setUpSocket(zmq::socket_type::push, MessageEndpointConnectType::CONNECT); } void AsyncSendMessageEndpoint::sendHeader(int header) { uint8_t headerBytes = static_cast(header); - doSend(pushSocket, &headerBytes, sizeof(headerBytes), true); + doSend(socket, &headerBytes, sizeof(headerBytes), true); } void AsyncSendMessageEndpoint::send(const uint8_t* data, @@ -328,7 +328,24 @@ void AsyncSendMessageEndpoint::send(const uint8_t* data, bool more) { SPDLOG_TRACE("PUSH {} ({} bytes, more {})", address, dataSize, more); - doSend(pushSocket, data, dataSize, more); + doSend(socket, data, dataSize, more); +} + +AsyncInternalSendMessageEndpoint::AsyncInternalSendMessageEndpoint( + const std::string& inprocLabel, + int timeoutMs) + : MessageEndpoint("inproc://" + inprocLabel, timeoutMs) +{ + socket = + setUpSocket(zmq::socket_type::push, MessageEndpointConnectType::CONNECT); +} + +void AsyncInternalSendMessageEndpoint::send(const uint8_t* data, + size_t dataSize, + bool more) +{ + SPDLOG_TRACE("PUSH {} ({} bytes, more {})", address, dataSize, more); + doSend(socket, data, dataSize, more); } // ---------------------------------------------- @@ -495,6 +512,21 @@ std::optional AsyncRecvMessageEndpoint::recv(int size) return RecvMessageEndpoint::recv(size); } +AsyncInternalRecvMessageEndpoint::AsyncInternalRecvMessageEndpoint( + const std::string& inprocLabel, + int timeoutMs) + : RecvMessageEndpoint(inprocLabel, + timeoutMs, + zmq::socket_type::pull, + MessageEndpointConnectType::BIND) +{} + +std::optional AsyncInternalRecvMessageEndpoint::recv(int size) +{ + SPDLOG_TRACE("PULL {} ({} bytes)", address, size); + return RecvMessageEndpoint::recv(size); +} + // ---------------------------------------------- // SYNC RECV ENDPOINT // ---------------------------------------------- diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 3e340675f..e340853a3 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -157,6 +157,9 @@ void MessageEndpointServerHandler::start( } } + // Perform the tidy-up + server->onWorkerStop(); + // Just before the thread dies, check if there's something // waiting on the shutdown latch if (server->shutdownLatch != nullptr) { @@ -286,6 +289,11 @@ void MessageEndpointServer::stop() syncHandler.join(); } +void MessageEndpointServer::onWorkerStop() +{ + // Nothing to do by default +} + void MessageEndpointServer::setRequestLatch() { requestLatch = faabric::util::Latch::create(2); diff --git a/src/transport/PointToPointBroker.cpp b/src/transport/PointToPointBroker.cpp new file mode 100644 index 000000000..b5bd600df --- /dev/null +++ b/src/transport/PointToPointBroker.cpp @@ -0,0 +1,231 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace faabric::transport { + +// NOTE: Keeping 0MQ sockets in TLS is usually a bad idea, as they _must_ be +// closed before the global context. However, in this case it's worth it +// to cache the sockets across messages, as otherwise we'd be creating and +// destroying a lot of them under high throughput. To ensure things are cleared +// up, see the thread-local tidy-up message on this class and its usage in the +// rest of the codebase. +thread_local std:: + unordered_map> + recvEndpoints; + +thread_local std:: + unordered_map> + sendEndpoints; + +thread_local std::unordered_map> + clients; + +std::string getPointToPointKey(int appId, int sendIdx, int recvIdx) +{ + return fmt::format("{}-{}-{}", appId, sendIdx, recvIdx); +} + +std::string getPointToPointKey(int appId, int recvIdx) +{ + return fmt::format("{}-{}", appId, recvIdx); +} + +PointToPointBroker::PointToPointBroker() + : sch(faabric::scheduler::getScheduler()) +{} + +std::string PointToPointBroker::getHostForReceiver(int appId, int recvIdx) +{ + faabric::util::SharedLock lock(brokerMutex); + + std::string key = getPointToPointKey(appId, recvIdx); + + if (mappings.find(key) == mappings.end()) { + SPDLOG_ERROR( + "No point-to-point mapping for app {} idx {}", appId, recvIdx); + throw std::runtime_error("No point-to-point mapping found"); + } + + return mappings[key]; +} + +void PointToPointBroker::setHostForReceiver(int appId, + int recvIdx, + const std::string& host) +{ + faabric::util::FullLock lock(brokerMutex); + + SPDLOG_TRACE( + "Setting point-to-point mapping {}:{} to {}", appId, recvIdx, host); + + // Record this index for this app + appIdxs[appId].insert(recvIdx); + + // Add host mapping + std::string key = getPointToPointKey(appId, recvIdx); + mappings[key] = host; +} + +void PointToPointBroker::broadcastMappings(int appId) +{ + auto& sch = faabric::scheduler::getScheduler(); + + // TODO seems excessive to broadcast to all hosts, could we perhaps use the + // set of registered hosts? + std::set hosts = sch.getAvailableHosts(); + + faabric::util::SystemConfig& conf = faabric::util::getSystemConfig(); + + for (const auto& host : hosts) { + // Skip this host + if (host == conf.endpointHost) { + continue; + } + + sendMappings(appId, host); + } +} + +void PointToPointBroker::sendMappings(int appId, const std::string& host) +{ + faabric::util::SharedLock lock(brokerMutex); + + faabric::PointToPointMappings msg; + + std::set& indexes = appIdxs[appId]; + + for (auto i : indexes) { + std::string key = getPointToPointKey(appId, i); + std::string host = mappings[key]; + + auto* mapping = msg.add_mappings(); + mapping->set_appid(appId); + mapping->set_recvidx(i); + mapping->set_host(host); + } + + SPDLOG_DEBUG("Sending {} point-to-point mappings for {} to {}", + indexes.size(), + appId, + host); + + auto cli = getClient(host); + cli->sendMappings(msg); +} + +std::set PointToPointBroker::getIdxsRegisteredForApp(int appId) +{ + faabric::util::SharedLock lock(brokerMutex); + return appIdxs[appId]; +} + +void PointToPointBroker::sendMessage(int appId, + int sendIdx, + int recvIdx, + const uint8_t* buffer, + size_t bufferSize) +{ + std::string host = getHostForReceiver(appId, recvIdx); + + if (host == faabric::util::getSystemConfig().endpointHost) { + std::string label = getPointToPointKey(appId, sendIdx, recvIdx); + + // Note - this map is thread-local so no locking required + if (sendEndpoints.find(label) == sendEndpoints.end()) { + sendEndpoints[label] = + std::make_unique(label); + + SPDLOG_TRACE("Created new internal send endpoint {}", + sendEndpoints[label]->getAddress()); + } + + SPDLOG_TRACE("Local point-to-point message {}:{}:{} to {}", + appId, + sendIdx, + recvIdx, + sendEndpoints[label]->getAddress()); + + sendEndpoints[label]->send(buffer, bufferSize); + + } else { + auto cli = getClient(host); + faabric::PointToPointMessage msg; + msg.set_appid(appId); + msg.set_sendidx(sendIdx); + msg.set_recvidx(recvIdx); + msg.set_data(buffer, bufferSize); + + SPDLOG_TRACE("Remote point-to-point message {}:{}:{} to {}", + appId, + sendIdx, + recvIdx, + host); + + cli->sendMessage(msg); + } +} + +std::vector PointToPointBroker::recvMessage(int appId, + int sendIdx, + int recvIdx) +{ + std::string label = getPointToPointKey(appId, sendIdx, recvIdx); + + // Note: this map is thread-local so no locking required + if (recvEndpoints.find(label) == recvEndpoints.end()) { + recvEndpoints[label] = + std::make_unique(label); + SPDLOG_TRACE("Created new internal recv endpoint {}", + recvEndpoints[label]->getAddress()); + } + + std::optional messageDataMaybe = + recvEndpoints[label]->recv().value(); + Message messageData = messageDataMaybe.value(); + + // TODO - possible to avoid this copy? + return messageData.dataCopy(); +} + +std::shared_ptr PointToPointBroker::getClient( + const std::string& host) +{ + // Note - this map is thread-local so no locking required + if (clients.find(host) == clients.end()) { + clients[host] = std::make_shared(host); + + SPDLOG_TRACE("Created new point-to-point client {}", host); + } + + return clients[host]; +} + +void PointToPointBroker::clear() +{ + faabric::util::SharedLock lock(brokerMutex); + + appIdxs.clear(); + mappings.clear(); +} + +void PointToPointBroker::resetThreadLocalCache() +{ + SPDLOG_TRACE("Resetting point-to-point thread-local cache"); + + sendEndpoints.clear(); + recvEndpoints.clear(); + clients.clear(); +} + +PointToPointBroker& getPointToPointBroker() +{ + static PointToPointBroker reg; + return reg; +} +} diff --git a/src/transport/PointToPointClient.cpp b/src/transport/PointToPointClient.cpp new file mode 100644 index 000000000..872b82d3d --- /dev/null +++ b/src/transport/PointToPointClient.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace faabric::transport { + +static std::vector> + sentMappings; + +static std::vector> + sentMessages; + +std::vector> +getSentMappings() +{ + return sentMappings; +} + +std::vector> +getSentPointToPointMessages() +{ + return sentMessages; +} + +void clearSentMessages() +{ + sentMappings.clear(); + sentMessages.clear(); +} + +PointToPointClient::PointToPointClient(const std::string& hostIn) + : faabric::transport::MessageEndpointClient(hostIn, + POINT_TO_POINT_ASYNC_PORT, + POINT_TO_POINT_SYNC_PORT) +{} + +void PointToPointClient::sendMappings(faabric::PointToPointMappings& mappings) +{ + if (faabric::util::isMockMode()) { + sentMappings.emplace_back(host, mappings); + } else { + faabric::EmptyResponse resp; + syncSend(PointToPointCall::MAPPING, &mappings, &resp); + } +} + +void PointToPointClient::sendMessage(faabric::PointToPointMessage& msg) +{ + if (faabric::util::isMockMode()) { + sentMessages.emplace_back(host, msg); + } else { + asyncSend(PointToPointCall::MESSAGE, &msg); + } +} +} diff --git a/src/transport/PointToPointServer.cpp b/src/transport/PointToPointServer.cpp new file mode 100644 index 000000000..581167a0c --- /dev/null +++ b/src/transport/PointToPointServer.cpp @@ -0,0 +1,80 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace faabric::transport { + +PointToPointServer::PointToPointServer() + : faabric::transport::MessageEndpointServer( + POINT_TO_POINT_ASYNC_PORT, + POINT_TO_POINT_SYNC_PORT, + POINT_TO_POINT_INPROC_LABEL, + faabric::util::getSystemConfig().pointToPointServerThreads) + , reg(getPointToPointBroker()) +{} + +void PointToPointServer::doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) +{ + switch (header) { + case (faabric::transport::PointToPointCall::MESSAGE): { + PARSE_MSG(faabric::PointToPointMessage, buffer, bufferSize) + + // Send the message locally to the downstream socket + reg.sendMessage(msg.appid(), + msg.sendidx(), + msg.recvidx(), + BYTES_CONST(msg.data().c_str()), + msg.data().size()); + break; + } + default: { + SPDLOG_ERROR("Invalid aync point-to-point header: {}", header); + throw std::runtime_error("Invalid async point-to-point message"); + } + } +} + +std::unique_ptr PointToPointServer::doSyncRecv( + int header, + const uint8_t* buffer, + size_t bufferSize) +{ + switch (header) { + case (faabric::transport::PointToPointCall::MAPPING): { + return doRecvMappings(buffer, bufferSize); + } + default: { + SPDLOG_ERROR("Invalid sync point-to-point header: {}", header); + throw std::runtime_error("Invalid sync point-to-point message"); + } + } +} + +std::unique_ptr PointToPointServer::doRecvMappings( + const uint8_t* buffer, + size_t bufferSize) +{ + PARSE_MSG(faabric::PointToPointMappings, buffer, bufferSize) + + for (const auto& m : msg.mappings()) { + reg.setHostForReceiver(m.appid(), m.recvidx(), m.host()); + } + + return std::make_unique(); +} + +void PointToPointServer::onWorkerStop() +{ + // Clear any thread-local cached sockets + reg.resetThreadLocalCache(); +} + +} diff --git a/src/util/bytes.cpp b/src/util/bytes.cpp index 74d7d73ff..43666b242 100644 --- a/src/util/bytes.cpp +++ b/src/util/bytes.cpp @@ -1,5 +1,6 @@ #include +#include #include namespace faabric::util { @@ -72,4 +73,21 @@ std::string bytesToString(const std::vector& bytes) return result; } + +std::string formatByteArrayToIntString(const std::vector& bytes) +{ + std::stringstream ss; + + ss << "["; + for (int i = 0; i < bytes.size(); i++) { + ss << (int)bytes.at(i); + + if (i < bytes.size() - 1) { + ss << ", "; + } + } + ss << "]"; + + return ss.str(); +} } diff --git a/src/util/config.cpp b/src/util/config.cpp index 76a4fb273..633291de3 100644 --- a/src/util/config.cpp +++ b/src/util/config.cpp @@ -65,6 +65,8 @@ void SystemConfig::initialise() this->getSystemConfIntParam("STATE_SERVER_THREADS", "2"); snapshotServerThreads = this->getSystemConfIntParam("SNAPSHOT_SERVER_THREADS", "2"); + pointToPointServerThreads = + this->getSystemConfIntParam("POINT_TO_POINT_SERVER_THREADS", "2"); } int SystemConfig::getSystemConfIntParam(const char* name, diff --git a/src/util/func.cpp b/src/util/func.cpp index 42deb4206..0bef897d7 100644 --- a/src/util/func.cpp +++ b/src/util/func.cpp @@ -65,8 +65,11 @@ std::shared_ptr batchExecFactory( { auto req = batchExecFactory(); + // Force the messages to have the same app ID + uint32_t appId = faabric::util::generateGid(); for (int i = 0; i < count; i++) { *req->add_messages() = messageFactory(user, function); + req->mutable_messages()->at(i).set_appid(appId); } return req; diff --git a/tests/dist/CMakeLists.txt b/tests/dist/CMakeLists.txt index d55b65443..d24e4b6b9 100644 --- a/tests/dist/CMakeLists.txt +++ b/tests/dist/CMakeLists.txt @@ -12,6 +12,7 @@ add_library(faabric_dist_tests_lib DistTestExecutor.h DistTestExecutor.cpp scheduler/functions.cpp + transport/functions.cpp ) target_link_libraries(faabric_dist_tests_lib faabric_test_utils) diff --git a/tests/dist/fixtures.h b/tests/dist/fixtures.h index 0ae349aa1..f34abbb4f 100644 --- a/tests/dist/fixtures.h +++ b/tests/dist/fixtures.h @@ -13,12 +13,14 @@ class DistTestsFixture : public SchedulerTestFixture , public ConfTestFixture , public SnapshotTestFixture + , public PointToPointTestFixture { public: DistTestsFixture() { - // Get other hosts - std::string thisHost = conf.endpointHost; + // Make sure the host list is up to date + sch.addHostToGlobalSet(getMasterIP()); + sch.addHostToGlobalSet(getWorkerIP()); // Set up executor std::shared_ptr fac = @@ -26,6 +28,8 @@ class DistTestsFixture faabric::scheduler::setExecutorFactory(fac); } + ~DistTestsFixture() {} + std::string getWorkerIP() { if (workerIP.empty()) { diff --git a/tests/dist/init.cpp b/tests/dist/init.cpp index 0a04a55d4..caffcb260 100644 --- a/tests/dist/init.cpp +++ b/tests/dist/init.cpp @@ -9,5 +9,6 @@ void initDistTests() SPDLOG_INFO("Registering distributed test server functions"); tests::registerSchedulerTestFunctions(); + tests::registerTransportTestFunctions(); } } diff --git a/tests/dist/init.h b/tests/dist/init.h index 0a7455916..54018f219 100644 --- a/tests/dist/init.h +++ b/tests/dist/init.h @@ -6,5 +6,6 @@ void initDistTests(); // Specific test functions void registerSchedulerTestFunctions(); +void registerTransportTestFunctions(); } diff --git a/tests/dist/transport/functions.cpp b/tests/dist/transport/functions.cpp new file mode 100644 index 000000000..f94a80a9b --- /dev/null +++ b/tests/dist/transport/functions.cpp @@ -0,0 +1,77 @@ +#include "faabric_utils.h" +#include + +#include "DistTestExecutor.h" +#include "init.h" + +#include +#include +#include + +using namespace faabric::util; + +namespace tests { + +int handlePointToPointFunction( + faabric::scheduler::Executor* exec, + int threadPoolIdx, + int msgIdx, + std::shared_ptr req) +{ + faabric::Message& msg = req->mutable_messages()->at(msgIdx); + + uint8_t appIdx = (uint8_t)msg.appindex(); + + faabric::transport::PointToPointBroker& broker = + faabric::transport::getPointToPointBroker(); + + // Start by receiving a kick-off message from the master (to make sure the + // mappings have been broadcasted) + std::vector kickOffData = + broker.recvMessage(msg.appid(), 0, appIdx); + + // Check data received + std::vector expectedKickOffData = { 0, 1, 2 }; + if (kickOffData != expectedKickOffData) { + SPDLOG_ERROR("Point-to-point kick-off not as expected {} != {}", + formatByteArrayToIntString(kickOffData), + formatByteArrayToIntString(expectedKickOffData)); + return 1; + } + + // Send to next index in ring and recv from previous in ring. + uint8_t minIdx = 1; + uint8_t maxIdx = 3; + uint8_t sendToIdx = appIdx < maxIdx ? appIdx + 1 : minIdx; + uint8_t recvFromIdx = appIdx > minIdx ? appIdx - 1 : maxIdx; + + // Send a series of our own index, expect to receive the same from other + // senders + std::vector sendData(10, appIdx); + std::vector expectedRecvData(10, recvFromIdx); + + // Do the sending + broker.sendMessage( + msg.appid(), appIdx, sendToIdx, sendData.data(), sendData.size()); + + // Do the receiving + std::vector actualRecvData = + broker.recvMessage(msg.appid(), recvFromIdx, appIdx); + + // Check data is as expected + if (actualRecvData != expectedRecvData) { + SPDLOG_ERROR("Point-to-point recv data not as expected {} != {}", + formatByteArrayToIntString(actualRecvData), + formatByteArrayToIntString(expectedRecvData)); + return 1; + } + + return 0; +} + +void registerTransportTestFunctions() +{ + registerDistTestExecutorCallback( + "ptp", "simple", handlePointToPointFunction); +} +} diff --git a/tests/dist/transport/test_point_to_point.cpp b/tests/dist/transport/test_point_to_point.cpp new file mode 100644 index 000000000..495e78dac --- /dev/null +++ b/tests/dist/transport/test_point_to_point.cpp @@ -0,0 +1,82 @@ +#include "faabric_utils.h" +#include + +#include "fixtures.h" +#include "init.h" + +#include +#include +#include +#include +#include + +namespace tests { + +TEST_CASE_METHOD(DistTestsFixture, + "Test point-to-point messaging on multiple hosts", + "[ptp]") +{ + std::set actualAvailable = sch.getAvailableHosts(); + std::set expectedAvailable = { getMasterIP(), getWorkerIP() }; + REQUIRE(actualAvailable == expectedAvailable); + + // Set up this host's resources + // Make sure some functions execute remotely, some locally + int nLocalSlots = 1; + int nFuncs = 3; + + faabric::HostResources res; + res.set_slots(nLocalSlots); + sch.setThisHostResources(res); + + // Set up batch request + std::shared_ptr req = + faabric::util::batchExecFactory("ptp", "simple", nFuncs); + + // Double check app id + int appId = req->messages().at(0).appid(); + REQUIRE(appId > 0); + + faabric::transport::PointToPointBroker& broker = + faabric::transport::getPointToPointBroker(); + + std::vector expectedHosts = { getMasterIP(), + getWorkerIP(), + getWorkerIP() }; + + // Set up individual messages + // Note that this thread is acting as app index 0 + for (int i = 0; i < nFuncs; i++) { + faabric::Message& msg = req->mutable_messages()->at(i); + + msg.set_appindex(i + 1); + + // Register function locations to where we assume they'll be executed + // (we'll confirm this is the case after scheduling) + broker.setHostForReceiver( + msg.appid(), msg.appindex(), expectedHosts.at(i)); + } + + // Call the functions + std::vector actualHosts = sch.callFunctions(req); + REQUIRE(actualHosts == expectedHosts); + + // Broadcast mappings to other hosts + broker.broadcastMappings(appId); + + // Send kick-off message to all functions + std::vector kickOffData = { 0, 1, 2 }; + for (int i = 0; i < nFuncs; i++) { + broker.sendMessage( + appId, 0, i + 1, kickOffData.data(), kickOffData.size()); + } + + // Check other functions executed successfully + for (int i = 0; i < nFuncs; i++) { + faabric::Message& m = req->mutable_messages()->at(i); + + sch.getFunctionResult(m.id(), 2000); + REQUIRE(m.returnvalue() == 0); + } +} +} diff --git a/tests/test/transport/test_point_to_point.cpp b/tests/test/transport/test_point_to_point.cpp new file mode 100644 index 000000000..a4ce27d3b --- /dev/null +++ b/tests/test/transport/test_point_to_point.cpp @@ -0,0 +1,271 @@ +#include + +#include "faabric_utils.h" + +#include + +#include +#include +#include +#include +#include + +using namespace faabric::transport; +using namespace faabric::util; + +namespace tests { + +class PointToPointClientServerFixture + : public PointToPointTestFixture + , SchedulerTestFixture +{ + public: + PointToPointClientServerFixture() + : cli(LOCALHOST) + { + server.start(); + } + + ~PointToPointClientServerFixture() { server.stop(); } + + protected: + faabric::transport::PointToPointClient cli; + faabric::transport::PointToPointServer server; +}; + +TEST_CASE_METHOD(PointToPointClientServerFixture, + "Test set and get point-to-point hosts", + "[transport][ptp]") +{ + // Note - deliberately overlap app indexes to make sure app id counts + int appIdA = 123; + int appIdB = 345; + int idxA1 = 0; + int idxB1 = 2; + int idxA2 = 10; + int idxB2 = 10; + + std::string hostA = "host-a"; + std::string hostB = "host-b"; + std::string hostC = "host-c"; + + REQUIRE_THROWS(broker.getHostForReceiver(appIdA, idxA1)); + REQUIRE_THROWS(broker.getHostForReceiver(appIdA, idxA2)); + REQUIRE_THROWS(broker.getHostForReceiver(appIdB, idxB1)); + REQUIRE_THROWS(broker.getHostForReceiver(appIdB, idxB2)); + + broker.setHostForReceiver(appIdA, idxA1, hostA); + broker.setHostForReceiver(appIdB, idxB1, hostB); + + std::set expectedA = { idxA1 }; + std::set expectedB = { idxB1 }; + REQUIRE(broker.getIdxsRegisteredForApp(appIdA) == expectedA); + REQUIRE(broker.getIdxsRegisteredForApp(appIdB) == expectedB); + + REQUIRE(broker.getHostForReceiver(appIdA, idxA1) == hostA); + REQUIRE_THROWS(broker.getHostForReceiver(appIdA, idxA2)); + REQUIRE(broker.getHostForReceiver(appIdB, idxB1) == hostB); + REQUIRE_THROWS(broker.getHostForReceiver(appIdB, idxB2)); + + broker.setHostForReceiver(appIdA, idxA2, hostB); + broker.setHostForReceiver(appIdB, idxB2, hostC); + + expectedA = { idxA1, idxA2 }; + expectedB = { idxB1, idxB2 }; + + REQUIRE(broker.getIdxsRegisteredForApp(appIdA) == expectedA); + REQUIRE(broker.getIdxsRegisteredForApp(appIdB) == expectedB); + + REQUIRE(broker.getHostForReceiver(appIdA, idxA1) == hostA); + REQUIRE(broker.getHostForReceiver(appIdA, idxA2) == hostB); + REQUIRE(broker.getHostForReceiver(appIdB, idxB1) == hostB); + REQUIRE(broker.getHostForReceiver(appIdB, idxB2) == hostC); +} + +TEST_CASE_METHOD(PointToPointClientServerFixture, + "Test sending point-to-point mappings via broker", + "[transport][ptp]") +{ + faabric::util::setMockMode(true); + + int appIdA = 123; + int appIdB = 345; + + int idxA1 = 1; + int idxA2 = 2; + int idxB1 = 1; + + std::string hostA = "host-a"; + std::string hostB = "host-b"; + std::string hostC = "host-c"; + + faabric::scheduler::Scheduler& sch = faabric::scheduler::getScheduler(); + sch.reset(); + + sch.addHostToGlobalSet(hostA); + sch.addHostToGlobalSet(hostB); + sch.addHostToGlobalSet(hostC); + + // Includes this host + REQUIRE(sch.getAvailableHosts().size() == 4); + + broker.setHostForReceiver(appIdA, idxA1, hostA); + broker.setHostForReceiver(appIdA, idxA2, hostB); + broker.setHostForReceiver(appIdB, idxB1, hostB); + + std::vector expectedHosts; + SECTION("Send single host") + { + broker.sendMappings(appIdA, hostC); + expectedHosts = { hostC }; + } + + SECTION("Broadcast all hosts") + { + broker.broadcastMappings(appIdA); + + // Don't expect to be broadcast to this host + expectedHosts = { hostA, hostB, hostC }; + } + + auto actualSent = getSentMappings(); + REQUIRE(actualSent.size() == expectedHosts.size()); + + // Sort the sent mappings based on host + std::sort(actualSent.begin(), + actualSent.end(), + [](const std::pair& a, + const std::pair& b) + -> bool { return a.first < b.first; }); + + // Check each of the sent mappings is as we would expect + for (int i = 0; i < expectedHosts.size(); i++) { + REQUIRE(actualSent.at(i).first == expectedHosts.at(i)); + + faabric::PointToPointMappings actualMappings = actualSent.at(i).second; + REQUIRE(actualMappings.mappings().size() == 2); + + faabric::PointToPointMappings::PointToPointMapping mappingA = + actualMappings.mappings().at(0); + faabric::PointToPointMappings::PointToPointMapping mappingB = + actualMappings.mappings().at(1); + + REQUIRE(mappingA.appid() == appIdA); + REQUIRE(mappingB.appid() == appIdA); + + // Note - we don't know the order of the mappings and can't easily sort + // the data in the protobuf object, so it's easiest just to check both + // possible orderings. + if (mappingA.recvidx() == idxA1) { + REQUIRE(mappingA.host() == hostA); + + REQUIRE(mappingB.recvidx() == idxA2); + REQUIRE(mappingB.host() == hostB); + } else if (mappingA.recvidx() == idxA2) { + REQUIRE(mappingA.host() == hostB); + + REQUIRE(mappingB.recvidx() == idxA1); + REQUIRE(mappingB.host() == hostA); + } else { + FAIL(); + } + } +} + +TEST_CASE_METHOD(PointToPointClientServerFixture, + "Test sending point-to-point mappings from client", + "[transport][ptp]") +{ + int appIdA = 123; + int appIdB = 345; + + int idxA1 = 1; + int idxA2 = 2; + int idxB1 = 1; + + std::string hostA = "host-a"; + std::string hostB = "host-b"; + + REQUIRE(broker.getIdxsRegisteredForApp(appIdA).empty()); + REQUIRE(broker.getIdxsRegisteredForApp(appIdB).empty()); + + faabric::PointToPointMappings mappings; + + auto* mappingA1 = mappings.add_mappings(); + mappingA1->set_appid(appIdA); + mappingA1->set_recvidx(idxA1); + mappingA1->set_host(hostA); + + auto* mappingA2 = mappings.add_mappings(); + mappingA2->set_appid(appIdA); + mappingA2->set_recvidx(idxA2); + mappingA2->set_host(hostB); + + auto* mappingB1 = mappings.add_mappings(); + mappingB1->set_appid(appIdB); + mappingB1->set_recvidx(idxB1); + mappingB1->set_host(hostA); + + cli.sendMappings(mappings); + + REQUIRE(broker.getIdxsRegisteredForApp(appIdA).size() == 2); + REQUIRE(broker.getIdxsRegisteredForApp(appIdB).size() == 1); + + REQUIRE(broker.getHostForReceiver(appIdA, idxA1) == hostA); + REQUIRE(broker.getHostForReceiver(appIdA, idxA2) == hostB); + REQUIRE(broker.getHostForReceiver(appIdB, idxB1) == hostA); +} + +TEST_CASE_METHOD(PointToPointClientServerFixture, + "Test send and receive point-to-point messages", + "[transport][ptp]") +{ + int appId = 123; + int idxA = 5; + int idxB = 10; + + // Ensure this host is set to localhost + faabric::util::SystemConfig& conf = faabric::util::getSystemConfig(); + conf.endpointHost = LOCALHOST; + + // Register both indexes on this host + broker.setHostForReceiver(appId, idxA, LOCALHOST); + broker.setHostForReceiver(appId, idxB, LOCALHOST); + + std::vector sentDataA = { 0, 1, 2, 3 }; + std::vector receivedDataA; + std::vector sentDataB = { 3, 4, 5 }; + std::vector receivedDataB; + + // Make sure we send the message before a receiver is available to check + // async handling + broker.sendMessage(appId, idxA, idxB, sentDataA.data(), sentDataA.size()); + + SLEEP_MS(1000); + + std::thread t([appId, idxA, idxB, &receivedDataA, &sentDataB] { + PointToPointBroker& broker = getPointToPointBroker(); + + // Receive the first message + receivedDataA = broker.recvMessage(appId, idxA, idxB); + + // Send a message back (note reversing the indexes) + broker.sendMessage( + appId, idxB, idxA, sentDataB.data(), sentDataB.size()); + + broker.resetThreadLocalCache(); + }); + + // Receive the message sent back + receivedDataB = broker.recvMessage(appId, idxB, idxA); + + if (t.joinable()) { + t.join(); + } + + REQUIRE(receivedDataA == sentDataA); + REQUIRE(receivedDataB == sentDataB); + + conf.reset(); +} +} diff --git a/tests/test/util/test_bytes.cpp b/tests/test/util/test_bytes.cpp index a10659d7e..f9573a572 100644 --- a/tests/test/util/test_bytes.cpp +++ b/tests/test/util/test_bytes.cpp @@ -137,4 +137,25 @@ TEST_CASE("Test integer encoding to/from bytes", "[util]") REQUIRE_THROWS_AS(readBytesOf(buffer, offset, &r1), std::range_error); } +TEST_CASE("Test format byte array to string", "[util]") +{ + std::vector bytesIn; + std::string expectedString; + + SECTION("Empty") { expectedString = "[]"; } + + SECTION("Non-empty") + { + bytesIn = { 0, 1, 2, 3, 4, 5, 6, 7 }; + expectedString = "[0, 1, 2, 3, 4, 5, 6, 7]"; + } + + SECTION("Larger int values") + { + bytesIn = { 23, 9, 100 }; + expectedString = "[23, 9, 100]"; + } + + REQUIRE(formatByteArrayToIntString(bytesIn) == expectedString); +} } diff --git a/tests/test/util/test_config.cpp b/tests/test/util/test_config.cpp index ff7e726db..383253057 100644 --- a/tests/test/util/test_config.cpp +++ b/tests/test/util/test_config.cpp @@ -51,6 +51,8 @@ TEST_CASE("Test overriding system config initialisation", "[util]") std::string functionThreads = setEnvVar("FUNCTION_SERVER_THREADS", "111"); std::string stateThreads = setEnvVar("STATE_SERVER_THREADS", "222"); std::string snapshotThreads = setEnvVar("SNAPSHOT_SERVER_THREADS", "333"); + std::string pointToPointThreads = + setEnvVar("POINT_TO_POINT_SERVER_THREADS", "444"); std::string mpiSize = setEnvVar("DEFAULT_MPI_WORLD_SIZE", "2468"); std::string mpiPort = setEnvVar("MPI_BASE_PORT", "9999"); @@ -75,6 +77,7 @@ TEST_CASE("Test overriding system config initialisation", "[util]") REQUIRE(conf.functionServerThreads == 111); REQUIRE(conf.stateServerThreads == 222); REQUIRE(conf.snapshotServerThreads == 333); + REQUIRE(conf.pointToPointServerThreads == 444); REQUIRE(conf.defaultMpiWorldSize == 2468); REQUIRE(conf.mpiBasePort == 9999); @@ -99,6 +102,7 @@ TEST_CASE("Test overriding system config initialisation", "[util]") setEnvVar("FUNCTION_SERVER_THREADS", functionThreads); setEnvVar("STATE_SERVER_THREADS", stateThreads); setEnvVar("SNAPSHOT_SERVER_THREADS", snapshotThreads); + setEnvVar("POINT_TO_POINT_SERVER_THREADS", pointToPointThreads); setEnvVar("DEFAULT_MPI_WORLD_SIZE", mpiSize); setEnvVar("MPI_BASE_PORT", mpiPort); diff --git a/tests/test/util/test_func.cpp b/tests/test/util/test_func.cpp index 29bacc02b..2d381a820 100644 --- a/tests/test/util/test_func.cpp +++ b/tests/test/util/test_func.cpp @@ -30,6 +30,27 @@ TEST_CASE("Test message factory shared", "[util]") REQUIRE(!msg->resultkey().empty()); } +TEST_CASE("Test batch exec factory", "[util]") +{ + int nMessages = 4; + std::shared_ptr req = + faabric::util::batchExecFactory("demo", "echo", nMessages); + + REQUIRE(req->messages().size() == nMessages); + + REQUIRE(req->id() > 0); + + // Expect all messages to have the same app ID by default + int appId = req->messages().at(0).appid(); + REQUIRE(appId > 0); + + for (const auto& m : req->messages()) { + REQUIRE(m.appid() == appId); + REQUIRE(m.user() == "demo"); + REQUIRE(m.function() == "echo"); + } +} + TEST_CASE("Test adding ids to message", "[util]") { faabric::Message msgA; diff --git a/tests/utils/fixtures.h b/tests/utils/fixtures.h index 2826eeb3d..65e490cf8 100644 --- a/tests/utils/fixtures.h +++ b/tests/utils/fixtures.h @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include #include #include @@ -265,4 +267,30 @@ class RemoteMpiTestFixture : public MpiBaseTestFixture faabric::scheduler::MpiWorld otherWorld; }; + +class PointToPointTestFixture +{ + public: + PointToPointTestFixture() + : broker(faabric::transport::getPointToPointBroker()) + { + faabric::util::setMockMode(false); + broker.clear(); + } + + ~PointToPointTestFixture() + { + // Note - here we reset the thread-local cache for the test thread. If + // other threads are used in the tests, they too must do this. + broker.resetThreadLocalCache(); + + faabric::transport::clearSentMessages(); + + broker.clear(); + faabric::util::setMockMode(false); + } + + protected: + faabric::transport::PointToPointBroker& broker; +}; }