From c1e262d9f4a89b8064563f8008e42624ad92d425 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Thu, 30 Sep 2021 14:34:07 +0000 Subject: [PATCH 01/16] Rework message endpoints to accommodate router/ dealer --- include/faabric/transport/MessageEndpoint.h | 27 ++++-- .../faabric/transport/MessageEndpointServer.h | 14 +-- include/faabric/util/config.h | 5 + src/transport/MessageEndpoint.cpp | 94 ++++++++++++------- src/transport/MessageEndpointServer.cpp | 26 ++--- 5 files changed, 106 insertions(+), 60 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 43f5ae8ad..fadc5c016 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -32,25 +32,21 @@ class MessageEndpoint public: MessageEndpoint(const std::string& hostIn, int portIn, int timeoutMsIn); + MessageEndpoint(const std::string& addressIn, int timeoutMsIn); + // Delete assignment and copy-constructor as we need to be very careful with // scoping and same-thread instantiation MessageEndpoint& operator=(const MessageEndpoint&) = delete; MessageEndpoint(const MessageEndpoint& ctx) = delete; - std::string getHost(); - - int getPort(); - protected: - const std::string host; - const int port; const std::string address; const int timeoutMs; const std::thread::id tid; const int id; - zmq::socket_t setUpSocket(zmq::socket_type socketType, int socketPort); + zmq::socket_t setUpSocket(zmq::socket_type socketType); void doSend(zmq::socket_t& socket, const uint8_t* data, @@ -103,6 +99,10 @@ class RecvMessageEndpoint : public MessageEndpoint public: RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType); + RecvMessageEndpoint(std::string inProcLabel, + int timeoutMs, + zmq::socket_type socketType); + virtual ~RecvMessageEndpoint(){}; virtual std::optional recv(int size = 0); @@ -111,6 +111,19 @@ class RecvMessageEndpoint : public MessageEndpoint zmq::socket_t socket; }; +class RouterMessageEndpoint final : public RecvMessageEndpoint +{ + public: + RouterMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); +}; + +class DealerMessageEndpoint final : public RecvMessageEndpoint +{ + public: + DealerMessageEndpoint(const std::string& inProcLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); +}; + class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint { public: diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index a7c3f88a2..60cad086c 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -13,10 +13,10 @@ namespace faabric::transport { // one for asynchronous. Each is run inside its own background thread. class MessageEndpointServer; -class MessageEndpointServerThread +class MessageEndpointServerHandler { public: - MessageEndpointServerThread(MessageEndpointServer* serverIn, bool asyncIn); + MessageEndpointServerHandler(MessageEndpointServer* serverIn, bool asyncIn); void start(std::shared_ptr latch); @@ -26,7 +26,9 @@ class MessageEndpointServerThread MessageEndpointServer* server; bool async = false; - std::thread backgroundThread; + std::thread receiverThread; + + std::vector workerThreads; }; class MessageEndpointServer @@ -51,13 +53,13 @@ class MessageEndpointServer doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) = 0; private: - friend class MessageEndpointServerThread; + friend class MessageEndpointServerHandler; const int asyncPort; const int syncPort; - MessageEndpointServerThread asyncThread; - MessageEndpointServerThread syncThread; + MessageEndpointServerHandler asyncHandler; + MessageEndpointServerHandler syncHandler; AsyncSendMessageEndpoint asyncShutdownSender; SyncSendMessageEndpoint syncShutdownSender; diff --git a/include/faabric/util/config.h b/include/faabric/util/config.h index 42d87cf5b..9503896fd 100644 --- a/include/faabric/util/config.h +++ b/include/faabric/util/config.h @@ -42,6 +42,11 @@ class SystemConfig int endpointPort; int endpointNumThreads; + // Transport + int functionCallServerNumThreads = 4; + int stateServerNumThreads = 2; + int snapshotServerNumThreads = 2; + SystemConfig(); void print(); diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index a62b43d68..995d3d031 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -47,12 +47,8 @@ namespace faabric::transport { -MessageEndpoint::MessageEndpoint(const std::string& hostIn, - int portIn, - int timeoutMsIn) - : host(hostIn) - , port(portIn) - , address("tcp://" + host + ":" + std::to_string(port)) +MessageEndpoint::MessageEndpoint(const std::string& addressIn, int timeoutMsIn) + : address(addressIn) , timeoutMs(timeoutMsIn) , tid(std::this_thread::get_id()) , id(faabric::util::generateGid()) @@ -64,8 +60,15 @@ MessageEndpoint::MessageEndpoint(const std::string& hostIn, } } -zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, - int socketPort) +// Convenience constructor for standard TCP ports +MessageEndpoint::MessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMsIn) + : MessageEndpoint("tcp://" + hostIn + ":" + std::to_string(portIn), + timeoutMsIn) +{} + +zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType) { zmq::socket_t socket; @@ -82,25 +85,37 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType, switch (socketType) { case zmq::socket_type::req: { SPDLOG_TRACE( - "New socket: req {}:{} (timeout {}ms)", host, port, timeoutMs); + "New socket: req {} (timeout {}ms)", address, timeoutMs); CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") break; } case zmq::socket_type::push: { SPDLOG_TRACE( - "New socket: push {}:{} (timeout {}ms)", host, port, timeoutMs); + "New socket: push {} (timeout {}ms)", address, timeoutMs); CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") break; } case zmq::socket_type::pull: { SPDLOG_TRACE( - "New socket: pull {}:{} (timeout {}ms)", host, port, timeoutMs); + "New socket: pull {} (timeout {}ms)", address, timeoutMs); CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; } case zmq::socket_type::rep: { SPDLOG_TRACE( - "New socket: rep {}:{} (timeout {}ms)", host, port, timeoutMs); + "New socket: rep {} (timeout {}ms)", address, timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } + case zmq::socket_type::router: { + SPDLOG_TRACE( + "New socket: router {} (timeout {}ms)", address, timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } + case zmq::socket_type::dealer: { + SPDLOG_TRACE( + "New socket: dealer {} (timeout {}ms)", address, timeoutMs); CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; } @@ -171,7 +186,7 @@ std::optional MessageEndpoint::recvBuffer(zmq::socket_t& socket, } } catch (zmq::error_t& e) { if (e.num() == ZMQ_ETERM) { - SPDLOG_WARN("Endpoint {}:{} received ETERM on recv", host, port); + SPDLOG_WARN("Endpoint {} received ETERM on recv", address); return Message(); } @@ -195,7 +210,7 @@ std::optional MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) } } catch (zmq::error_t& e) { if (e.num() == ZMQ_ETERM) { - SPDLOG_WARN("Endpoint {}:{} received ETERM on recv", host, port); + SPDLOG_WARN("Endpoint {} received ETERM on recv", address); return Message(); } throw; @@ -206,16 +221,6 @@ std::optional MessageEndpoint::recvNoBuffer(zmq::socket_t& socket) return Message(msg); } -std::string MessageEndpoint::getHost() -{ - return host; -} - -int MessageEndpoint::getPort() -{ - return port; -} - // ---------------------------------------------- // ASYNC SEND ENDPOINT // ---------------------------------------------- @@ -225,7 +230,7 @@ AsyncSendMessageEndpoint::AsyncSendMessageEndpoint(const std::string& hostIn, int timeoutMs) : MessageEndpoint(hostIn, portIn, timeoutMs) { - pushSocket = setUpSocket(zmq::socket_type::push, portIn); + pushSocket = setUpSocket(zmq::socket_type::push); } void AsyncSendMessageEndpoint::sendHeader(int header) @@ -238,7 +243,7 @@ void AsyncSendMessageEndpoint::send(const uint8_t* data, size_t dataSize, bool more) { - SPDLOG_TRACE("PUSH {}:{} ({} bytes, more {})", host, port, dataSize, more); + SPDLOG_TRACE("PUSH {} ({} bytes, more {})", address, dataSize, more); doSend(pushSocket, data, dataSize, more); } @@ -251,7 +256,7 @@ SyncSendMessageEndpoint::SyncSendMessageEndpoint(const std::string& hostIn, int timeoutMs) : MessageEndpoint(hostIn, portIn, timeoutMs) { - reqSocket = setUpSocket(zmq::socket_type::req, portIn); + reqSocket = setUpSocket(zmq::socket_type::req); } void SyncSendMessageEndpoint::sendHeader(int header) @@ -262,7 +267,7 @@ void SyncSendMessageEndpoint::sendHeader(int header) void SyncSendMessageEndpoint::sendRaw(const uint8_t* data, size_t dataSize) { - SPDLOG_TRACE("REQ {}:{} ({} bytes)", host, port, dataSize); + SPDLOG_TRACE("REQ {} ({} bytes)", address, dataSize); doSend(reqSocket, data, dataSize, false); } @@ -270,11 +275,11 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, size_t dataSize, bool more) { - SPDLOG_TRACE("REQ {}:{} ({} bytes, more {})", host, port, dataSize, more); + SPDLOG_TRACE("REQ {} ({} bytes, more {})", address, dataSize, more); doSend(reqSocket, data, dataSize, more); // Do the receive - SPDLOG_TRACE("RECV (REQ) {}", port); + SPDLOG_TRACE("RECV (REQ) {}", address); auto msgMaybe = recvNoBuffer(reqSocket); if (!msgMaybe.has_value()) { throw MessageTimeoutException("SendAwaitResponse timeout"); @@ -286,12 +291,20 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, // RECV ENDPOINT // ---------------------------------------------- +RecvMessageEndpoint::RecvMessageEndpoint(std::string inProcLabel, + int timeoutMs, + zmq::socket_type socketType) + : MessageEndpoint("inproc://" + inProcLabel, timeoutMs) +{ + socket = setUpSocket(socketType); +} + RecvMessageEndpoint::RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType) : MessageEndpoint(ANY_HOST, portIn, timeoutMs) { - socket = setUpSocket(socketType, portIn); + socket = setUpSocket(socketType); } std::optional RecvMessageEndpoint::recv(int size) @@ -299,6 +312,19 @@ std::optional RecvMessageEndpoint::recv(int size) return doRecv(socket, size); } +// ---------------------------------------------- +// ROUTER AND DEALER ENDPOINTS +// ---------------------------------------------- + +RouterMessageEndpoint::RouterMessageEndpoint(int portIn, int timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::router) +{} + +DealerMessageEndpoint::DealerMessageEndpoint(const std::string& inProcLabel, + int timeoutMs) + : RecvMessageEndpoint(inProcLabel, timeoutMs, zmq::socket_type::dealer) +{} + // ---------------------------------------------- // ASYNC RECV ENDPOINT // ---------------------------------------------- @@ -309,7 +335,7 @@ AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) std::optional AsyncRecvMessageEndpoint::recv(int size) { - SPDLOG_TRACE("PULL {} ({} bytes)", port, size); + SPDLOG_TRACE("PULL {} ({} bytes)", address, size); return RecvMessageEndpoint::recv(size); } @@ -323,13 +349,13 @@ SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) std::optional SyncRecvMessageEndpoint::recv(int size) { - SPDLOG_TRACE("RECV (REP) {} ({} bytes)", port, size); + SPDLOG_TRACE("RECV (REP) {} ({} bytes)", address, size); return RecvMessageEndpoint::recv(size); } void SyncRecvMessageEndpoint::sendResponse(const uint8_t* data, int size) { - SPDLOG_TRACE("REP {} ({} bytes)", port, size); + SPDLOG_TRACE("REP {} ({} bytes)", address, size); doSend(socket, data, size, false); } } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index dbcac4529..d0a9ba3eb 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -11,17 +11,17 @@ namespace faabric::transport { static const std::vector shutdownHeader = { 0, 0, 1, 1 }; -MessageEndpointServerThread::MessageEndpointServerThread( +MessageEndpointServerHandler::MessageEndpointServerHandler( MessageEndpointServer* serverIn, bool asyncIn) : server(serverIn) , async(asyncIn) {} -void MessageEndpointServerThread::start( +void MessageEndpointServerHandler::start( std::shared_ptr latch) { - backgroundThread = std::thread([this, latch] { + receiverThread = std::thread([this, latch] { std::unique_ptr endpoint = nullptr; int port = -1; @@ -100,18 +100,18 @@ void MessageEndpointServerThread::start( }); } -void MessageEndpointServerThread::join() +void MessageEndpointServerHandler::join() { - if (backgroundThread.joinable()) { - backgroundThread.join(); + if (receiverThread.joinable()) { + receiverThread.join(); } } MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) : asyncPort(asyncPortIn) , syncPort(syncPortIn) - , asyncThread(this, true) - , syncThread(this, false) + , asyncHandler(this, true) + , syncHandler(this, false) , asyncShutdownSender(LOCALHOST, asyncPort) , syncShutdownSender(LOCALHOST, syncPort) {} @@ -123,8 +123,8 @@ void MessageEndpointServer::start() // ready to use). auto startLatch = faabric::util::Latch::create(3); - asyncThread.start(startLatch); - syncThread.start(startLatch); + asyncHandler.start(startLatch); + syncHandler.start(startLatch); startLatch->wait(); } @@ -139,9 +139,9 @@ void MessageEndpointServer::stop() syncShutdownSender.sendRaw(shutdownHeader.data(), shutdownHeader.size()); - // Join the threads - asyncThread.join(); - syncThread.join(); + // Join the handlers + asyncHandler.join(); + syncHandler.join(); } void MessageEndpointServer::setAsyncLatch() From 26ebc9be85f8a128c48b93e6b3d34c2895fe9410 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Thu, 30 Sep 2021 15:40:21 +0000 Subject: [PATCH 02/16] Hooked up router/ dealer --- include/faabric/transport/MessageEndpoint.h | 21 +- .../faabric/transport/MessageEndpointServer.h | 16 +- include/faabric/transport/common.h | 5 + src/scheduler/FunctionCallServer.cpp | 3 +- src/snapshot/SnapshotServer.cpp | 3 +- src/state/StateServer.cpp | 4 +- src/transport/MessageEndpoint.cpp | 43 ++++- src/transport/MessageEndpointServer.cpp | 180 ++++++++++-------- tests/test/transport/test_message_server.cpp | 6 +- 9 files changed, 186 insertions(+), 95 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index fadc5c016..4dcbbf368 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -40,6 +40,8 @@ class MessageEndpoint MessageEndpoint(const MessageEndpoint& ctx) = delete; + bool forceConnectNotBind = false; + protected: const std::string address; const int timeoutMs; @@ -107,26 +109,30 @@ class RecvMessageEndpoint : public MessageEndpoint virtual std::optional recv(int size = 0); - protected: zmq::socket_t socket; }; -class RouterMessageEndpoint final : public RecvMessageEndpoint +class DealerMessageEndpoint final : public RecvMessageEndpoint { public: - RouterMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + DealerMessageEndpoint(const std::string& inProcLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); }; -class DealerMessageEndpoint final : public RecvMessageEndpoint +class RouterMessageEndpoint final : public RecvMessageEndpoint { public: - DealerMessageEndpoint(const std::string& inProcLabel, - int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + RouterMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + void proxyWithDealer(std::unique_ptr& dealer); }; class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint { public: + AsyncRecvMessageEndpoint(const std::string& inprocLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + AsyncRecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); @@ -136,6 +142,9 @@ class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint class SyncRecvMessageEndpoint final : public RecvMessageEndpoint { public: + SyncRecvMessageEndpoint(const std::string& inprocLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + SyncRecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 60cad086c..ea88bacc4 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -7,6 +7,8 @@ #include +#define DEFAULT_MESSAGE_SERVER_THREADS 4 + namespace faabric::transport { // Each server has two underlying sockets, one for synchronous communication and @@ -16,7 +18,10 @@ class MessageEndpointServer; class MessageEndpointServerHandler { public: - MessageEndpointServerHandler(MessageEndpointServer* serverIn, bool asyncIn); + MessageEndpointServerHandler(MessageEndpointServer* serverIn, + bool asyncIn, + const std::string& inprocLabelIn, + int nThreadsIn); void start(std::shared_ptr latch); @@ -25,6 +30,8 @@ class MessageEndpointServerHandler private: MessageEndpointServer* server; bool async = false; + const std::string inprocLabel; + int nThreads; std::thread receiverThread; @@ -34,7 +41,10 @@ class MessageEndpointServerHandler class MessageEndpointServer { public: - MessageEndpointServer(int asyncPortIn, int syncPortIn); + MessageEndpointServer(int asyncPortIn, + int syncPortIn, + const std::string &inprocLabelIn, + int nThreadsIn = DEFAULT_MESSAGE_SERVER_THREADS); virtual void start(); @@ -57,6 +67,8 @@ class MessageEndpointServer const int asyncPort; const int syncPort; + const std::string inprocLabel; + const int nThreads; MessageEndpointServerHandler asyncHandler; MessageEndpointServerHandler syncHandler; diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index 7ee8ee759..4f15a9ee7 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -6,9 +6,14 @@ #define STATE_ASYNC_PORT 8003 #define STATE_SYNC_PORT 8004 +#define STATE_INPROC_LABEL "state" + #define FUNCTION_CALL_ASYNC_PORT 8005 #define FUNCTION_CALL_SYNC_PORT 8006 +#define FUNCTION_INPROC_LABEL "function" + #define SNAPSHOT_ASYNC_PORT 8007 #define SNAPSHOT_SYNC_PORT 8008 +#define SNAPSHOT_INPROC_LABEL "snapshot" #define DEFAULT_MPI_BASE_PORT 8800 diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 5b33627ff..9ed5273f3 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -10,7 +10,8 @@ namespace faabric::scheduler { FunctionCallServer::FunctionCallServer() : faabric::transport::MessageEndpointServer(FUNCTION_CALL_ASYNC_PORT, - FUNCTION_CALL_SYNC_PORT) + FUNCTION_CALL_SYNC_PORT, + FUNCTION_INPROC_LABEL) , scheduler(getScheduler()) {} diff --git a/src/snapshot/SnapshotServer.cpp b/src/snapshot/SnapshotServer.cpp index adcc9922c..b4924ad70 100644 --- a/src/snapshot/SnapshotServer.cpp +++ b/src/snapshot/SnapshotServer.cpp @@ -14,7 +14,8 @@ namespace faabric::snapshot { SnapshotServer::SnapshotServer() : faabric::transport::MessageEndpointServer(SNAPSHOT_ASYNC_PORT, - SNAPSHOT_SYNC_PORT) + SNAPSHOT_SYNC_PORT, + SNAPSHOT_INPROC_LABEL) {} void SnapshotServer::doAsyncRecv(int header, diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index b162fbee6..baf082769 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -13,7 +13,9 @@ namespace faabric::state { StateServer::StateServer(State& stateIn) - : faabric::transport::MessageEndpointServer(STATE_ASYNC_PORT, STATE_SYNC_PORT) + : faabric::transport::MessageEndpointServer(STATE_ASYNC_PORT, + STATE_SYNC_PORT, + STATE_INPROC_LABEL) , state(stateIn) {} diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 995d3d031..1c8e5ff28 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -96,15 +96,24 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType) break; } case zmq::socket_type::pull: { + SPDLOG_TRACE( "New socket: pull {} (timeout {}ms)", address, timeoutMs); - CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + if (forceConnectNotBind) { + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + } else { + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + } break; } case zmq::socket_type::rep: { SPDLOG_TRACE( "New socket: rep {} (timeout {}ms)", address, timeoutMs); - CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + if (forceConnectNotBind) { + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + } else { + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + } break; } case zmq::socket_type::router: { @@ -296,6 +305,10 @@ RecvMessageEndpoint::RecvMessageEndpoint(std::string inProcLabel, zmq::socket_type socketType) : MessageEndpoint("inproc://" + inProcLabel, timeoutMs) { + // Because this is listening to a local dealer we have to force + // a connect rather than a bind + forceConnectNotBind = true; + socket = setUpSocket(socketType); } @@ -316,19 +329,32 @@ std::optional RecvMessageEndpoint::recv(int size) // ROUTER AND DEALER ENDPOINTS // ---------------------------------------------- -RouterMessageEndpoint::RouterMessageEndpoint(int portIn, int timeoutMs) - : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::router) -{} - DealerMessageEndpoint::DealerMessageEndpoint(const std::string& inProcLabel, int timeoutMs) : RecvMessageEndpoint(inProcLabel, timeoutMs, zmq::socket_type::dealer) {} +RouterMessageEndpoint::RouterMessageEndpoint(int portIn, int timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::router) +{} + +void RouterMessageEndpoint::proxyWithDealer( + std::unique_ptr& dealer) +{ + // Connect this router to a dealer via a queue + zmq::proxy(socket, dealer->socket); +} + // ---------------------------------------------- // ASYNC RECV ENDPOINT // ---------------------------------------------- +AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint( + const std::string& inprocLabel, + int timeoutMs) + : RecvMessageEndpoint(inprocLabel, timeoutMs, zmq::socket_type::pull) +{} + AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) {} @@ -343,6 +369,11 @@ std::optional AsyncRecvMessageEndpoint::recv(int size) // SYNC RECV ENDPOINT // ---------------------------------------------- +SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(const std::string& inprocLabel, + int timeoutMs) + : RecvMessageEndpoint(inprocLabel, timeoutMs, zmq::socket_type::rep) +{} + SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::rep) {} diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index d0a9ba3eb..3087784e9 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -1,3 +1,4 @@ +#include "faabric/transport/MessageEndpoint.h" #include #include #include @@ -13,90 +14,114 @@ static const std::vector shutdownHeader = { 0, 0, 1, 1 }; MessageEndpointServerHandler::MessageEndpointServerHandler( MessageEndpointServer* serverIn, - bool asyncIn) + bool asyncIn, + const std::string& inprocLabelIn, + int nThreadsIn) : server(serverIn) , async(asyncIn) + , inprocLabel(inprocLabelIn) + , nThreads(nThreadsIn) {} void MessageEndpointServerHandler::start( std::shared_ptr latch) { receiverThread = std::thread([this, latch] { - std::unique_ptr endpoint = nullptr; - int port = -1; - - if (async) { - port = server->asyncPort; - endpoint = std::make_unique(port); - } else { - port = server->syncPort; - endpoint = std::make_unique(port); - } - - latch->wait(); - - while (true) { - // Receive header and body - std::optional headerMessageMaybe = endpoint->recv(); - if (!headerMessageMaybe.has_value()) { - SPDLOG_TRACE("Server on port {}, looping after no message", - port); - continue; - } - Message& headerMessage = headerMessageMaybe.value(); - - if (headerMessage.size() == shutdownHeader.size()) { - if (headerMessage.dataCopy() == shutdownHeader) { - SPDLOG_TRACE("Server on {} received shutdown message", - port); - break; + // TODO - pass through proper inproc label to here + // Make it a class property set through the constructor + int port = async ? server->asyncPort : server->syncPort; + + // Set up router/ dealer + auto router = std::make_unique(port); + auto dealer = std::make_unique(inprocLabel); + + // Lauch worker threads + for (int i = 0; i < nThreads; i++) { + workerThreads.emplace_back([this, port] { + std::unique_ptr endpoint = nullptr; + + if (async) { + endpoint = + std::make_unique(inprocLabel); + } else { + endpoint = + std::make_unique(inprocLabel); } - } - - if (!headerMessage.more()) { - throw std::runtime_error("Header sent without SNDMORE flag"); - } - - std::optional bodyMaybe = endpoint->recv(); - if (!bodyMaybe.has_value()) { - SPDLOG_ERROR("Server on port {}, got header, timed out on body", - port); - throw MessageTimeoutException( - "Server, got header, timed out on body"); - } - Message& body = bodyMaybe.value(); - if (body.more()) { - throw std::runtime_error("Body sent with SNDMORE flag"); - } - - assert(headerMessage.size() == sizeof(uint8_t)); - uint8_t header = static_cast(*headerMessage.data()); - - if (async) { - // Server-specific async handling - server->doAsyncRecv(header, body.udata(), body.size()); - } else { - // Server-specific sync handling - std::unique_ptr resp = - server->doSyncRecv(header, body.udata(), body.size()); - size_t respSize = resp->ByteSizeLong(); - - uint8_t buffer[respSize]; - if (!resp->SerializeToArray(buffer, respSize)) { - throw std::runtime_error("Error serialising message"); + + while (true) { + // Receive header and body + std::optional headerMessageMaybe = + endpoint->recv(); + if (!headerMessageMaybe.has_value()) { + SPDLOG_TRACE( + "Server on port {}, looping after no message", port); + continue; + } + Message& headerMessage = headerMessageMaybe.value(); + + if (headerMessage.size() == shutdownHeader.size()) { + if (headerMessage.dataCopy() == shutdownHeader) { + SPDLOG_TRACE( + "Server on {} received shutdown message", port); + break; + } + } + + if (!headerMessage.more()) { + throw std::runtime_error( + "Header sent without SNDMORE flag"); + } + + std::optional bodyMaybe = endpoint->recv(); + if (!bodyMaybe.has_value()) { + SPDLOG_ERROR( + "Server on port {}, got header, timed out on body", + port); + throw MessageTimeoutException( + "Server, got header, timed out on body"); + } + Message& body = bodyMaybe.value(); + if (body.more()) { + throw std::runtime_error("Body sent with SNDMORE flag"); + } + + assert(headerMessage.size() == sizeof(uint8_t)); + uint8_t header = + static_cast(*headerMessage.data()); + + if (async) { + // Server-specific async handling + server->doAsyncRecv(header, body.udata(), body.size()); + } else { + // Server-specific sync handling + std::unique_ptr resp = + server->doSyncRecv(header, body.udata(), body.size()); + size_t respSize = resp->ByteSizeLong(); + + uint8_t buffer[respSize]; + if (!resp->SerializeToArray(buffer, respSize)) { + throw std::runtime_error( + "Error serialising message"); + } + + // Return the response + static_cast(endpoint.get()) + ->sendResponse(buffer, respSize); + } + + // Wait on the async latch if necessary + if (server->asyncLatch != nullptr) { + SPDLOG_TRACE("Server thread waiting on async latch"); + server->asyncLatch->wait(); + } } + }); + } - // Return the response - static_cast(endpoint.get()) - ->sendResponse(buffer, respSize); - } + // Connect the router and dealer + router->proxyWithDealer(dealer); - // Wait on the async latch if necessary - if (server->asyncLatch != nullptr) { - SPDLOG_TRACE("Server thread waiting on async latch"); - server->asyncLatch->wait(); - } - } + latch->wait(); }); } @@ -107,11 +132,16 @@ void MessageEndpointServerHandler::join() } } -MessageEndpointServer::MessageEndpointServer(int asyncPortIn, int syncPortIn) +MessageEndpointServer::MessageEndpointServer(int asyncPortIn, + int syncPortIn, + const std::string& inprocLabelIn, + int nThreadsIn) : asyncPort(asyncPortIn) , syncPort(syncPortIn) - , asyncHandler(this, true) - , syncHandler(this, false) + , inprocLabel(inprocLabelIn) + , nThreads(nThreadsIn) + , asyncHandler(this, true, inprocLabel + "async", nThreadsIn) + , syncHandler(this, false, inprocLabel, nThreadsIn) , asyncShutdownSender(LOCALHOST, asyncPort) , syncShutdownSender(LOCALHOST, syncPort) {} diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 2449d55b8..b352ada11 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -20,7 +20,7 @@ class DummyServer final : public MessageEndpointServer { public: DummyServer() - : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-dummy") {} std::atomic messageCount = 0; @@ -46,7 +46,7 @@ class EchoServer final : public MessageEndpointServer { public: EchoServer() - : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-echo") {} protected: @@ -75,7 +75,7 @@ class SleepServer final : public MessageEndpointServer int delayMs = 1000; SleepServer() - : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC) + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-sleep") {} protected: From bd37e72cb3a4aa2d351015489b09f51234f9606c Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 1 Oct 2021 15:24:35 +0000 Subject: [PATCH 03/16] Progressing on async multi-threading --- include/faabric/transport/MessageEndpoint.h | 21 +++-- .../faabric/transport/MessageEndpointServer.h | 8 +- src/transport/MessageEndpoint.cpp | 39 ++++++--- src/transport/MessageEndpointServer.cpp | 81 +++++++++++++------ .../scheduler/test_function_client_server.cpp | 8 +- tests/test/transport/test_message_server.cpp | 12 +-- 6 files changed, 118 insertions(+), 51 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 4dcbbf368..a76a82889 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -42,6 +42,8 @@ class MessageEndpoint bool forceConnectNotBind = false; + std::string getAddress(); + protected: const std::string address; const int timeoutMs; @@ -99,7 +101,10 @@ class SyncSendMessageEndpoint final : public MessageEndpoint class RecvMessageEndpoint : public MessageEndpoint { public: - RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType); + RecvMessageEndpoint(int portIn, + int timeoutMs, + zmq::socket_type socketType, + bool connectNotBind = false); RecvMessageEndpoint(std::string inProcLabel, int timeoutMs, @@ -127,18 +132,24 @@ class RouterMessageEndpoint final : public RecvMessageEndpoint void proxyWithDealer(std::unique_ptr& dealer); }; -class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint +class AsyncRecvMessageEndpoint : public RecvMessageEndpoint { public: - AsyncRecvMessageEndpoint(const std::string& inprocLabel, - int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - AsyncRecvMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); std::optional recv(int size = 0) override; }; +class MultiAsyncRecvMessageEndpoint final : public RecvMessageEndpoint +{ + public: + MultiAsyncRecvMessageEndpoint(int portIn, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + std::optional recv(int size = 0) override; +}; + class SyncRecvMessageEndpoint final : public RecvMessageEndpoint { public: diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index ea88bacc4..6f0196240 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -43,16 +43,16 @@ class MessageEndpointServer public: MessageEndpointServer(int asyncPortIn, int syncPortIn, - const std::string &inprocLabelIn, + const std::string& inprocLabelIn, int nThreadsIn = DEFAULT_MESSAGE_SERVER_THREADS); virtual void start(); virtual void stop(); - void setAsyncLatch(); + void setWorkerLatch(); - void awaitAsyncLatch(); + void awaitWorkerLatch(); protected: virtual void doAsyncRecv(int header, @@ -76,6 +76,6 @@ class MessageEndpointServer AsyncSendMessageEndpoint asyncShutdownSender; SyncSendMessageEndpoint syncShutdownSender; - std::shared_ptr asyncLatch; + std::shared_ptr workerLatch; }; } diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 1c8e5ff28..fc895576f 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,3 +1,4 @@ +#include "zmq.hpp" #include #include #include @@ -68,6 +69,11 @@ MessageEndpoint::MessageEndpoint(const std::string& hostIn, timeoutMsIn) {} +std::string MessageEndpoint::getAddress() +{ + return address; +} + zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType) { zmq::socket_t socket; @@ -96,7 +102,6 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType) break; } case zmq::socket_type::pull: { - SPDLOG_TRACE( "New socket: pull {} (timeout {}ms)", address, timeoutMs); if (forceConnectNotBind) { @@ -305,8 +310,8 @@ RecvMessageEndpoint::RecvMessageEndpoint(std::string inProcLabel, zmq::socket_type socketType) : MessageEndpoint("inproc://" + inProcLabel, timeoutMs) { - // Because this is listening to a local dealer we have to force - // a connect rather than a bind + // All inproc sockets will be listening to a dealer port in our case, so we + // always connect, not bind forceConnectNotBind = true; socket = setUpSocket(socketType); @@ -314,9 +319,13 @@ RecvMessageEndpoint::RecvMessageEndpoint(std::string inProcLabel, RecvMessageEndpoint::RecvMessageEndpoint(int portIn, int timeoutMs, - zmq::socket_type socketType) + zmq::socket_type socketType, + bool connectNotBind) : MessageEndpoint(ANY_HOST, portIn, timeoutMs) { + // In some cases (e.g. many PULLs, we may want to connect, not bind) + forceConnectNotBind = connectNotBind; + socket = setUpSocket(socketType); } @@ -342,6 +351,7 @@ void RouterMessageEndpoint::proxyWithDealer( std::unique_ptr& dealer) { // Connect this router to a dealer via a queue + SPDLOG_TRACE("Proxying {} to {}", address, dealer->getAddress()); zmq::proxy(socket, dealer->socket); } @@ -349,12 +359,6 @@ void RouterMessageEndpoint::proxyWithDealer( // ASYNC RECV ENDPOINT // ---------------------------------------------- -AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint( - const std::string& inprocLabel, - int timeoutMs) - : RecvMessageEndpoint(inprocLabel, timeoutMs, zmq::socket_type::pull) -{} - AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) {} @@ -365,6 +369,21 @@ std::optional AsyncRecvMessageEndpoint::recv(int size) return RecvMessageEndpoint::recv(size); } +// ---------------------------------------------- +// MULTI ASYNC RECV ENDPOINT +// ---------------------------------------------- + +MultiAsyncRecvMessageEndpoint::MultiAsyncRecvMessageEndpoint(int portIn, + int timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull, true) +{} + +std::optional MultiAsyncRecvMessageEndpoint::recv(int size) +{ + SPDLOG_TRACE("PULL (multi) {} ({} bytes)", address, size); + return RecvMessageEndpoint::recv(size); +} + // ---------------------------------------------- // SYNC RECV ENDPOINT // ---------------------------------------------- diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 3087784e9..b29917d7b 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -26,14 +26,20 @@ MessageEndpointServerHandler::MessageEndpointServerHandler( void MessageEndpointServerHandler::start( std::shared_ptr latch) { + // See 0MQ docs on multi-threaded server for req/ rep: + // https://zguide.zeromq.org/docs/chapter2/#Multithreading-with-ZeroMQ receiverThread = std::thread([this, latch] { - // TODO - pass through proper inproc label to here - // Make it a class property set through the constructor int port = async ? server->asyncPort : server->syncPort; - // Set up router/ dealer - auto router = std::make_unique(port); - auto dealer = std::make_unique(inprocLabel); + // For sync, we need to set up a router and dealer + std::unique_ptr router = nullptr; + std::unique_ptr dealer = nullptr; + + if (!async) { + // Set up router/ dealer + router = std::make_unique(port); + dealer = std::make_unique(inprocLabel); + } // Lauch worker threads for (int i = 0; i < nThreads; i++) { @@ -41,9 +47,11 @@ void MessageEndpointServerHandler::start( std::unique_ptr endpoint = nullptr; if (async) { + // Async workers have a PULL socket endpoint = - std::make_unique(inprocLabel); + std::make_unique(port); } else { + // Sync workers have an in-proc REP socket endpoint = std::make_unique(inprocLabel); } @@ -53,8 +61,8 @@ void MessageEndpointServerHandler::start( std::optional headerMessageMaybe = endpoint->recv(); if (!headerMessageMaybe.has_value()) { - SPDLOG_TRACE( - "Server on port {}, looping after no message", port); + SPDLOG_TRACE("Server on {}, looping after no message", + endpoint->getAddress()); continue; } Message& headerMessage = headerMessageMaybe.value(); @@ -62,7 +70,13 @@ void MessageEndpointServerHandler::start( if (headerMessage.size() == shutdownHeader.size()) { if (headerMessage.dataCopy() == shutdownHeader) { SPDLOG_TRACE( - "Server on {} received shutdown message", port); + "Server on {} received shutdown message", + endpoint->getAddress()); + + // Allow things to wait on shutdown + if (server->workerLatch != nullptr) { + server->workerLatch->wait(); + } break; } } @@ -76,10 +90,11 @@ void MessageEndpointServerHandler::start( if (!bodyMaybe.has_value()) { SPDLOG_ERROR( "Server on port {}, got header, timed out on body", - port); + endpoint->getAddress()); throw MessageTimeoutException( "Server, got header, timed out on body"); } + Message& body = bodyMaybe.value(); if (body.more()) { throw std::runtime_error("Body sent with SNDMORE flag"); @@ -110,23 +125,34 @@ void MessageEndpointServerHandler::start( } // Wait on the async latch if necessary - if (server->asyncLatch != nullptr) { + if (server->workerLatch != nullptr) { SPDLOG_TRACE("Server thread waiting on async latch"); - server->asyncLatch->wait(); + server->workerLatch->wait(); } } }); } - // Connect the router and dealer - router->proxyWithDealer(dealer); - + // Wait on the latch latch->wait(); + + // Connect the router and dealer if sync + if (!async) { + router->proxyWithDealer(dealer); + } }); } void MessageEndpointServerHandler::join() { + // Join each worker + for (auto& t : workerThreads) { + if (t.joinable()) { + t.join(); + } + } + + // Join the receiver thread if (receiverThread.joinable()) { receiverThread.join(); } @@ -140,7 +166,7 @@ MessageEndpointServer::MessageEndpointServer(int asyncPortIn, , syncPort(syncPortIn) , inprocLabel(inprocLabelIn) , nThreads(nThreadsIn) - , asyncHandler(this, true, inprocLabel + "async", nThreadsIn) + , asyncHandler(this, true, inprocLabel + "-async", nThreadsIn) , syncHandler(this, false, inprocLabel, nThreadsIn) , asyncShutdownSender(LOCALHOST, asyncPort) , syncShutdownSender(LOCALHOST, syncPort) @@ -165,26 +191,35 @@ void MessageEndpointServer::stop() SPDLOG_TRACE( "Server sending shutdown messages to ports {} {}", asyncPort, syncPort); - asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); + for (int i = 0; i < nThreads; i++) { + setWorkerLatch(); + asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); + awaitWorkerLatch(); + } - syncShutdownSender.sendRaw(shutdownHeader.data(), shutdownHeader.size()); + for (int i = 0; i < nThreads; i++) { + setWorkerLatch(); + syncShutdownSender.sendRaw(shutdownHeader.data(), + shutdownHeader.size()); + awaitWorkerLatch(); + } // Join the handlers asyncHandler.join(); syncHandler.join(); } -void MessageEndpointServer::setAsyncLatch() +void MessageEndpointServer::setWorkerLatch() { - asyncLatch = faabric::util::Latch::create(2); + workerLatch = faabric::util::Latch::create(2); } -void MessageEndpointServer::awaitAsyncLatch() +void MessageEndpointServer::awaitWorkerLatch() { SPDLOG_TRACE("Waiting on async latch for port {}", asyncPort); - asyncLatch->wait(); + workerLatch->wait(); SPDLOG_TRACE("Finished async latch for port {}", asyncPort); - asyncLatch = nullptr; + workerLatch = nullptr; } } diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index ccbfca203..e96a3b86f 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -220,9 +220,9 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") *reqA.mutable_function() = msg; // Check that nothing's happened - server.setAsyncLatch(); + server.setWorkerLatch(); cli.unregister(reqA); - server.awaitAsyncLatch(); + server.awaitWorkerLatch(); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 1); // Make the request to unregister the actual host @@ -230,9 +230,9 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") reqB.set_host(otherHost); *reqB.mutable_function() = msg; - server.setAsyncLatch(); + server.setWorkerLatch(); cli.unregister(reqB); - server.awaitAsyncLatch(); + server.awaitWorkerLatch(); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 0); diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index b352ada11..2c2fb4906 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -101,11 +101,13 @@ class SleepServer final : public MessageEndpointServer namespace tests { -TEST_CASE("Test send one message to server", "[transport]") +TEST_CASE("Test sending one message to server", "[transport]") { DummyServer server; server.start(); + SPDLOG_DEBUG("Dummy server started"); + REQUIRE(server.messageCount == 0); MessageEndpointClient cli(LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); @@ -114,16 +116,16 @@ TEST_CASE("Test send one message to server", "[transport]") std::string body = "body"; const uint8_t* bodyMsg = BYTES_CONST(body.c_str()); - server.setAsyncLatch(); + server.setWorkerLatch(); cli.asyncSend(0, bodyMsg, body.size()); - server.awaitAsyncLatch(); + server.awaitWorkerLatch(); REQUIRE(server.messageCount == 1); server.stop(); } -TEST_CASE("Test send response to client", "[transport]") +TEST_CASE("Test sending response to client", "[transport]") { EchoServer server; server.start(); @@ -137,7 +139,7 @@ TEST_CASE("Test send response to client", "[transport]") faabric::StatePart response; cli.syncSend(0, BYTES(expectedMsg.data()), expectedMsg.size(), &response); - assert(response.data() == expectedMsg); + REQUIRE(response.data() == expectedMsg); server.stop(); } From a89f660ea04747eb88c1e584bf4b2ac0c002183d Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 1 Oct 2021 17:15:52 +0000 Subject: [PATCH 04/16] Connect fan in/ out --- include/faabric/transport/MessageEndpoint.h | 61 +++-- src/transport/MessageEndpoint.cpp | 235 +++++++++++++------- src/transport/MessageEndpointServer.cpp | 54 +++-- 3 files changed, 236 insertions(+), 114 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index a76a82889..31a18493d 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -24,6 +24,12 @@ namespace faabric::transport { +enum MessageEndpointConnectType +{ + BIND = 0, + CONNECT = 1, +}; + // Note: sockets must be open-ed and close-ed from the _same_ thread. In a given // communication group, one socket may bind, and all the rest must connect. // Order does not matter. @@ -40,8 +46,6 @@ class MessageEndpoint MessageEndpoint(const MessageEndpoint& ctx) = delete; - bool forceConnectNotBind = false; - std::string getAddress(); protected: @@ -50,7 +54,8 @@ class MessageEndpoint const std::thread::id tid; const int id; - zmq::socket_t setUpSocket(zmq::socket_type socketType); + zmq::socket_t setUpSocket(zmq::socket_type socketType, + MessageEndpointConnectType connectType); void doSend(zmq::socket_t& socket, const uint8_t* data, @@ -75,7 +80,6 @@ class AsyncSendMessageEndpoint final : public MessageEndpoint void send(const uint8_t* data, size_t dataSize, bool more = false); - private: zmq::socket_t pushSocket; }; @@ -101,14 +105,12 @@ class SyncSendMessageEndpoint final : public MessageEndpoint class RecvMessageEndpoint : public MessageEndpoint { public: - RecvMessageEndpoint(int portIn, - int timeoutMs, - zmq::socket_type socketType, - bool connectNotBind = false); + RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType); RecvMessageEndpoint(std::string inProcLabel, int timeoutMs, - zmq::socket_type socketType); + zmq::socket_type socketType, + MessageEndpointConnectType connectType); virtual ~RecvMessageEndpoint(){}; @@ -117,35 +119,52 @@ class RecvMessageEndpoint : public MessageEndpoint zmq::socket_t socket; }; -class DealerMessageEndpoint final : public RecvMessageEndpoint +class AsyncFanOutMessageEndpoint final : public MessageEndpoint { public: - DealerMessageEndpoint(const std::string& inProcLabel, - int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + AsyncFanOutMessageEndpoint(const std::string& inProcLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + void sendHeader(int header); + + void send(const uint8_t* data, size_t dataSize, bool more = false); + + zmq::socket_t socket; }; -class RouterMessageEndpoint final : public RecvMessageEndpoint +class AsyncFanInMessageEndpoint final : public RecvMessageEndpoint { public: - RouterMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + AsyncFanInMessageEndpoint(int portIn, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - void proxyWithDealer(std::unique_ptr& dealer); + void attachFanOut(std::unique_ptr& dealer); }; -class AsyncRecvMessageEndpoint : public RecvMessageEndpoint +class SyncFanOutMessageEndpoint final : public RecvMessageEndpoint { public: - AsyncRecvMessageEndpoint(int portIn, + SyncFanOutMessageEndpoint(const std::string& inProcLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); +}; + +class SyncFanInMessageEndpoint final : public RecvMessageEndpoint +{ + public: + SyncFanInMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - std::optional recv(int size = 0) override; + void attachFanOut(std::unique_ptr& dealer); }; -class MultiAsyncRecvMessageEndpoint final : public RecvMessageEndpoint +class AsyncRecvMessageEndpoint : public RecvMessageEndpoint { public: - MultiAsyncRecvMessageEndpoint(int portIn, - int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + AsyncRecvMessageEndpoint(const std::string& inprocLabel, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); + + AsyncRecvMessageEndpoint(int portIn, + int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); std::optional recv(int size = 0) override; }; diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index fc895576f..00bd0cb81 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -74,7 +74,9 @@ std::string MessageEndpoint::getAddress() return address; } -zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType) +zmq::socket_t MessageEndpoint::setUpSocket( + zmq::socket_type socketType, + MessageEndpointConnectType connectType) { zmq::socket_t socket; @@ -88,53 +90,95 @@ zmq::socket_t MessageEndpoint::setUpSocket(zmq::socket_type socketType) // Note - setting linger here is essential to avoid infinite hangs socket.set(zmq::sockopt::linger, LINGER_MS); - switch (socketType) { - case zmq::socket_type::req: { - SPDLOG_TRACE( - "New socket: req {} (timeout {}ms)", address, timeoutMs); - CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") - break; - } - case zmq::socket_type::push: { - SPDLOG_TRACE( - "New socket: push {} (timeout {}ms)", address, timeoutMs); - CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") - break; - } - case zmq::socket_type::pull: { - SPDLOG_TRACE( - "New socket: pull {} (timeout {}ms)", address, timeoutMs); - if (forceConnectNotBind) { - CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") - } else { - CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + switch (connectType) { + case (MessageEndpointConnectType::BIND): { + switch (socketType) { + case zmq::socket_type::push: { + SPDLOG_TRACE("Bind socket: push {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } + case zmq::socket_type::pull: { + SPDLOG_TRACE("Bind socket: pull {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } + case zmq::socket_type::rep: { + SPDLOG_TRACE( + "Bind socket: rep {} (timeout {}ms)", address, timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } + case zmq::socket_type::router: { + SPDLOG_TRACE("Bind socket: router {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } + case zmq::socket_type::dealer: { + SPDLOG_TRACE("Bind socket: dealer {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } + default: { + SPDLOG_ERROR( + "Invalid bind socket type {} ({})", socketType, address); + throw std::runtime_error( + "Binding with invalid socket type"); + } } break; } - case zmq::socket_type::rep: { - SPDLOG_TRACE( - "New socket: rep {} (timeout {}ms)", address, timeoutMs); - if (forceConnectNotBind) { - CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") - } else { - CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + case (MessageEndpointConnectType::CONNECT): { + switch (socketType) { + case zmq::socket_type::req: { + SPDLOG_TRACE("Connect socket: req {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + break; + } + case zmq::socket_type::push: { + SPDLOG_TRACE("Connect socket: push {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + break; + } + case zmq::socket_type::pull: { + SPDLOG_TRACE("Connect socket: pull {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + break; + } + case zmq::socket_type::rep: { + SPDLOG_TRACE("Connect socket: rep {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + break; + } + default: { + SPDLOG_ERROR("Invalid connect socket type {} ({})", + socketType, + address); + throw std::runtime_error( + "Connecting with unrecognized socket type"); + } } break; } - case zmq::socket_type::router: { - SPDLOG_TRACE( - "New socket: router {} (timeout {}ms)", address, timeoutMs); - CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") - break; - } - case zmq::socket_type::dealer: { - SPDLOG_TRACE( - "New socket: dealer {} (timeout {}ms)", address, timeoutMs); - CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") - break; - } default: { - throw std::runtime_error("Opening unrecognized socket type"); + SPDLOG_ERROR("Unrecognised socket connect type {}", connectType); + throw std::runtime_error("Unrecognised connect type"); } } @@ -244,7 +288,8 @@ AsyncSendMessageEndpoint::AsyncSendMessageEndpoint(const std::string& hostIn, int timeoutMs) : MessageEndpoint(hostIn, portIn, timeoutMs) { - pushSocket = setUpSocket(zmq::socket_type::push); + pushSocket = + setUpSocket(zmq::socket_type::push, MessageEndpointConnectType::CONNECT); } void AsyncSendMessageEndpoint::sendHeader(int header) @@ -270,7 +315,8 @@ SyncSendMessageEndpoint::SyncSendMessageEndpoint(const std::string& hostIn, int timeoutMs) : MessageEndpoint(hostIn, portIn, timeoutMs) { - reqSocket = setUpSocket(zmq::socket_type::req); + reqSocket = + setUpSocket(zmq::socket_type::req, MessageEndpointConnectType::CONNECT); } void SyncSendMessageEndpoint::sendHeader(int header) @@ -307,26 +353,19 @@ Message SyncSendMessageEndpoint::sendAwaitResponse(const uint8_t* data, RecvMessageEndpoint::RecvMessageEndpoint(std::string inProcLabel, int timeoutMs, - zmq::socket_type socketType) + zmq::socket_type socketType, + MessageEndpointConnectType connectType) : MessageEndpoint("inproc://" + inProcLabel, timeoutMs) { - // All inproc sockets will be listening to a dealer port in our case, so we - // always connect, not bind - forceConnectNotBind = true; - - socket = setUpSocket(socketType); + socket = setUpSocket(socketType, connectType); } RecvMessageEndpoint::RecvMessageEndpoint(int portIn, int timeoutMs, - zmq::socket_type socketType, - bool connectNotBind) + zmq::socket_type socketType) : MessageEndpoint(ANY_HOST, portIn, timeoutMs) { - // In some cases (e.g. many PULLs, we may want to connect, not bind) - forceConnectNotBind = connectNotBind; - - socket = setUpSocket(socketType); + socket = setUpSocket(socketType, MessageEndpointConnectType::BIND); } std::optional RecvMessageEndpoint::recv(int size) @@ -335,52 +374,89 @@ std::optional RecvMessageEndpoint::recv(int size) } // ---------------------------------------------- -// ROUTER AND DEALER ENDPOINTS +// ASYNC FAN IN AND FAN OUT // ---------------------------------------------- -DealerMessageEndpoint::DealerMessageEndpoint(const std::string& inProcLabel, - int timeoutMs) - : RecvMessageEndpoint(inProcLabel, timeoutMs, zmq::socket_type::dealer) -{} +AsyncFanOutMessageEndpoint::AsyncFanOutMessageEndpoint( + const std::string& inprocLabel, + int timeoutMs) + : MessageEndpoint("inproc://" + inprocLabel, timeoutMs) +{ + socket = + setUpSocket(zmq::socket_type::push, MessageEndpointConnectType::BIND); +} -RouterMessageEndpoint::RouterMessageEndpoint(int portIn, int timeoutMs) - : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::router) +void AsyncFanOutMessageEndpoint::sendHeader(int header) +{ + uint8_t headerBytes = static_cast(header); + doSend(socket, &headerBytes, sizeof(headerBytes), true); +} + +void AsyncFanOutMessageEndpoint::send(const uint8_t* data, + size_t dataSize, + bool more) +{ + SPDLOG_TRACE("PUSH {} ({} bytes, more {})", address, dataSize, more); + doSend(socket, data, dataSize, more); +} + +AsyncFanInMessageEndpoint::AsyncFanInMessageEndpoint(int portIn, int timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) {} -void RouterMessageEndpoint::proxyWithDealer( - std::unique_ptr& dealer) +void AsyncFanInMessageEndpoint::attachFanOut( + std::unique_ptr& dealer) { - // Connect this router to a dealer via a queue + // Connect this to a fan out SPDLOG_TRACE("Proxying {} to {}", address, dealer->getAddress()); zmq::proxy(socket, dealer->socket); } // ---------------------------------------------- -// ASYNC RECV ENDPOINT +// SYNC FAN IN AND FAN OUT // ---------------------------------------------- -AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) - : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) +SyncFanOutMessageEndpoint::SyncFanOutMessageEndpoint( + const std::string& inProcLabel, + int timeoutMs) + : RecvMessageEndpoint(inProcLabel, + timeoutMs, + zmq::socket_type::dealer, + MessageEndpointConnectType::BIND) {} -std::optional AsyncRecvMessageEndpoint::recv(int size) +SyncFanInMessageEndpoint::SyncFanInMessageEndpoint(int portIn, int timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::router) +{} + +void SyncFanInMessageEndpoint::attachFanOut( + std::unique_ptr& dealer) { - SPDLOG_TRACE("PULL {} ({} bytes)", address, size); - return RecvMessageEndpoint::recv(size); + // Connect this to a fan out + SPDLOG_TRACE("Proxying {} to {}", address, dealer->getAddress()); + zmq::proxy(socket, dealer->socket); } // ---------------------------------------------- -// MULTI ASYNC RECV ENDPOINT +// ASYNC RECV ENDPOINT // ---------------------------------------------- -MultiAsyncRecvMessageEndpoint::MultiAsyncRecvMessageEndpoint(int portIn, - int timeoutMs) - : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull, true) +AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint( + const std::string& inprocLabel, + int timeoutMs) + : RecvMessageEndpoint(inprocLabel, + timeoutMs, + zmq::socket_type::pull, + MessageEndpointConnectType::CONNECT) {} -std::optional MultiAsyncRecvMessageEndpoint::recv(int size) +AsyncRecvMessageEndpoint::AsyncRecvMessageEndpoint(int portIn, int timeoutMs) + : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) +{} + +std::optional AsyncRecvMessageEndpoint::recv(int size) { - SPDLOG_TRACE("PULL (multi) {} ({} bytes)", address, size); + SPDLOG_TRACE("PULL {} ({} bytes)", address, size); return RecvMessageEndpoint::recv(size); } @@ -390,7 +466,10 @@ std::optional MultiAsyncRecvMessageEndpoint::recv(int size) SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(const std::string& inprocLabel, int timeoutMs) - : RecvMessageEndpoint(inprocLabel, timeoutMs, zmq::socket_type::rep) + : RecvMessageEndpoint(inprocLabel, + timeoutMs, + zmq::socket_type::rep, + MessageEndpointConnectType::CONNECT) {} SyncRecvMessageEndpoint::SyncRecvMessageEndpoint(int portIn, int timeoutMs) diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index b29917d7b..f6b1c8a0e 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -26,30 +26,44 @@ MessageEndpointServerHandler::MessageEndpointServerHandler( void MessageEndpointServerHandler::start( std::shared_ptr latch) { - // See 0MQ docs on multi-threaded server for req/ rep: + // For both sync and async, we want to fan out the messages to multiple + // worker threads. + // For sync, we use the router/ dealer pattern: // https://zguide.zeromq.org/docs/chapter2/#Multithreading-with-ZeroMQ + // For push/ pull we receive on a pull socket, then proxy with another push + // to multiple downstream pull sockets + // In both cases, the downstream fan-out is done over inproc sockets. receiverThread = std::thread([this, latch] { int port = async ? server->asyncPort : server->syncPort; - // For sync, we need to set up a router and dealer - std::unique_ptr router = nullptr; - std::unique_ptr dealer = nullptr; + // Sync: router and dealer + std::unique_ptr syncFanIn = nullptr; + std::unique_ptr syncFanOut = nullptr; - if (!async) { + // Async: pull/ push pair + std::unique_ptr asyncFanIn = nullptr; + std::unique_ptr asyncFanOut = nullptr; + + if (async) { + asyncFanIn = std::make_unique(port); + asyncFanOut = + std::make_unique(inprocLabel); + } else { // Set up router/ dealer - router = std::make_unique(port); - dealer = std::make_unique(inprocLabel); + syncFanIn = std::make_unique(port); + syncFanOut = + std::make_unique(inprocLabel); } // Lauch worker threads for (int i = 0; i < nThreads; i++) { - workerThreads.emplace_back([this, port] { + workerThreads.emplace_back([this] { std::unique_ptr endpoint = nullptr; if (async) { // Async workers have a PULL socket endpoint = - std::make_unique(port); + std::make_unique(inprocLabel); } else { // Sync workers have an in-proc REP socket endpoint = @@ -65,6 +79,7 @@ void MessageEndpointServerHandler::start( endpoint->getAddress()); continue; } + Message& headerMessage = headerMessageMaybe.value(); if (headerMessage.size() == shutdownHeader.size()) { @@ -77,6 +92,7 @@ void MessageEndpointServerHandler::start( if (server->workerLatch != nullptr) { server->workerLatch->wait(); } + break; } } @@ -137,8 +153,10 @@ void MessageEndpointServerHandler::start( latch->wait(); // Connect the router and dealer if sync - if (!async) { - router->proxyWithDealer(dealer); + if (async) { + asyncFanIn->attachFanOut(asyncFanOut); + } else { + syncFanIn->attachFanOut(syncFanOut); } }); } @@ -187,18 +205,24 @@ void MessageEndpointServer::start() void MessageEndpointServer::stop() { - // Send shutdown messages - SPDLOG_TRACE( - "Server sending shutdown messages to ports {} {}", asyncPort, syncPort); - for (int i = 0; i < nThreads; i++) { setWorkerLatch(); + SPDLOG_TRACE("Sending async shutdown message {}/{} to port {}", + i + 1, + nThreads, + asyncPort); + asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); awaitWorkerLatch(); } for (int i = 0; i < nThreads; i++) { setWorkerLatch(); + SPDLOG_TRACE("Sending sync shutdown message {}/{} to port {}", + i + 1, + nThreads, + syncPort); + syncShutdownSender.sendRaw(shutdownHeader.data(), shutdownHeader.size()); awaitWorkerLatch(); From 2c071341d3948713d9097a838489a044155a4254 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 1 Oct 2021 17:30:12 +0000 Subject: [PATCH 05/16] Async shutdown working --- src/transport/MessageEndpointServer.cpp | 155 +++++++++++++----------- 1 file changed, 84 insertions(+), 71 deletions(-) diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index f6b1c8a0e..985fa22ec 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -57,95 +57,108 @@ void MessageEndpointServerHandler::start( // Lauch worker threads for (int i = 0; i < nThreads; i++) { - workerThreads.emplace_back([this] { - std::unique_ptr endpoint = nullptr; - - if (async) { - // Async workers have a PULL socket - endpoint = - std::make_unique(inprocLabel); - } else { - // Sync workers have an in-proc REP socket - endpoint = - std::make_unique(inprocLabel); - } + workerThreads.emplace_back([this, i] { + // Here we want to isolate all ZeroMQ stuff in its own + // context, so we can do things after it's been destroyed + { + std::unique_ptr endpoint = nullptr; - while (true) { - // Receive header and body - std::optional headerMessageMaybe = - endpoint->recv(); - if (!headerMessageMaybe.has_value()) { - SPDLOG_TRACE("Server on {}, looping after no message", - endpoint->getAddress()); - continue; + if (async) { + // Async workers have a PULL socket + endpoint = std::make_unique( + inprocLabel); + } else { + // Sync workers have an in-proc REP socket + endpoint = std::make_unique( + inprocLabel); } - Message& headerMessage = headerMessageMaybe.value(); - - if (headerMessage.size() == shutdownHeader.size()) { - if (headerMessage.dataCopy() == shutdownHeader) { + while (true) { + // Receive header and body + std::optional headerMessageMaybe = + endpoint->recv(); + if (!headerMessageMaybe.has_value()) { SPDLOG_TRACE( - "Server on {} received shutdown message", + "Server on {}, looping after no message", endpoint->getAddress()); - - // Allow things to wait on shutdown - if (server->workerLatch != nullptr) { - server->workerLatch->wait(); - } - - break; + continue; } - } - if (!headerMessage.more()) { - throw std::runtime_error( - "Header sent without SNDMORE flag"); - } + Message& headerMessage = headerMessageMaybe.value(); - std::optional bodyMaybe = endpoint->recv(); - if (!bodyMaybe.has_value()) { - SPDLOG_ERROR( - "Server on port {}, got header, timed out on body", - endpoint->getAddress()); - throw MessageTimeoutException( - "Server, got header, timed out on body"); - } + if (headerMessage.size() == shutdownHeader.size()) { + if (headerMessage.dataCopy() == shutdownHeader) { + SPDLOG_TRACE( + "Server thread {} on {} got shutdown message", + i, + endpoint->getAddress()); - Message& body = bodyMaybe.value(); - if (body.more()) { - throw std::runtime_error("Body sent with SNDMORE flag"); - } + break; + } + } - assert(headerMessage.size() == sizeof(uint8_t)); - uint8_t header = - static_cast(*headerMessage.data()); + if (!headerMessage.more()) { + throw std::runtime_error( + "Header sent without SNDMORE flag"); + } - if (async) { - // Server-specific async handling - server->doAsyncRecv(header, body.udata(), body.size()); - } else { - // Server-specific sync handling - std::unique_ptr resp = - server->doSyncRecv(header, body.udata(), body.size()); - size_t respSize = resp->ByteSizeLong(); + std::optional bodyMaybe = endpoint->recv(); + if (!bodyMaybe.has_value()) { + SPDLOG_ERROR("Server on port {}, got header, timed " + "out on body", + endpoint->getAddress()); + throw MessageTimeoutException( + "Server, got header, timed out on body"); + } - uint8_t buffer[respSize]; - if (!resp->SerializeToArray(buffer, respSize)) { + Message& body = bodyMaybe.value(); + if (body.more()) { throw std::runtime_error( - "Error serialising message"); + "Body sent with SNDMORE flag"); } - // Return the response - static_cast(endpoint.get()) - ->sendResponse(buffer, respSize); - } + assert(headerMessage.size() == sizeof(uint8_t)); + uint8_t header = + static_cast(*headerMessage.data()); + + if (async) { + // Server-specific async handling + server->doAsyncRecv( + header, body.udata(), body.size()); + } else { + // Server-specific sync handling + std::unique_ptr resp = + server->doSyncRecv( + header, body.udata(), body.size()); + size_t respSize = resp->ByteSizeLong(); + + uint8_t buffer[respSize]; + if (!resp->SerializeToArray(buffer, respSize)) { + throw std::runtime_error( + "Error serialising message"); + } - // Wait on the async latch if necessary - if (server->workerLatch != nullptr) { - SPDLOG_TRACE("Server thread waiting on async latch"); - server->workerLatch->wait(); + // Return the response + static_cast( + endpoint.get()) + ->sendResponse(buffer, respSize); + } + + // Wait on the async latch if necessary + if (server->workerLatch != nullptr) { + SPDLOG_TRACE( + "Server thread waiting on async latch"); + server->workerLatch->wait(); + } } } + + // Just before the thread dies, check if there's something + // waiting on the latch + if (server->workerLatch != nullptr) { + SPDLOG_TRACE("Server thread {} waiting on async latch", i); + server->workerLatch->wait(); + } }); } From f6f7521817af2d033bf3c80dcaaa6e448ec5cc88 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 1 Oct 2021 17:44:08 +0000 Subject: [PATCH 06/16] Sync shutdown working --- src/transport/MessageEndpointServer.cpp | 26 ++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 985fa22ec..4f13eaa53 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -1,8 +1,9 @@ -#include "faabric/transport/MessageEndpoint.h" +#include #include #include #include #include +#include #include #include @@ -93,6 +94,16 @@ void MessageEndpointServerHandler::start( i, endpoint->getAddress()); + // Send an empty response if in sync mode + // (otherwise upstream socket will hang) + if (!async) { + std::vector empty(4, 0); + static_cast( + endpoint.get()) + ->sendResponse(empty.data(), + empty.size()); + } + break; } } @@ -162,10 +173,11 @@ void MessageEndpointServerHandler::start( }); } - // Wait on the latch + // Wait on the start-up latch latch->wait(); - // Connect the router and dealer if sync + // Connect the relevant fan-in/ out sockets (these will run until + // context is closed) if (async) { asyncFanIn->attachFanOut(asyncFanOut); } else { @@ -219,25 +231,25 @@ void MessageEndpointServer::start() void MessageEndpointServer::stop() { for (int i = 0; i < nThreads; i++) { - setWorkerLatch(); SPDLOG_TRACE("Sending async shutdown message {}/{} to port {}", i + 1, nThreads, asyncPort); + setWorkerLatch(); asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); awaitWorkerLatch(); } for (int i = 0; i < nThreads; i++) { - setWorkerLatch(); SPDLOG_TRACE("Sending sync shutdown message {}/{} to port {}", i + 1, nThreads, syncPort); - syncShutdownSender.sendRaw(shutdownHeader.data(), - shutdownHeader.size()); + setWorkerLatch(); + syncShutdownSender.sendAwaitResponse(shutdownHeader.data(), + shutdownHeader.size()); awaitWorkerLatch(); } From 46f12a6192267e716dcf2940abc6db9894c989ba Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Fri, 1 Oct 2021 18:01:50 +0000 Subject: [PATCH 07/16] Proxy steerable --- include/faabric/transport/MessageEndpoint.h | 23 ++++++++++---- src/transport/MessageEndpoint.cpp | 33 ++++++++++----------- src/transport/MessageEndpointServer.cpp | 4 +-- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 31a18493d..684a0abe9 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -119,6 +119,21 @@ class RecvMessageEndpoint : public MessageEndpoint zmq::socket_t socket; }; +class FanInMessageEndpoint : public RecvMessageEndpoint +{ + public: + FanInMessageEndpoint(int portIn, + int timeoutMs, + zmq::socket_type socketType); + + void attachFanOut(zmq::socket_t& fanOutSock); + + void stop(); + + private: + zmq::socket_t controlSock; +}; + class AsyncFanOutMessageEndpoint final : public MessageEndpoint { public: @@ -132,13 +147,11 @@ class AsyncFanOutMessageEndpoint final : public MessageEndpoint zmq::socket_t socket; }; -class AsyncFanInMessageEndpoint final : public RecvMessageEndpoint +class AsyncFanInMessageEndpoint final : public FanInMessageEndpoint { public: AsyncFanInMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - - void attachFanOut(std::unique_ptr& dealer); }; class SyncFanOutMessageEndpoint final : public RecvMessageEndpoint @@ -148,13 +161,11 @@ class SyncFanOutMessageEndpoint final : public RecvMessageEndpoint int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); }; -class SyncFanInMessageEndpoint final : public RecvMessageEndpoint +class SyncFanInMessageEndpoint final : public FanInMessageEndpoint { public: SyncFanInMessageEndpoint(int portIn, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - - void attachFanOut(std::unique_ptr& dealer); }; class AsyncRecvMessageEndpoint : public RecvMessageEndpoint diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 00bd0cb81..dab8cbf14 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -377,6 +377,19 @@ std::optional RecvMessageEndpoint::recv(int size) // ASYNC FAN IN AND FAN OUT // ---------------------------------------------- +FanInMessageEndpoint::FanInMessageEndpoint(int portIn, + int timeoutMs, + zmq::socket_type socketType) + : RecvMessageEndpoint(portIn, timeoutMs, socketType) +{} + +void FanInMessageEndpoint::attachFanOut(zmq::socket_t& fanOutSock) +{ + zmq::proxy_steerable(socket, fanOutSock, zmq::socket_ref(), controlSock); +} + +void FanInMessageEndpoint::stop() {} + AsyncFanOutMessageEndpoint::AsyncFanOutMessageEndpoint( const std::string& inprocLabel, int timeoutMs) @@ -401,17 +414,9 @@ void AsyncFanOutMessageEndpoint::send(const uint8_t* data, } AsyncFanInMessageEndpoint::AsyncFanInMessageEndpoint(int portIn, int timeoutMs) - : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) + : FanInMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) {} -void AsyncFanInMessageEndpoint::attachFanOut( - std::unique_ptr& dealer) -{ - // Connect this to a fan out - SPDLOG_TRACE("Proxying {} to {}", address, dealer->getAddress()); - zmq::proxy(socket, dealer->socket); -} - // ---------------------------------------------- // SYNC FAN IN AND FAN OUT // ---------------------------------------------- @@ -426,17 +431,9 @@ SyncFanOutMessageEndpoint::SyncFanOutMessageEndpoint( {} SyncFanInMessageEndpoint::SyncFanInMessageEndpoint(int portIn, int timeoutMs) - : RecvMessageEndpoint(portIn, timeoutMs, zmq::socket_type::router) + : FanInMessageEndpoint(portIn, timeoutMs, zmq::socket_type::router) {} -void SyncFanInMessageEndpoint::attachFanOut( - std::unique_ptr& dealer) -{ - // Connect this to a fan out - SPDLOG_TRACE("Proxying {} to {}", address, dealer->getAddress()); - zmq::proxy(socket, dealer->socket); -} - // ---------------------------------------------- // ASYNC RECV ENDPOINT // ---------------------------------------------- diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 4f13eaa53..d6fd8a1f1 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -179,9 +179,9 @@ void MessageEndpointServerHandler::start( // Connect the relevant fan-in/ out sockets (these will run until // context is closed) if (async) { - asyncFanIn->attachFanOut(asyncFanOut); + asyncFanIn->attachFanOut(asyncFanOut->socket); } else { - syncFanIn->attachFanOut(syncFanOut); + syncFanIn->attachFanOut(syncFanOut->socket); } }); } From 9e20c1aa0cb00686c5db9d549b9d23a1dd05568a Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 4 Oct 2021 07:34:12 +0000 Subject: [PATCH 08/16] Killing proxies properly --- include/faabric/transport/MessageEndpoint.h | 1 + .../faabric/transport/MessageEndpointServer.h | 3 + src/transport/MessageEndpoint.cpp | 145 ++++++++++++------ src/transport/MessageEndpointServer.cpp | 35 +++-- 4 files changed, 127 insertions(+), 57 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 684a0abe9..031cbcc26 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -132,6 +132,7 @@ class FanInMessageEndpoint : public RecvMessageEndpoint private: zmq::socket_t controlSock; + std::string controlSockAddress; }; class AsyncFanOutMessageEndpoint final : public MessageEndpoint diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 6f0196240..9f693da4a 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -36,6 +36,9 @@ class MessageEndpointServerHandler std::thread receiverThread; std::vector workerThreads; + + std::unique_ptr asyncFanIn = nullptr; + std::unique_ptr syncFanIn = nullptr; }; class MessageEndpointServer diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index dab8cbf14..28c12d1ff 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -48,35 +48,15 @@ namespace faabric::transport { -MessageEndpoint::MessageEndpoint(const std::string& addressIn, int timeoutMsIn) - : address(addressIn) - , timeoutMs(timeoutMsIn) - , tid(std::this_thread::get_id()) - , id(faabric::util::generateGid()) -{ - // Check and set socket timeout - if (timeoutMs <= 0) { - SPDLOG_ERROR("Setting invalid timeout of {}", timeoutMs); - throw std::runtime_error("Setting invalid timeout"); - } -} - -// Convenience constructor for standard TCP ports -MessageEndpoint::MessageEndpoint(const std::string& hostIn, - int portIn, - int timeoutMsIn) - : MessageEndpoint("tcp://" + hostIn + ":" + std::to_string(portIn), - timeoutMsIn) -{} - -std::string MessageEndpoint::getAddress() -{ - return address; -} - -zmq::socket_t MessageEndpoint::setUpSocket( - zmq::socket_type socketType, - MessageEndpointConnectType connectType) +/** + * This is the core of our zmq usage, where we set up sockets. It handles + * setting timeouts and catching errors in the creation process, as well as + * logging and validating our use of socket types and connection types. + */ +zmq::socket_t socketFactory(zmq::socket_type socketType, + MessageEndpointConnectType connectType, + int timeoutMs, + const std::string& address) { zmq::socket_t socket; @@ -93,13 +73,19 @@ zmq::socket_t MessageEndpoint::setUpSocket( switch (connectType) { case (MessageEndpointConnectType::BIND): { switch (socketType) { - case zmq::socket_type::push: { - SPDLOG_TRACE("Bind socket: push {} (timeout {}ms)", + case zmq::socket_type::dealer: { + SPDLOG_TRACE("Bind socket: dealer {} (timeout {}ms)", address, timeoutMs); CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; } + case zmq::socket_type::pub: { + SPDLOG_TRACE( + "Bind socket: pub {} (timeout {}ms)", address, timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } case zmq::socket_type::pull: { SPDLOG_TRACE("Bind socket: pull {} (timeout {}ms)", address, @@ -107,6 +93,13 @@ zmq::socket_t MessageEndpoint::setUpSocket( CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; } + case zmq::socket_type::push: { + SPDLOG_TRACE("Bind socket: push {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") + break; + } case zmq::socket_type::rep: { SPDLOG_TRACE( "Bind socket: rep {} (timeout {}ms)", address, timeoutMs); @@ -120,13 +113,6 @@ zmq::socket_t MessageEndpoint::setUpSocket( CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") break; } - case zmq::socket_type::dealer: { - SPDLOG_TRACE("Bind socket: dealer {} (timeout {}ms)", - address, - timeoutMs); - CATCH_ZMQ_ERR_RETRY_ONCE(socket.bind(address), "bind") - break; - } default: { SPDLOG_ERROR( "Invalid bind socket type {} ({})", socketType, address); @@ -138,8 +124,8 @@ zmq::socket_t MessageEndpoint::setUpSocket( } case (MessageEndpointConnectType::CONNECT): { switch (socketType) { - case zmq::socket_type::req: { - SPDLOG_TRACE("Connect socket: req {} (timeout {}ms)", + case zmq::socket_type::pull: { + SPDLOG_TRACE("Connect socket: pull {} (timeout {}ms)", address, timeoutMs); CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") @@ -152,15 +138,22 @@ zmq::socket_t MessageEndpoint::setUpSocket( CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") break; } - case zmq::socket_type::pull: { - SPDLOG_TRACE("Connect socket: pull {} (timeout {}ms)", + case zmq::socket_type::rep: { + SPDLOG_TRACE("Connect socket: rep {} (timeout {}ms)", address, timeoutMs); CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") break; } - case zmq::socket_type::rep: { - SPDLOG_TRACE("Connect socket: rep {} (timeout {}ms)", + case zmq::socket_type::req: { + SPDLOG_TRACE("Connect socket: req {} (timeout {}ms)", + address, + timeoutMs); + CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") + break; + } + case zmq::socket_type::sub: { + SPDLOG_TRACE("Connect socket: sub {} (timeout {}ms)", address, timeoutMs); CATCH_ZMQ_ERR_RETRY_ONCE(socket.connect(address), "connect") @@ -185,6 +178,39 @@ zmq::socket_t MessageEndpoint::setUpSocket( return socket; } +MessageEndpoint::MessageEndpoint(const std::string& addressIn, int timeoutMsIn) + : address(addressIn) + , timeoutMs(timeoutMsIn) + , tid(std::this_thread::get_id()) + , id(faabric::util::generateGid()) +{ + // Check and set socket timeout + if (timeoutMs <= 0) { + SPDLOG_ERROR("Setting invalid timeout of {}", timeoutMs); + throw std::runtime_error("Setting invalid timeout"); + } +} + +// Convenience constructor for standard TCP ports +MessageEndpoint::MessageEndpoint(const std::string& hostIn, + int portIn, + int timeoutMsIn) + : MessageEndpoint("tcp://" + hostIn + ":" + std::to_string(portIn), + timeoutMsIn) +{} + +std::string MessageEndpoint::getAddress() +{ + return address; +} + +zmq::socket_t MessageEndpoint::setUpSocket( + zmq::socket_type socketType, + MessageEndpointConnectType connectType) +{ + return socketFactory(socketType, connectType, timeoutMs, address); +} + void MessageEndpoint::doSend(zmq::socket_t& socket, const uint8_t* data, size_t dataSize, @@ -381,14 +407,41 @@ FanInMessageEndpoint::FanInMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType) : RecvMessageEndpoint(portIn, timeoutMs, socketType) -{} + , controlSockAddress("inproc://" + std::to_string(portIn) + "-control") +{ + // Connect the control sock. Note that even though the control socket lives + // longer than the control killer socket, we must *connect* here, not bind. + controlSock = socketFactory(zmq::socket_type::sub, + MessageEndpointConnectType::CONNECT, + timeoutMs, + controlSockAddress); + + // Subscribe to all topics + controlSock.set(zmq::sockopt::subscribe, zmq::str_buffer("")); +} void FanInMessageEndpoint::attachFanOut(zmq::socket_t& fanOutSock) { + // Discussion on proxy_steerable here: + // https://github.com/zeromq/cppzmq/issues/478 + SPDLOG_TRACE("Connecting proxy on {} ({})", address, controlSockAddress); zmq::proxy_steerable(socket, fanOutSock, zmq::socket_ref(), controlSock); } -void FanInMessageEndpoint::stop() {} +void FanInMessageEndpoint::stop() +{ + SPDLOG_TRACE("Sending TERMINATE on control socket {}", controlSockAddress); + // Note that even though this killer socket is short-lived sending a message + // to the control socket, we must *bind* here, not connect. + zmq::socket_t controlKillerSock = + socketFactory(zmq::socket_type::pub, + MessageEndpointConnectType::BIND, + timeoutMs, + controlSockAddress); + + controlKillerSock.send(zmq::str_buffer("TERMINATE"), zmq::send_flags::none); + controlKillerSock.close(); +} AsyncFanOutMessageEndpoint::AsyncFanOutMessageEndpoint( const std::string& inprocLabel, diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index d6fd8a1f1..d8a463707 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -37,15 +37,11 @@ void MessageEndpointServerHandler::start( receiverThread = std::thread([this, latch] { int port = async ? server->asyncPort : server->syncPort; - // Sync: router and dealer - std::unique_ptr syncFanIn = nullptr; std::unique_ptr syncFanOut = nullptr; - - // Async: pull/ push pair - std::unique_ptr asyncFanIn = nullptr; std::unique_ptr asyncFanOut = nullptr; if (async) { + // Set up push/ pull pair asyncFanIn = std::make_unique(port); asyncFanOut = std::make_unique(inprocLabel); @@ -155,10 +151,10 @@ void MessageEndpointServerHandler::start( ->sendResponse(buffer, respSize); } - // Wait on the async latch if necessary + // Wait on the worker latch if necessary if (server->workerLatch != nullptr) { SPDLOG_TRACE( - "Server thread waiting on async latch"); + "Server thread waiting on worker latch"); server->workerLatch->wait(); } } @@ -167,13 +163,14 @@ void MessageEndpointServerHandler::start( // Just before the thread dies, check if there's something // waiting on the latch if (server->workerLatch != nullptr) { - SPDLOG_TRACE("Server thread {} waiting on async latch", i); + SPDLOG_TRACE("Server thread {} waiting on worker latch", i); server->workerLatch->wait(); } }); } - // Wait on the start-up latch + // Wait on the start-up latch passed in by the caller. + // TODO - does this still work with the fan-in/-out approach? latch->wait(); // Connect the relevant fan-in/ out sockets (these will run until @@ -188,6 +185,16 @@ void MessageEndpointServerHandler::start( void MessageEndpointServerHandler::join() { + // Note that we have to kill any running proxies before anything else + // https://github.com/zeromq/cppzmq/issues/478 + if (syncFanIn != nullptr) { + syncFanIn->stop(); + } + + if (asyncFanIn != nullptr) { + asyncFanIn->stop(); + } + // Join each worker for (auto& t : workerThreads) { if (t.joinable()) { @@ -230,6 +237,12 @@ void MessageEndpointServer::start() void MessageEndpointServer::stop() { + // Here we send shutdown messages to each worker in turn, however, because + // they're all connected on the same inproc port, we have to wait until each + // one has shut down fully (i.e. the zmq socket has gone out of scope), + // before sending the next shutdown message (hence the use of the latch). If + // we don't do this, zmq will direct messages to sockets that are in the + // process of shutting down and cause errors. for (int i = 0; i < nThreads; i++) { SPDLOG_TRACE("Sending async shutdown message {}/{} to port {}", i + 1, @@ -265,10 +278,10 @@ void MessageEndpointServer::setWorkerLatch() void MessageEndpointServer::awaitWorkerLatch() { - SPDLOG_TRACE("Waiting on async latch for port {}", asyncPort); + SPDLOG_TRACE("Waiting on worker latch for port {}", asyncPort); workerLatch->wait(); - SPDLOG_TRACE("Finished async latch for port {}", asyncPort); + SPDLOG_TRACE("Finished worker latch for port {}", asyncPort); workerLatch = nullptr; } } From 683188279a20d16a12b8b59d2a85d8b55163da96 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 4 Oct 2021 11:43:27 +0000 Subject: [PATCH 09/16] Connect up configurable thread numbers --- .../faabric/transport/MessageEndpointServer.h | 2 +- include/faabric/util/config.h | 6 ++--- src/scheduler/FunctionCallServer.cpp | 8 +++--- src/snapshot/SnapshotServer.cpp | 8 +++--- src/state/StateServer.cpp | 8 +++--- src/util/config.cpp | 8 ++++++ tests/test/transport/test_message_server.cpp | 26 ++++++++++++++----- 7 files changed, 47 insertions(+), 19 deletions(-) diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index 9f693da4a..bcf1478e5 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -47,7 +47,7 @@ class MessageEndpointServer MessageEndpointServer(int asyncPortIn, int syncPortIn, const std::string& inprocLabelIn, - int nThreadsIn = DEFAULT_MESSAGE_SERVER_THREADS); + int nThreadsIn); virtual void start(); diff --git a/include/faabric/util/config.h b/include/faabric/util/config.h index 9503896fd..e71d46490 100644 --- a/include/faabric/util/config.h +++ b/include/faabric/util/config.h @@ -43,9 +43,9 @@ class SystemConfig int endpointNumThreads; // Transport - int functionCallServerNumThreads = 4; - int stateServerNumThreads = 2; - int snapshotServerNumThreads = 2; + int functionServerThreads; + int stateServerThreads; + int snapshotServerThreads; SystemConfig(); diff --git a/src/scheduler/FunctionCallServer.cpp b/src/scheduler/FunctionCallServer.cpp index 9ed5273f3..e8b6b35e0 100644 --- a/src/scheduler/FunctionCallServer.cpp +++ b/src/scheduler/FunctionCallServer.cpp @@ -9,9 +9,11 @@ namespace faabric::scheduler { FunctionCallServer::FunctionCallServer() - : faabric::transport::MessageEndpointServer(FUNCTION_CALL_ASYNC_PORT, - FUNCTION_CALL_SYNC_PORT, - FUNCTION_INPROC_LABEL) + : faabric::transport::MessageEndpointServer( + FUNCTION_CALL_ASYNC_PORT, + FUNCTION_CALL_SYNC_PORT, + FUNCTION_INPROC_LABEL, + faabric::util::getSystemConfig().functionServerThreads) , scheduler(getScheduler()) {} diff --git a/src/snapshot/SnapshotServer.cpp b/src/snapshot/SnapshotServer.cpp index b4924ad70..7d0d87b36 100644 --- a/src/snapshot/SnapshotServer.cpp +++ b/src/snapshot/SnapshotServer.cpp @@ -13,9 +13,11 @@ namespace faabric::snapshot { SnapshotServer::SnapshotServer() - : faabric::transport::MessageEndpointServer(SNAPSHOT_ASYNC_PORT, - SNAPSHOT_SYNC_PORT, - SNAPSHOT_INPROC_LABEL) + : faabric::transport::MessageEndpointServer( + SNAPSHOT_ASYNC_PORT, + SNAPSHOT_SYNC_PORT, + SNAPSHOT_INPROC_LABEL, + faabric::util::getSystemConfig().snapshotServerThreads) {} void SnapshotServer::doAsyncRecv(int header, diff --git a/src/state/StateServer.cpp b/src/state/StateServer.cpp index baf082769..7439a09d9 100644 --- a/src/state/StateServer.cpp +++ b/src/state/StateServer.cpp @@ -13,9 +13,11 @@ namespace faabric::state { StateServer::StateServer(State& stateIn) - : faabric::transport::MessageEndpointServer(STATE_ASYNC_PORT, - STATE_SYNC_PORT, - STATE_INPROC_LABEL) + : faabric::transport::MessageEndpointServer( + STATE_ASYNC_PORT, + STATE_SYNC_PORT, + STATE_INPROC_LABEL, + faabric::util::getSystemConfig().stateServerThreads) , state(stateIn) {} diff --git a/src/util/config.cpp b/src/util/config.cpp index 6875b1854..3dd139aee 100644 --- a/src/util/config.cpp +++ b/src/util/config.cpp @@ -56,6 +56,14 @@ void SystemConfig::initialise() endpointHost = faabric::util::getPrimaryIPForThisHost(endpointInterface); } + + // Transport + functionServerThreads = + this->getSystemConfIntParam("FUNCTION_SERVER_THREADS", "2"); + stateServerThreads = + this->getSystemConfIntParam("STATE_SERVER_THREADS", "2"); + snapshotServerThreads = + this->getSystemConfIntParam("SNAPSHOT_SERVER_THREADS", "2"); } int SystemConfig::getSystemConfIntParam(const char* name, diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 2c2fb4906..76678d255 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -20,7 +20,7 @@ class DummyServer final : public MessageEndpointServer { public: DummyServer() - : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-dummy") + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-dummy", 2) {} std::atomic messageCount = 0; @@ -46,7 +46,7 @@ class EchoServer final : public MessageEndpointServer { public: EchoServer() - : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-echo") + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-echo", 2) {} protected: @@ -75,7 +75,7 @@ class SleepServer final : public MessageEndpointServer int delayMs = 1000; SleepServer() - : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-sleep") + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-sleep", 2) {} protected: @@ -215,9 +215,23 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") faabric::StatePart response; if (expectFailure) { - // Check for failure - REQUIRE_THROWS_AS(cli.syncSend(0, sleepBytes, sizeof(int), &response), - MessageTimeoutException); + bool failed = false; + + // Note - here we must wait until the server has finished handling the + // request, even though it's failed + server.setWorkerLatch(); + + // Make the call and check it fails + try { + cli.syncSend(0, sleepBytes, sizeof(int), &response); + } catch (MessageTimeoutException& ex) { + failed = true; + } + + REQUIRE(failed); + + // Wait for request to finish + server.awaitWorkerLatch(); } else { cli.syncSend(0, sleepBytes, sizeof(int), &response); REQUIRE(response.data() == "Response after sleep"); From 52fe76548d47e16b689c460532f3c425172cf891 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 4 Oct 2021 13:18:55 +0000 Subject: [PATCH 10/16] Add tests for mulit-threaded server --- tests/test/transport/test_message_server.cpp | 95 ++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index 76678d255..ff77e95ae 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -99,6 +99,39 @@ class SleepServer final : public MessageEndpointServer } }; +class BlockServer final : public MessageEndpointServer +{ + public: + BlockServer() + : MessageEndpointServer(TEST_PORT_ASYNC, TEST_PORT_SYNC, "test-lock", 2) + , latch(faabric::util::Latch::create(2)) + {} + + protected: + void doAsyncRecv(int header, + const uint8_t* buffer, + size_t bufferSize) override + { + throw std::runtime_error("Lock server not expecting async recv"); + } + + std::unique_ptr + doSyncRecv(int header, const uint8_t* buffer, size_t bufferSize) override + { + // Wait on the latch, requires multiple threads executing in parallel to + // get a response. + latch->wait(); + + // Echo input data + auto response = std::make_unique(); + response->set_data(buffer, bufferSize); + return response; + } + + private: + std::shared_ptr latch = nullptr; +}; + namespace tests { TEST_CASE("Test sending one message to server", "[transport]") @@ -239,4 +272,66 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") server.stop(); } + +TEST_CASE("Test blocking requests in multi-threaded server", "[transport]") +{ + // Start server in the background + BlockServer server; + server.start(); + + bool successes[2] = { false, false }; + + // Create two background threads to make the blocking requests + std::thread tA([&successes] { + MessageEndpointClient cli(LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); + + std::string expectedMsg = "Background thread A"; + + faabric::StatePart response; + cli.syncSend( + 0, BYTES(expectedMsg.data()), expectedMsg.size(), &response); + + if (response.data() != expectedMsg) { + SPDLOG_ERROR("A did not get expected response: {} != {}", + response.data(), + expectedMsg); + successes[0] = false; + } else { + successes[0] = true; + } + }); + + std::thread tB([&successes] { + MessageEndpointClient cli(LOCALHOST, TEST_PORT_ASYNC, TEST_PORT_SYNC); + + std::string expectedMsg = "Background thread B"; + + faabric::StatePart response; + cli.syncSend( + 0, BYTES(expectedMsg.data()), expectedMsg.size(), &response); + + if (response.data() != expectedMsg) { + SPDLOG_ERROR("B did not get expected response: {} != {}", + response.data(), + expectedMsg); + + successes[1] = false; + } else { + successes[1] = true; + } + }); + + if (tA.joinable()) { + tA.join(); + } + + if (tB.joinable()) { + tB.join(); + } + + REQUIRE(successes[0]); + REQUIRE(successes[1]); + + server.stop(); +} } From 16b202c28386b7c2c73a252cfc9b82a682e2f435 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 4 Oct 2021 15:37:08 +0000 Subject: [PATCH 11/16] Tidy-up and sleep --- include/faabric/transport/MessageEndpoint.h | 4 ---- .../faabric/transport/MessageEndpointServer.h | 5 +++- src/transport/MessageEndpoint.cpp | 23 ++++--------------- src/transport/MessageEndpointServer.cpp | 23 +++++++++++-------- 4 files changed, 22 insertions(+), 33 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index 031cbcc26..bc1c8ec8e 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -141,10 +141,6 @@ class AsyncFanOutMessageEndpoint final : public MessageEndpoint AsyncFanOutMessageEndpoint(const std::string& inProcLabel, int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); - void sendHeader(int header); - - void send(const uint8_t* data, size_t dataSize, bool more = false); - zmq::socket_t socket; }; diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index bcf1478e5..a5c884730 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -37,8 +37,11 @@ class MessageEndpointServerHandler std::vector workerThreads; - std::unique_ptr asyncFanIn = nullptr; std::unique_ptr syncFanIn = nullptr; + std::unique_ptr syncFanOut = nullptr; + + std::unique_ptr asyncFanIn = nullptr; + std::unique_ptr asyncFanOut = nullptr; }; class MessageEndpointServer diff --git a/src/transport/MessageEndpoint.cpp b/src/transport/MessageEndpoint.cpp index 28c12d1ff..8fb0b1c22 100644 --- a/src/transport/MessageEndpoint.cpp +++ b/src/transport/MessageEndpoint.cpp @@ -1,4 +1,3 @@ -#include "zmq.hpp" #include #include #include @@ -49,9 +48,9 @@ namespace faabric::transport { /** - * This is the core of our zmq usage, where we set up sockets. It handles - * setting timeouts and catching errors in the creation process, as well as - * logging and validating our use of socket types and connection types. + * This is where we set up all our sockets. It handles setting timeouts and + * catching errors in the creation process, as well as logging and validating + * our use of socket types and connection types. */ zmq::socket_t socketFactory(zmq::socket_type socketType, MessageEndpointConnectType connectType, @@ -422,7 +421,7 @@ FanInMessageEndpoint::FanInMessageEndpoint(int portIn, void FanInMessageEndpoint::attachFanOut(zmq::socket_t& fanOutSock) { - // Discussion on proxy_steerable here: + // Useful discussion on proxy_steerable here: // https://github.com/zeromq/cppzmq/issues/478 SPDLOG_TRACE("Connecting proxy on {} ({})", address, controlSockAddress); zmq::proxy_steerable(socket, fanOutSock, zmq::socket_ref(), controlSock); @@ -452,20 +451,6 @@ AsyncFanOutMessageEndpoint::AsyncFanOutMessageEndpoint( setUpSocket(zmq::socket_type::push, MessageEndpointConnectType::BIND); } -void AsyncFanOutMessageEndpoint::sendHeader(int header) -{ - uint8_t headerBytes = static_cast(header); - doSend(socket, &headerBytes, sizeof(headerBytes), true); -} - -void AsyncFanOutMessageEndpoint::send(const uint8_t* data, - size_t dataSize, - bool more) -{ - SPDLOG_TRACE("PUSH {} ({} bytes, more {})", address, dataSize, more); - doSend(socket, data, dataSize, more); -} - AsyncFanInMessageEndpoint::AsyncFanInMessageEndpoint(int portIn, int timeoutMs) : FanInMessageEndpoint(portIn, timeoutMs, zmq::socket_type::pull) {} diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index d8a463707..1a6052ea2 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -37,9 +37,6 @@ void MessageEndpointServerHandler::start( receiverThread = std::thread([this, latch] { int port = async ? server->asyncPort : server->syncPort; - std::unique_ptr syncFanOut = nullptr; - std::unique_ptr asyncFanOut = nullptr; - if (async) { // Set up push/ pull pair asyncFanIn = std::make_unique(port); @@ -169,12 +166,13 @@ void MessageEndpointServerHandler::start( }); } - // Wait on the start-up latch passed in by the caller. - // TODO - does this still work with the fan-in/-out approach? + // Wait on the start-up latch passed in by the caller. This means that + // once the latch is freed, this handler is just about to start its + // proxy, so a short sleep should mean things are ready to go. latch->wait(); // Connect the relevant fan-in/ out sockets (these will run until - // context is closed) + // they receive a terminate message) if (async) { asyncFanIn->attachFanOut(asyncFanOut->socket); } else { @@ -222,17 +220,24 @@ MessageEndpointServer::MessageEndpointServer(int asyncPortIn, , syncShutdownSender(LOCALHOST, syncPort) {} +/** + * We need to guarantee to callers of this function, that when it returns, the + * server will be ready to use. + */ void MessageEndpointServer::start() { - // This latch means that callers can guarantee that when this function - // completes, both sockets will have been opened (and hence the server is - // ready to use). + // This latch allows use to block on the handlers until _just_ before they + // start their proxies. auto startLatch = faabric::util::Latch::create(3); asyncHandler.start(startLatch); syncHandler.start(startLatch); startLatch->wait(); + + // Unfortunately we can't know precisely when the proxies have started, + // hence have to add a sleep. + SLEEP_MS(500); } void MessageEndpointServer::stop() From 9d163e66ab5d29a39d91f8b7a33d7ff6bdff623d Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 4 Oct 2021 16:17:09 +0000 Subject: [PATCH 12/16] Temporary trace logging in tests --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6b40a0e25..4e5872a2a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,6 +44,7 @@ jobs: HOST_TYPE: ci REDIS_QUEUE_HOST: redis REDIS_STATE_HOST: redis + LOG_LEVEL: trace container: image: faasm/faabric:0.1.2 defaults: From fd27c508edcc9e485b08a99de159b67e272a3299 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 4 Oct 2021 16:52:52 +0000 Subject: [PATCH 13/16] Config tests, different latches --- .../faabric/transport/MessageEndpointServer.h | 13 +++-- src/transport/MessageEndpointServer.cpp | 52 +++++++++++++------ .../scheduler/test_function_client_server.cpp | 19 +++++-- .../snapshot/test_snapshot_client_server.cpp | 11 ++++ tests/test/state/test_state_server.cpp | 11 ++++ tests/test/transport/test_message_server.cpp | 8 +-- tests/test/util/test_config.cpp | 12 +++++ 7 files changed, 99 insertions(+), 27 deletions(-) diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index a5c884730..a40a520fd 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -56,9 +56,11 @@ class MessageEndpointServer virtual void stop(); - void setWorkerLatch(); + void setRequestLatch(); - void awaitWorkerLatch(); + void awaitRequestLatch(); + + int getNThreads(); protected: virtual void doAsyncRecv(int header, @@ -76,12 +78,17 @@ class MessageEndpointServer const std::string inprocLabel; const int nThreads; + void setShutdownLatch(); + + void awaitShutdownLatch(); + MessageEndpointServerHandler asyncHandler; MessageEndpointServerHandler syncHandler; AsyncSendMessageEndpoint asyncShutdownSender; SyncSendMessageEndpoint syncShutdownSender; - std::shared_ptr workerLatch; + std::shared_ptr requestLatch; + std::shared_ptr shutdownLatch; }; } diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 1a6052ea2..4e497c416 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -148,20 +148,21 @@ void MessageEndpointServerHandler::start( ->sendResponse(buffer, respSize); } - // Wait on the worker latch if necessary - if (server->workerLatch != nullptr) { + // Wait on the request latch if necessary + if (server->requestLatch != nullptr) { SPDLOG_TRACE( "Server thread waiting on worker latch"); - server->workerLatch->wait(); + server->requestLatch->wait(); } } } // Just before the thread dies, check if there's something - // waiting on the latch - if (server->workerLatch != nullptr) { - SPDLOG_TRACE("Server thread {} waiting on worker latch", i); - server->workerLatch->wait(); + // waiting on the shutdown latch + if (server->shutdownLatch != nullptr) { + SPDLOG_TRACE("Server thread {} waiting on shutdown latch", + i); + server->shutdownLatch->wait(); } }); } @@ -254,9 +255,9 @@ void MessageEndpointServer::stop() nThreads, asyncPort); - setWorkerLatch(); + setShutdownLatch(); asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); - awaitWorkerLatch(); + awaitShutdownLatch(); } for (int i = 0; i < nThreads; i++) { @@ -265,10 +266,10 @@ void MessageEndpointServer::stop() nThreads, syncPort); - setWorkerLatch(); + setShutdownLatch(); syncShutdownSender.sendAwaitResponse(shutdownHeader.data(), shutdownHeader.size()); - awaitWorkerLatch(); + awaitShutdownLatch(); } // Join the handlers @@ -276,17 +277,36 @@ void MessageEndpointServer::stop() syncHandler.join(); } -void MessageEndpointServer::setWorkerLatch() +void MessageEndpointServer::setRequestLatch() { - workerLatch = faabric::util::Latch::create(2); + requestLatch = faabric::util::Latch::create(2); } -void MessageEndpointServer::awaitWorkerLatch() +void MessageEndpointServer::awaitRequestLatch() { SPDLOG_TRACE("Waiting on worker latch for port {}", asyncPort); - workerLatch->wait(); + requestLatch->wait(); SPDLOG_TRACE("Finished worker latch for port {}", asyncPort); - workerLatch = nullptr; + requestLatch = nullptr; +} + +void MessageEndpointServer::setShutdownLatch() +{ + shutdownLatch = faabric::util::Latch::create(2); +} + +void MessageEndpointServer::awaitShutdownLatch() +{ + SPDLOG_TRACE("Waiting on shutdown latch for port {}", asyncPort); + shutdownLatch->wait(); + + SPDLOG_TRACE("Finished shutdown latch for port {}", asyncPort); + shutdownLatch = nullptr; +} + +int MessageEndpointServer::getNThreads() +{ + return nThreads; } } diff --git a/tests/test/scheduler/test_function_client_server.cpp b/tests/test/scheduler/test_function_client_server.cpp index e96a3b86f..519679935 100644 --- a/tests/test/scheduler/test_function_client_server.cpp +++ b/tests/test/scheduler/test_function_client_server.cpp @@ -49,6 +49,17 @@ class ClientServerFixture } }; +TEST_CASE_METHOD(ConfTestFixture, + "Test setting function call server threads", + "[scheduler]") +{ + conf.functionServerThreads = 6; + + faabric::scheduler::FunctionCallServer server; + + REQUIRE(server.getNThreads() == 6); +} + TEST_CASE_METHOD(ClientServerFixture, "Test sending flush message", "[scheduler]") @@ -220,9 +231,9 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") *reqA.mutable_function() = msg; // Check that nothing's happened - server.setWorkerLatch(); + server.setRequestLatch(); cli.unregister(reqA); - server.awaitWorkerLatch(); + server.awaitRequestLatch(); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 1); // Make the request to unregister the actual host @@ -230,9 +241,9 @@ TEST_CASE_METHOD(ClientServerFixture, "Test unregister request", "[scheduler]") reqB.set_host(otherHost); *reqB.mutable_function() = msg; - server.setWorkerLatch(); + server.setRequestLatch(); cli.unregister(reqB); - server.awaitWorkerLatch(); + server.awaitRequestLatch(); REQUIRE(sch.getFunctionRegisteredHostCount(msg) == 0); diff --git a/tests/test/snapshot/test_snapshot_client_server.cpp b/tests/test/snapshot/test_snapshot_client_server.cpp index 60abe9ec4..c2c67c830 100644 --- a/tests/test/snapshot/test_snapshot_client_server.cpp +++ b/tests/test/snapshot/test_snapshot_client_server.cpp @@ -37,6 +37,17 @@ class SnapshotClientServerFixture ~SnapshotClientServerFixture() { server.stop(); } }; +TEST_CASE_METHOD(ConfTestFixture, + "Test setting snapshot server threads", + "[snapshot]") +{ + conf.snapshotServerThreads = 5; + + faabric::snapshot::SnapshotServer server; + + REQUIRE(server.getNThreads() == 5); +} + TEST_CASE_METHOD(SnapshotClientServerFixture, "Test pushing and deleting snapshots", "[snapshot]") diff --git a/tests/test/state/test_state_server.cpp b/tests/test/state/test_state_server.cpp index 2a5603dca..b2d8c6699 100644 --- a/tests/test/state/test_state_server.cpp +++ b/tests/test/state/test_state_server.cpp @@ -56,6 +56,17 @@ class SimpleStateServerTestFixture std::vector dataB; }; +TEST_CASE_METHOD(ConfTestFixture, + "Test setting state server threads", + "[state]") +{ + conf.stateServerThreads = 7; + + StateServer server(faabric::state::getGlobalState()); + + REQUIRE(server.getNThreads() == 7); +} + TEST_CASE_METHOD(SimpleStateServerTestFixture, "Test state request/ response", "[state]") diff --git a/tests/test/transport/test_message_server.cpp b/tests/test/transport/test_message_server.cpp index ff77e95ae..fdb4b4080 100644 --- a/tests/test/transport/test_message_server.cpp +++ b/tests/test/transport/test_message_server.cpp @@ -149,9 +149,9 @@ TEST_CASE("Test sending one message to server", "[transport]") std::string body = "body"; const uint8_t* bodyMsg = BYTES_CONST(body.c_str()); - server.setWorkerLatch(); + server.setRequestLatch(); cli.asyncSend(0, bodyMsg, body.size()); - server.awaitWorkerLatch(); + server.awaitRequestLatch(); REQUIRE(server.messageCount == 1); @@ -252,7 +252,7 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") // Note - here we must wait until the server has finished handling the // request, even though it's failed - server.setWorkerLatch(); + server.setRequestLatch(); // Make the call and check it fails try { @@ -264,7 +264,7 @@ TEST_CASE("Test client timeout on requests to valid server", "[transport]") REQUIRE(failed); // Wait for request to finish - server.awaitWorkerLatch(); + server.awaitRequestLatch(); } else { cli.syncSend(0, sleepBytes, sizeof(int), &response); REQUIRE(response.data() == "Response after sleep"); diff --git a/tests/test/util/test_config.cpp b/tests/test/util/test_config.cpp index 75c2d88ef..06e741e2f 100644 --- a/tests/test/util/test_config.cpp +++ b/tests/test/util/test_config.cpp @@ -47,6 +47,10 @@ TEST_CASE("Test overriding system config initialisation", "[util]") std::string globalTimeout = setEnvVar("GLOBAL_MESSAGE_TIMEOUT", "9876"); std::string boundTimeout = setEnvVar("BOUND_TIMEOUT", "6666"); + std::string functionThreads = setEnvVar("FUNCTION_SERVER_THREADS", "111"); + std::string stateThreads = setEnvVar("STATE_SERVER_THREADS", "222"); + std::string snapshotThreads = setEnvVar("SNAPSHOT_SERVER_THREADS", "333"); + std::string mpiSize = setEnvVar("DEFAULT_MPI_WORLD_SIZE", "2468"); // Create new conf for test @@ -66,6 +70,10 @@ TEST_CASE("Test overriding system config initialisation", "[util]") REQUIRE(conf.globalMessageTimeout == 9876); REQUIRE(conf.boundTimeout == 6666); + REQUIRE(conf.functionServerThreads == 111); + REQUIRE(conf.stateServerThreads == 222); + REQUIRE(conf.snapshotServerThreads == 333); + REQUIRE(conf.defaultMpiWorldSize == 2468); // Be careful with host type @@ -85,6 +93,10 @@ TEST_CASE("Test overriding system config initialisation", "[util]") setEnvVar("GLOBAL_MESSAGE_TIMEOUT", globalTimeout); setEnvVar("BOUND_TIMEOUT", boundTimeout); + setEnvVar("FUNCTION_SERVER_THREADS", functionThreads); + setEnvVar("STATE_SERVER_THREADS", stateThreads); + setEnvVar("SNAPSHOT_SERVER_THREADS", snapshotThreads); + setEnvVar("DEFAULT_MPI_WORLD_SIZE", mpiSize); } From 275692a2c377a7fc5990e22a26ac7fd364a38058 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Mon, 4 Oct 2021 17:02:34 +0000 Subject: [PATCH 14/16] Remove trace log in tests --- .github/workflows/tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4e5872a2a..6b40a0e25 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,7 +44,6 @@ jobs: HOST_TYPE: ci REDIS_QUEUE_HOST: redis REDIS_STATE_HOST: redis - LOG_LEVEL: trace container: image: faasm/faabric:0.1.2 defaults: From 989ef1042205fa5658711a09b44c85fc8481a3ab Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 6 Oct 2021 10:52:39 +0000 Subject: [PATCH 15/16] PR comments --- include/faabric/transport/MessageEndpoint.h | 8 +++- .../faabric/transport/MessageEndpointServer.h | 4 -- src/transport/MessageEndpointServer.cpp | 37 ++++++++----------- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/include/faabric/transport/MessageEndpoint.h b/include/faabric/transport/MessageEndpoint.h index bc1c8ec8e..2e955df5a 100644 --- a/include/faabric/transport/MessageEndpoint.h +++ b/include/faabric/transport/MessageEndpoint.h @@ -105,8 +105,14 @@ class SyncSendMessageEndpoint final : public MessageEndpoint class RecvMessageEndpoint : public MessageEndpoint { public: + /** + * Constructor for external TCP sockets + */ RecvMessageEndpoint(int portIn, int timeoutMs, zmq::socket_type socketType); + /** + * Constructor for internal inproc sockets + */ RecvMessageEndpoint(std::string inProcLabel, int timeoutMs, zmq::socket_type socketType, @@ -165,7 +171,7 @@ class SyncFanInMessageEndpoint final : public FanInMessageEndpoint int timeoutMs = DEFAULT_RECV_TIMEOUT_MS); }; -class AsyncRecvMessageEndpoint : public RecvMessageEndpoint +class AsyncRecvMessageEndpoint final : public RecvMessageEndpoint { public: AsyncRecvMessageEndpoint(const std::string& inprocLabel, diff --git a/include/faabric/transport/MessageEndpointServer.h b/include/faabric/transport/MessageEndpointServer.h index a40a520fd..7f9f7438a 100644 --- a/include/faabric/transport/MessageEndpointServer.h +++ b/include/faabric/transport/MessageEndpointServer.h @@ -78,10 +78,6 @@ class MessageEndpointServer const std::string inprocLabel; const int nThreads; - void setShutdownLatch(); - - void awaitShutdownLatch(); - MessageEndpointServerHandler asyncHandler; MessageEndpointServerHandler syncHandler; diff --git a/src/transport/MessageEndpointServer.cpp b/src/transport/MessageEndpointServer.cpp index 4e497c416..3e340675f 100644 --- a/src/transport/MessageEndpointServer.cpp +++ b/src/transport/MessageEndpointServer.cpp @@ -49,7 +49,7 @@ void MessageEndpointServerHandler::start( std::make_unique(inprocLabel); } - // Lauch worker threads + // Launch worker threads for (int i = 0; i < nThreads; i++) { workerThreads.emplace_back([this, i] { // Here we want to isolate all ZeroMQ stuff in its own @@ -246,18 +246,24 @@ void MessageEndpointServer::stop() // Here we send shutdown messages to each worker in turn, however, because // they're all connected on the same inproc port, we have to wait until each // one has shut down fully (i.e. the zmq socket has gone out of scope), - // before sending the next shutdown message (hence the use of the latch). If - // we don't do this, zmq will direct messages to sockets that are in the + // before sending the next shutdown message. + // If we don't do this, zmq will direct messages to sockets that are in the // process of shutting down and cause errors. + // To ensure each socket has closed, we use a latch with two slots, where + // this thread takes one of the slots, and the worker thread takes the other + // once it's finished shutting down. for (int i = 0; i < nThreads; i++) { SPDLOG_TRACE("Sending async shutdown message {}/{} to port {}", i + 1, nThreads, asyncPort); - setShutdownLatch(); + shutdownLatch = faabric::util::Latch::create(2); + asyncShutdownSender.send(shutdownHeader.data(), shutdownHeader.size()); - awaitShutdownLatch(); + + shutdownLatch->wait(); + shutdownLatch = nullptr; } for (int i = 0; i < nThreads; i++) { @@ -266,10 +272,13 @@ void MessageEndpointServer::stop() nThreads, syncPort); - setShutdownLatch(); + shutdownLatch = faabric::util::Latch::create(2); + syncShutdownSender.sendAwaitResponse(shutdownHeader.data(), shutdownHeader.size()); - awaitShutdownLatch(); + + shutdownLatch->wait(); + shutdownLatch = nullptr; } // Join the handlers @@ -291,20 +300,6 @@ void MessageEndpointServer::awaitRequestLatch() requestLatch = nullptr; } -void MessageEndpointServer::setShutdownLatch() -{ - shutdownLatch = faabric::util::Latch::create(2); -} - -void MessageEndpointServer::awaitShutdownLatch() -{ - SPDLOG_TRACE("Waiting on shutdown latch for port {}", asyncPort); - shutdownLatch->wait(); - - SPDLOG_TRACE("Finished shutdown latch for port {}", asyncPort); - shutdownLatch = nullptr; -} - int MessageEndpointServer::getNThreads() { return nThreads; From 587e6c08cc3c1fc12a001ff6702d430728a718e7 Mon Sep 17 00:00:00 2001 From: Simon Shillaker Date: Wed, 6 Oct 2021 11:00:11 +0000 Subject: [PATCH 16/16] Formatting --- include/faabric/transport/common.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/faabric/transport/common.h b/include/faabric/transport/common.h index d1d310224..523842787 100644 --- a/include/faabric/transport/common.h +++ b/include/faabric/transport/common.h @@ -15,4 +15,3 @@ #define SNAPSHOT_ASYNC_PORT 8007 #define SNAPSHOT_SYNC_PORT 8008 #define SNAPSHOT_INPROC_LABEL "snapshot" -