diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 3591d698..3ea0803c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -118,6 +118,7 @@ add_library( src/log.cpp src/request.cpp src/request_am.cpp + src/request_data.cpp src/request_helper.cpp src/request_stream.cpp src/request_tag.cpp diff --git a/cpp/benchmarks/perftest.cpp b/cpp/benchmarks/perftest.cpp index 53998475..a4fa1898 100644 --- a/cpp/benchmarks/perftest.cpp +++ b/cpp/benchmarks/perftest.cpp @@ -30,7 +30,7 @@ enum class ProgressMode { enum transfer_type_t { SEND, RECV }; typedef std::unordered_map> BufferMap; -typedef std::unordered_map TagMap; +typedef std::unordered_map TagMap; typedef std::shared_ptr BufferMapPtr; typedef std::shared_ptr TagMapPtr; @@ -267,7 +267,8 @@ auto doTransfer(const app_context_t& app_context, auto start = std::chrono::high_resolution_clock::now(); std::vector> requests = { endpoint->tagSend((*bufferMap)[SEND].data(), app_context.message_size, (*tagMap)[SEND]), - endpoint->tagRecv((*bufferMap)[RECV].data(), app_context.message_size, (*tagMap)[RECV])}; + endpoint->tagRecv( + (*bufferMap)[RECV].data(), app_context.message_size, (*tagMap)[RECV], ucxx::TagMaskFull)}; // Wait for requests and clear requests waitRequests(app_context.progress_mode, worker, requests); @@ -292,8 +293,8 @@ int main(int argc, char** argv) bool is_server = app_context.server_addr == NULL; auto tagMap = std::make_shared(TagMap{ - {SEND, is_server ? 0 : 1}, - {RECV, is_server ? 1 : 0}, + {SEND, is_server ? ucxx::Tag{0} : ucxx::Tag{1}}, + {RECV, is_server ? ucxx::Tag{1} : ucxx::Tag{0}}, }); std::shared_ptr listener_ctx; @@ -337,7 +338,8 @@ int main(int argc, char** argv) (*tagMap)[SEND])); requests.push_back(endpoint->tagRecv((*wireupBufferMap)[RECV].data(), (*wireupBufferMap)[RECV].size() * sizeof(int), - (*tagMap)[RECV])); + (*tagMap)[RECV], + ucxx::TagMaskFull)); // Wait for wireup requests and clear requests waitRequests(app_context.progress_mode, worker, requests); diff --git a/cpp/examples/basic.cpp b/cpp/examples/basic.cpp index 55c6453c..50b148ec 100644 --- a/cpp/examples/basic.cpp +++ b/cpp/examples/basic.cpp @@ -202,25 +202,27 @@ int main(int argc, char** argv) // Schedule small wireup messages to let UCX identify capabilities between endpoints requests.push_back(listener_ctx->getEndpoint()->tagSend( - sendWireupBuffer.data(), sendWireupBuffer.size() * sizeof(int), 0)); - requests.push_back( - endpoint->tagRecv(recvWireupBuffer.data(), sendWireupBuffer.size() * sizeof(int), 0)); + sendWireupBuffer.data(), sendWireupBuffer.size() * sizeof(int), ucxx::Tag{0})); + requests.push_back(endpoint->tagRecv(recvWireupBuffer.data(), + sendWireupBuffer.size() * sizeof(int), + ucxx::Tag{0}, + ucxx::TagMaskFull)); ::waitRequests(progress_mode, worker, requests); requests.clear(); // Schedule send and recv messages on different tags and different ordering requests.push_back(listener_ctx->getEndpoint()->tagSend( - sendBuffers[0].data(), sendBuffers[0].size() * sizeof(int), 0)); + sendBuffers[0].data(), sendBuffers[0].size() * sizeof(int), ucxx::Tag{0})); requests.push_back(listener_ctx->getEndpoint()->tagRecv( - recvBuffers[1].data(), recvBuffers[1].size() * sizeof(int), 1)); + recvBuffers[1].data(), recvBuffers[1].size() * sizeof(int), ucxx::Tag{1}, ucxx::TagMaskFull)); requests.push_back(listener_ctx->getEndpoint()->tagSend( - sendBuffers[2].data(), sendBuffers[2].size() * sizeof(int), 2)); - requests.push_back( - endpoint->tagRecv(recvBuffers[2].data(), recvBuffers[2].size() * sizeof(int), 2)); - requests.push_back( - endpoint->tagSend(sendBuffers[1].data(), sendBuffers[1].size() * sizeof(int), 1)); + sendBuffers[2].data(), sendBuffers[2].size() * sizeof(int), ucxx::Tag{2}, ucxx::TagMaskFull)); + requests.push_back(endpoint->tagRecv( + recvBuffers[2].data(), recvBuffers[2].size() * sizeof(int), ucxx::Tag{2}, ucxx::TagMaskFull)); requests.push_back( - endpoint->tagRecv(recvBuffers[0].data(), recvBuffers[0].size() * sizeof(int), 0)); + endpoint->tagSend(sendBuffers[1].data(), sendBuffers[1].size() * sizeof(int), ucxx::Tag{1})); + requests.push_back(endpoint->tagRecv( + recvBuffers[0].data(), recvBuffers[0].size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); // Wait for requests to be set, i.e., transfers complete ::waitRequests(progress_mode, worker, requests); diff --git a/cpp/include/ucxx/api.h b/cpp/include/ucxx/api.h index e79dd2e4..8f8452f9 100644 --- a/cpp/include/ucxx/api.h +++ b/cpp/include/ucxx/api.h @@ -19,4 +19,5 @@ #include #include #include +#include #include diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index 0d0876dc..847aa6e1 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -8,6 +8,7 @@ #include #include +#include #include namespace ucxx { @@ -55,43 +56,28 @@ std::shared_ptr createWorker(std::shared_ptr context, const bool enableFuture); // Transfers -std::shared_ptr createRequestAmSend(std::shared_ptr endpoint, - void* buffer, - size_t length, - ucs_memory_type_t memoryType, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData); - -std::shared_ptr createRequestAmRecv(std::shared_ptr endpoint, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData); - -std::shared_ptr createRequestStream(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, - const bool enablePythonFuture); - -std::shared_ptr createRequestTag(std::shared_ptr endpointOrWorker, - bool send, - void* buffer, - size_t length, - ucp_tag_t tag, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData); - -std::shared_ptr createRequestTagMultiSend(std::shared_ptr endpoint, - const std::vector& buffer, - const std::vector& size, - const std::vector& isCUDA, - const ucp_tag_t tag, - const bool enablePythonFuture); - -std::shared_ptr createRequestTagMultiRecv(std::shared_ptr endpoint, - const ucp_tag_t tag, - const bool enablePythonFuture); +std::shared_ptr createRequestAm( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + +std::shared_ptr createRequestStream( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture); + +std::shared_ptr createRequestTag( + std::shared_ptr endpointOrWorker, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + +std::shared_ptr createRequestTagMulti( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture); } // namespace ucxx diff --git a/cpp/include/ucxx/delayed_submission.h b/cpp/include/ucxx/delayed_submission.h index 00b43988..a57494bf 100644 --- a/cpp/include/ucxx/delayed_submission.h +++ b/cpp/include/ucxx/delayed_submission.h @@ -7,56 +7,23 @@ #include #include #include +#include +#include #include #include +#include #include #include +#include #include +#include namespace ucxx { typedef std::function DelayedSubmissionCallbackType; -class DelayedSubmission { - public: - bool _send{false}; ///< Whether this is a send (`true`) operation or recv (`false`) - void* _buffer{nullptr}; ///< Raw pointer to data buffer - size_t _length{0}; ///< Length of the message in bytes - ucp_tag_t _tag{0}; ///< Tag to match - ucs_memory_type_t _memoryType{UCS_MEMORY_TYPE_UNKNOWN}; ///< Buffer memory type - - DelayedSubmission() = delete; - - /** - * @brief Constructor for a delayed submission operation. - * - * Construct a delayed submission operation. Delayed submission means that a transfer - * operation will not be submitted immediately, but will rather be delayed for the next - * progress iteration. - * - * This may be useful to avoid any transfer operations to be executed directly in the - * application thread, delaying all of them for the worker progress thread when enabled. - * With this approach any perceived overhead will be removed from the application thread, - * and thus provide some speedup in certain situations. It may be also useful to prevent - * a multi-threaded application for blocking while waiting for the UCX spinlock, since - * all transfer operations may be pushed to the worker progress thread. - * - * @param[in] send whether this is a send (`true`) or receive (`false`) operation. - * @param[in] buffer a raw pointer to the data being transferred. - * @param[in] length the size in bytes of the message being transfer. - * @param[in] tag tag to match for this operation (only applies for tag - * operations). - * @param[in] memoryType the memory type of the buffer. - */ - DelayedSubmission(const bool send, - void* buffer, - const size_t length, - const ucp_tag_t tag = 0, - const ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_UNKNOWN); -}; - template class BaseDelayedSubmissionCollection { protected: diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 59c92063..d712db90 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -387,7 +387,7 @@ class Endpoint : public Component { */ std::shared_ptr tagSend(void* buffer, size_t length, - ucp_tag_t tag, + Tag tag, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); @@ -408,6 +408,7 @@ class Endpoint : public Component { * data will be stored. * @param[in] length the size in bytes of the tag message to be received. * @param[in] tag the tag to match. + * @param[in] tagMask the tag mask to use. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. @@ -417,7 +418,8 @@ class Endpoint : public Component { */ std::shared_ptr tagRecv(void* buffer, size_t length, - ucp_tag_t tag, + Tag tag, + TagMask tagMask, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); @@ -460,7 +462,7 @@ class Endpoint : public Component { std::shared_ptr tagMultiSend(const std::vector& buffer, const std::vector& size, const std::vector& isCUDA, - const ucp_tag_t tag, + const Tag tag, const bool enablePythonFuture); /** @@ -479,12 +481,15 @@ class Endpoint : public Component { * ensure the transfer has completed. Requires UCXX Python support. * * @param[in] tag the tag to match. + * @param[in] tagMask the tag mask to use. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * * @returns Request to be subsequently checked for the completion and its state. */ - std::shared_ptr tagMultiRecv(const ucp_tag_t tag, const bool enablePythonFuture); + std::shared_ptr tagMultiRecv(const Tag tag, + const TagMask tagMask, + const bool enablePythonFuture); /** * @brief Get `ucxx::Worker` component from a worker or listener object. diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index cd5f070b..6066f2a5 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #define ucxx_trace_req_f(_owner, _req, _name, _message, ...) \ @@ -34,9 +35,8 @@ class Request : public Component { std::shared_ptr _endpoint{ nullptr}; ///< Endpoint that generated request (if not from worker) std::string _ownerString{ - "undetermined owner"}; ///< String to print owner (endpoint or worker) when logging - std::shared_ptr _delayedSubmission{ - nullptr}; ///< The submission object that will dispatch the request + "undetermined owner"}; ///< String to print owner (endpoint or worker) when logging + data::RequestData _requestData{}; ///< The operation-specific data to be used in the request std::string _operationName{ "request_undefined"}; ///< Human-readable operation name, mostly used for log messages std::recursive_mutex _mutex{}; ///< Mutex to prevent checking status while it's being set @@ -62,7 +62,7 @@ class Request : public Component { * subsequently notified. */ Request(std::shared_ptr endpointOrWorker, - std::shared_ptr delayedSubmission, + const data::RequestData requestData, const std::string operationName, const bool enablePythonFuture = false); diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index ecf3f845..586abd89 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -4,6 +4,7 @@ */ #pragma once #include +#include #include #include @@ -24,13 +25,10 @@ class RequestAm : public Request { private: friend class internal::RecvAmMessage; - ucs_memory_type_t _sendHeader{}; ///< The header to send - std::shared_ptr _buffer{nullptr}; ///< The AM received message buffer - /** - * @brief Private constructor of `ucxx::RequestAm` send. + * @brief Private constructor of `ucxx::RequestAm`. * - * This is the internal implementation of `ucxx::RequestAm` send constructor, made private + * This is the internal implementation of `ucxx::RequestAm` constructor, made private * not to be called directly. This constructor is made private to ensure all UCXX objects * are shared pointers and the correct lifetime management of each one. * @@ -38,72 +36,46 @@ class RequestAm : public Request { * * - `ucxx::Endpoint::amSend()` * - `ucxx::createRequestAmSend()` + * - `ucxx::Endpoint::amReceive()` + * - `ucxx::createRequestAmReceive()` * * @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr`. * * @param[in] endpoint the parent endpoint. - * @param[in] buffer a raw pointer to the data to be sent. - * @param[in] length the size in bytes of the active message to be sent. - * @param[in] memoryType the memory type of the buffer. - * @param[in] enablePythonFuture whether a python future should be created and - * subsequently notified. - * @param[in] callbackFunction user-defined callback function to call upon completion. - * @param[in] callbackData user-defined data to pass to the `callbackFunction`. - */ - RequestAm(std::shared_ptr endpoint, - void* buffer, - size_t length, - ucs_memory_type_t memoryType, - const bool enablePythonFuture = false, - RequestCallbackUserFunction callbackFunction = nullptr, - RequestCallbackUserData callbackData = nullptr); - - /** - * @brief Private constructor of `ucxx::RequestAm` receive. - * - * This is the internal implementation of `ucxx::RequestAm` receive constructor, made - * private not to be called directly. This constructor is made private to ensure all UCXX - * objects are shared pointers and the correct lifetime management of each one. - * - * Instead the user should use one of the following: - * - * - `ucxx::Endpoint::amRecv()` - * - `ucxx::createRequestAmRecv()` - * - * @throws ucxx::Error if `endpointOrWorker` is not a valid - * `std::shared_ptr` or - * `std::shared_ptr`. - * - * @param[in] endpointOrWorker the parent component, which may either be a - * `std::shared_ptr` or - * `std::shared_ptr`. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. * @param[in] callbackData user-defined data to pass to the `callbackFunction`. */ RequestAm(std::shared_ptr endpointOrWorker, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); public: /** - * @brief Constructor for `std::shared_ptr` send. + * @brief Constructor for `std::shared_ptr`. * - * The constructor for a `std::shared_ptr` object, creating a send active + * The constructor for a `std::shared_ptr` object, creating an active * message request, returning a pointer to a request object that can be later awaited and * checked for errors. This is a non-blocking operation, and the status of the transfer - * must be verified from the resulting request object before the data can be - * released. + * must be verified from the resulting request object before the data can be released if + * this is a send operation, or consumed if this is a receive operation. Received data is + * available via the `getRecvBuffer()` method if the receive transfer request completed + * successfully. * * @throws ucxx::Error if `endpoint` is not a valid * `std::shared_ptr`. * * @param[in] endpoint the parent endpoint. - * @param[in] buffer a raw pointer to the data to be transferred. - * @param[in] length the size in bytes of the tag message to be transferred. - * @param[in] memoryType the memory type of the buffer. + * @param[in] requestData container of the specified message type, including all + * type-specific data. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. @@ -111,38 +83,9 @@ class RequestAm : public Request { * * @returns The `shared_ptr` object */ - friend std::shared_ptr createRequestAmSend( - std::shared_ptr endpoint, - void* buffer, - size_t length, - ucs_memory_type_t memoryType, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData); - - /** - * @brief Constructor for `std::shared_ptr` receive. - * - * The constructor for a `std::shared_ptr` object, creating a receive - * active message request, returning a pointer to a request object that can be later - * awaited and checked for errors. This is a non-blocking operation, and the status of - * the transfer must be verified from the resulting request object before the data can be - * consumed, the data is available via the `getRecvBuffer()` method if the transfer - * completed successfully. - * - * @throws ucxx::Error if `endpoint` is not a valid - * `std::shared_ptr`. - * - * @param[in] endpoint the parent endpoint. - * @param[in] enablePythonFuture whether a python future should be created and - * subsequently notified. - * @param[in] callbackFunction user-defined callback function to call upon completion. - * @param[in] callbackData user-defined data to pass to the `callbackFunction`. - * - * @returns The `shared_ptr` object - */ - friend std::shared_ptr createRequestAmRecv( + friend std::shared_ptr createRequestAm( std::shared_ptr endpoint, + const std::variant requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h new file mode 100644 index 00000000..03b1d9fe --- /dev/null +++ b/cpp/include/ucxx/request_data.h @@ -0,0 +1,206 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include + +#include + +#include + +namespace ucxx { + +class Buffer; + +namespace data { + +class AmSend { + public: + const void* _buffer{nullptr}; ///< The raw pointer where data to be sent is stored. + const size_t _length{0}; ///< The length of the message. + const ucs_memory_type_t _memoryType{UCS_MEMORY_TYPE_HOST}; ///< Memory type used on the operation + + /** + * @brief Constructor for Active Message-specific send data. + * + * Construct an object containing Active Message-specific send data. + * + * @param[in] memoryType the memory type of the buffer. + */ + explicit AmSend(const decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_memoryType) memoryType = UCS_MEMORY_TYPE_HOST); + + AmSend() = delete; +}; + +class AmReceive { + public: + std::shared_ptr<::ucxx::Buffer> _buffer{nullptr}; ///< The AM received message buffer + + /** + * @brief Constructor for Active Message-specific receive data. + * + * Construct an object containing Active Message-specific receive data. + * + * @param[in] memoryType the memory type of the buffer. + */ + AmReceive(); +}; + +class StreamSend { + public: + const void* _buffer{nullptr}; ///< The raw pointer where data to be sent is stored. + const size_t _length{0}; ///< The length of the message. + + /** + * @brief Constructor for stream-specific data. + * + * Construct an object containing stream-specific data. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + */ + explicit StreamSend(const decltype(_buffer) buffer, const decltype(_length) length); + + StreamSend() = delete; +}; + +class StreamReceive { + public: + void* _buffer{nullptr}; ///< The raw pointer where received data should be stored. + const size_t _length{0}; ///< The expected messaged length. + size_t _lengthReceived{0}; ///< The actual received message length. + + /** + * @brief Constructor for stream-specific data. + * + * Construct an object containing stream-specific data. + * + * @param[out] buffer a raw pointer to the received data. + * @param[in] length the size in bytes of the tag message to be received. + */ + explicit StreamReceive(decltype(_buffer) buffer, const decltype(_length) length); + + StreamReceive() = delete; +}; + +class TagSend { + public: + const void* _buffer{nullptr}; ///< The raw pointer where data to be sent is stored. + const size_t _length{0}; ///< The length of the message. + const ::ucxx::Tag _tag{0}; ///< Tag to match + + /** + * @brief Constructor for tag/multi-buffer tag-specific data. + * + * Construct an object containing tag-specific data. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] tag the tag to match. + */ + explicit TagSend(const decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_tag) tag); + + TagSend() = delete; +}; + +class TagReceive { + public: + void* _buffer{nullptr}; ///< The raw pointer where received data should be stored. + const size_t _length{0}; ///< The length of the message. + const ::ucxx::Tag _tag{0}; ///< Tag to match + const ::ucxx::TagMask _tagMask{0}; ///< Tag mask to use + + /** + * @brief Constructor send tag-specific data. + * + * Construct an object containing send tag-specific data. + * + * @param[out] buffer a raw pointer to the received data. + * @param[in] length the size in bytes of the tag message to be received. + * @param[in] tag the tag to match. + * @param[in] tagMask the tag mask to use (only used for receive operations). + */ + explicit TagReceive(decltype(_buffer) buffer, + const decltype(_length) length, + const decltype(_tag) tag, + const decltype(_tagMask) tagMask); + + TagReceive() = delete; +}; + +class TagMultiSend { + public: + const std::vector _buffer{}; ///< Raw pointers where data to be sent is stored. + const std::vector _length{}; ///< Lengths of messages. + const std::vector _isCUDA{}; ///< Flags indicating whether the buffer is CUDA or not. + const ::ucxx::Tag _tag{0}; ///< Tag to match + + /** + * @brief Constructor for send multi-buffer tag-specific data. + * + * Construct an object containing tag/multi-buffer tag-specific data. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] tag the tag to match. + */ + explicit TagMultiSend(const decltype(_buffer)& buffer, + const decltype(_length)& length, + const decltype(_isCUDA)& isCUDA, + const decltype(_tag) tag); + + TagMultiSend() = delete; +}; + +class TagMultiReceive { + public: + const ::ucxx::Tag _tag{0}; ///< Tag to match + const ::ucxx::TagMask _tagMask{0}; ///< Tag mask to use + + /** + * @brief Constructor for receive multi-buffer tag-specific data. + * + * Construct an object containing receive multi-buffer tag-specific data. + * + * @param[in] tag the tag to match. + * @param[in] tagMask the tag mask to use (only used for receive operations). + */ + explicit TagMultiReceive(const decltype(_tag) tag, const decltype(_tagMask) tagMask); + + TagMultiReceive() = delete; +}; + +using RequestData = std::variant; + +template +struct dispatch : Ts... { + using Ts::operator()...; +}; +template +dispatch(Ts...) -> dispatch; + +template +RequestData getRequestData(T t) +{ + return std::visit([](auto arg) -> RequestData { return arg; }, t); +} + +} // namespace data + +} // namespace ucxx diff --git a/cpp/include/ucxx/request_stream.h b/cpp/include/ucxx/request_stream.h index 27809904..95743e77 100644 --- a/cpp/include/ucxx/request_stream.h +++ b/cpp/include/ucxx/request_stream.h @@ -4,19 +4,19 @@ */ #pragma once #include +#include #include #include #include +#include #include namespace ucxx { class RequestStream : public Request { private: - size_t _length{0}; ///< The stream request length in bytes - /** * @brief Private constructor of `ucxx::RequestStream`. * @@ -31,18 +31,16 @@ class RequestStream : public Request { * - `ucxx::createRequestStream()` * * @param[in] endpoint the `std::shared_ptr` parent component - * @param[in] send whether this is a send (`true`) or receive (`false`) - * stream request. - * @param[in] buffer a raw pointer to the data to be transferred. - * @param[in] length the size in bytes of the stream message to be - * transferred. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. */ RequestStream(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture = false); public: @@ -56,21 +54,17 @@ class RequestStream : public Request { * released (for a send operation) or consumed (for a receive operation). * * @param[in] endpoint the `std::shared_ptr` parent component - * @param[in] send whether this is a send (`true`) or receive (`false`) - * stream request. - * @param[in] buffer a raw pointer to the data to be transferred. - * @param[in] length the size in bytes of the stream message to be - * transferred. + * @param[in] requestData container of the specified message type, including all + * type-specific data. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * * @returns The `shared_ptr` object */ - friend std::shared_ptr createRequestStream(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, - const bool enablePythonFuture); + friend std::shared_ptr createRequestStream( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture); virtual void populateDelayedSubmission(); diff --git a/cpp/include/ucxx/request_tag.h b/cpp/include/ucxx/request_tag.h index b818d983..e7c47449 100644 --- a/cpp/include/ucxx/request_tag.h +++ b/cpp/include/ucxx/request_tag.h @@ -4,6 +4,7 @@ */ #pragma once #include +#include #include #include @@ -16,8 +17,6 @@ namespace ucxx { class RequestTag : public Request { private: - size_t _length{0}; ///< The tag message length in bytes - /** * @brief Private constructor of `ucxx::RequestTag`. * @@ -38,21 +37,18 @@ class RequestTag : public Request { * @param[in] endpointOrWorker the parent component, which may either be a * `std::shared_ptr` or * `std::shared_ptr`. - * @param[in] send whether this is a send (`true`) or receive (`false`) - * tag request. - * @param[in] buffer a raw pointer to the data to be transferred. - * @param[in] length the size in bytes of the tag message to be transferred. - * @param[in] tag the tag to match. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. * @param[in] callbackData user-defined data to pass to the `callbackFunction`. */ RequestTag(std::shared_ptr endpointOrWorker, - bool send, - void* buffer, - size_t length, - ucp_tag_t tag, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); @@ -73,11 +69,8 @@ class RequestTag : public Request { * @param[in] endpointOrWorker the parent component, which may either be a * `std::shared_ptr` or * `std::shared_ptr`. - * @param[in] send whether this is a send (`true`) or receive (`false`) - * tag request. - * @param[in] buffer a raw pointer to the data to be transferred. - * @param[in] length the size in bytes of the tag message to be transferred. - * @param[in] tag the tag to match. + * @param[in] requestData container of the specified message type, including all + * type-specific data. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. @@ -85,14 +78,12 @@ class RequestTag : public Request { * * @returns The `shared_ptr` object */ - friend std::shared_ptr createRequestTag(std::shared_ptr endpointOrWorker, - bool send, - void* buffer, - size_t length, - ucp_tag_t tag, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData); + friend std::shared_ptr createRequestTag( + std::shared_ptr endpointOrWorker, + const std::variant requestData, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); virtual void populateDelayedSubmission(); diff --git a/cpp/include/ucxx/request_tag_multi.h b/cpp/include/ucxx/request_tag_multi.h index 6f705a1d..4dcc7b58 100644 --- a/cpp/include/ucxx/request_tag_multi.h +++ b/cpp/include/ucxx/request_tag_multi.h @@ -38,8 +38,6 @@ typedef std::shared_ptr BufferRequestPtr; class RequestTagMulti : public Request { private: - bool _send{false}; ///< Whether this is a send (`true`) operation or recv (`false`) - ucp_tag_t _tag{0}; ///< Tag to match size_t _totalFrames{0}; ///< The total number of frames handled by this request std::mutex _completedRequestsMutex{}; ///< Mutex to control access to completed requests container @@ -70,15 +68,16 @@ class RequestTagMulti : public Request { * the first request to receive a header. * * @param[in] endpoint the `std::shared_ptr` parent component - * @param[in] send whether this is a send (`true`) or receive (`false`) - * tag request. - * @param[in] tag the tag to match. + * @param[in] requestData container of the specified message type, including all + * type-specific data. + * @param[in] operationName a human-readable operation name to help identifying + * requests by their types when UCXX logging is enabled. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. */ RequestTagMulti(std::shared_ptr endpoint, - const bool send, - const ucp_tag_t tag, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture); /** @@ -111,22 +110,23 @@ class RequestTagMulti : public Request { * @brief Send all header(s) and frame(s). * * Build header request(s) and send them, followed by requests to send all frame(s). - * - * @throws std::length_error if the lengths of `buffer`, `size` and `isCUDA` do not - * match. */ - void send(const std::vector& buffer, - const std::vector& size, - const std::vector& isCUDA); + void send(); public: /** * @brief Enqueue a multi-buffer tag send operation. * - * Initiate a multi-buffer tag send operation, returning a - * `std::shared` that can be later awaited and checked for errors. - * This is a non-blocking operation, and the status of the transfer must be verified from - * the resulting request object before the data can be released. + * Initiate a multi-buffer tag operation, returning a `std::shared` + * that can be later awaited and checked for errors. + * + * This is a non-blocking operation, and the status of a send transfer must be verified + * from the resulting request object before the data can be released. If this is a receive + * transfer and because the receiver has no a priori knowledge of the data being received, + * memory allocations are automatically handled internally. The receiver must have the + * same capabilities of the sender, so that if the sender is compiled with RMM support to + * allow for CUDA transfers, the receiver must have the ability to understand and allocate + * CUDA memory. * * The primary use of multi-buffer transfers is in Python where we want to reduce the * amount of futures needed to watch for, thus reducing Python overhead. However, this @@ -143,7 +143,8 @@ class RequestTagMulti : public Request { * @throws std::runtime_error if sizes of `buffer`, `size` and `isCUDA` do not match. * * @param[in] endpoint the `std::shared_ptr` parent component - * @param[in] tag the tag to match. + * @param[in] requestData container of the specified message type, including all + * type-specific data. * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. @@ -151,40 +152,11 @@ class RequestTagMulti : public Request { * * @returns Request to be subsequently checked for the completion and its state. */ - friend std::shared_ptr createRequestTagMultiSend( + friend std::shared_ptr createRequestTagMulti( std::shared_ptr endpoint, - const std::vector& buffer, - const std::vector& size, - const std::vector& isCUDA, - const ucp_tag_t tag, + const std::variant requestData, const bool enablePythonFuture); - /** - * @brief Enqueue a multi-buffer tag receive operation. - * - * Enqueue a multi-buffer tag receive operation, returning a - * `std::shared` that can be later awaited and checked for errors. - * This is a non-blocking operation, and because the receiver has no a priori knowledge - * of the data being received, memory allocations are automatically handled internally. - * The receiver must have the same capabilities of the sender, so that if the sender is - * compiled with RMM support to allow for CUDA transfers, the receiver must have the - * ability to understand and allocate CUDA memory. - * - * Using a Python future may be requested by specifying `enablePythonFuture`. If a - * Python future is requested, the Python application must then await on this future to - * ensure the transfer has completed. Requires UCXX to be compiled with - * `UCXX_ENABLE_PYTHON=1`. - * - * @param[in] endpoint the `std::shared_ptr` parent component - * @param[in] tag the tag to match. - * @param[in] enablePythonFuture whether a python future should be created and - * subsequently notified. - * - * @returns Request to be subsequently checked for the completion and its state. - */ - friend std::shared_ptr createRequestTagMultiRecv( - std::shared_ptr endpoint, const ucp_tag_t tag, const bool enablePythonFuture); - /** * @brief `ucxx::RequestTagMulti` destructor. * diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 1c8b1933..03763017 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -5,10 +5,13 @@ #pragma once #include +#include #include #include #include +#include + namespace ucxx { class Buffer; @@ -32,6 +35,13 @@ typedef enum { UCXX_LOG_LEVEL_PRINT /* Temporary output */ } ucxx_log_level_t; +enum class TransferDirection { Send = 0, Receive }; + +enum Tag : ucp_tag_t {}; +enum TagMask : ucp_tag_t {}; + +static constexpr TagMask TagMaskFull{std::numeric_limits>::max()}; + typedef std::unordered_map ConfigMap; typedef std::function)> RequestCallbackUserFunction; diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index 66804ada..ada15f2c 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -55,8 +55,9 @@ class Worker : public Component { std::shared_ptr _delayedSubmissionCollection{ nullptr}; ///< Collection of enqueued delayed submissions - friend std::shared_ptr createRequestAmRecv( + friend std::shared_ptr createRequestAm( std::shared_ptr endpoint, + const std::variant requestData, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData); @@ -635,7 +636,7 @@ class Worker : public Component { * * @returns `true` if any uncaught messages were received, `false` otherwise. */ - bool tagProbe(const ucp_tag_t tag); + bool tagProbe(const Tag tag); /** * @brief Enqueue a tag receive operation. @@ -653,6 +654,7 @@ class Worker : public Component { * data will be stored. * @param[in] length the size in bytes of the tag message to be received. * @param[in] tag the tag to match. + * @param[in] tagMask the tag mask to use. * @param[in] enableFuture whether a future should be created and subsequently * notified. * @param[in] callbackFunction user-defined callback function to call upon completion. @@ -662,7 +664,8 @@ class Worker : public Component { */ std::shared_ptr tagRecv(void* buffer, size_t length, - ucp_tag_t tag, + Tag tag, + TagMask tagMask, const bool enableFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); diff --git a/cpp/src/delayed_submission.cpp b/cpp/src/delayed_submission.cpp index b005d12f..d953992a 100644 --- a/cpp/src/delayed_submission.cpp +++ b/cpp/src/delayed_submission.cpp @@ -13,15 +13,6 @@ namespace ucxx { -DelayedSubmission::DelayedSubmission(const bool send, - void* buffer, - const size_t length, - const ucp_tag_t tag, - const ucs_memory_type_t memoryType) - : _send(send), _buffer(buffer), _length(length), _tag(tag), _memoryType(memoryType) -{ -} - RequestDelayedSubmissionCollection::RequestDelayedSubmissionCollection(const std::string name, const bool enabled) : BaseDelayedSubmissionCollection< diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index e795f71c..5e40f028 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -310,8 +311,11 @@ std::shared_ptr Endpoint::amSend(void* buffer, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestAmSend( - endpoint, buffer, length, memoryType, enablePythonFuture, callbackFunction, callbackData)); + return registerInflightRequest(createRequestAm(endpoint, + data::AmSend(buffer, length, memoryType), + enablePythonFuture, + callbackFunction, + callbackData)); } std::shared_ptr Endpoint::amRecv(const bool enablePythonFuture, @@ -319,8 +323,8 @@ std::shared_ptr Endpoint::amRecv(const bool enablePythonFuture, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest( - createRequestAmRecv(endpoint, enablePythonFuture, callbackFunction, callbackData)); + return registerInflightRequest(createRequestAm( + endpoint, data::AmReceive(), enablePythonFuture, callbackFunction, callbackData)); } std::shared_ptr Endpoint::streamSend(void* buffer, @@ -329,7 +333,7 @@ std::shared_ptr Endpoint::streamSend(void* buffer, { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( - createRequestStream(endpoint, true, buffer, length, enablePythonFuture)); + createRequestStream(endpoint, data::StreamSend(buffer, length), enablePythonFuture)); } std::shared_ptr Endpoint::streamRecv(void* buffer, @@ -338,48 +342,58 @@ std::shared_ptr Endpoint::streamRecv(void* buffer, { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( - createRequestStream(endpoint, false, buffer, length, enablePythonFuture)); + createRequestStream(endpoint, data::StreamReceive(buffer, length), enablePythonFuture)); } std::shared_ptr Endpoint::tagSend(void* buffer, size_t length, - ucp_tag_t tag, + Tag tag, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestTag( - endpoint, true, buffer, length, tag, enablePythonFuture, callbackFunction, callbackData)); + return registerInflightRequest(createRequestTag(endpoint, + data::TagSend(buffer, length, tag), + enablePythonFuture, + callbackFunction, + callbackData)); } std::shared_ptr Endpoint::tagRecv(void* buffer, size_t length, - ucp_tag_t tag, + Tag tag, + TagMask tagMask, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestTag( - endpoint, false, buffer, length, tag, enablePythonFuture, callbackFunction, callbackData)); + return registerInflightRequest(createRequestTag(endpoint, + data::TagReceive(buffer, length, tag, tagMask), + enablePythonFuture, + callbackFunction, + callbackData)); } std::shared_ptr Endpoint::tagMultiSend(const std::vector& buffer, const std::vector& size, const std::vector& isCUDA, - const ucp_tag_t tag, + const Tag tag, const bool enablePythonFuture) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest( - createRequestTagMultiSend(endpoint, buffer, size, isCUDA, tag, enablePythonFuture)); + return registerInflightRequest(createRequestTagMulti( + endpoint, data::TagMultiSend(buffer, size, isCUDA, tag), enablePythonFuture)); } -std::shared_ptr Endpoint::tagMultiRecv(const ucp_tag_t tag, const bool enablePythonFuture) +std::shared_ptr Endpoint::tagMultiRecv(const Tag tag, + const TagMask tagMask, + const bool enablePythonFuture) { auto endpoint = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestTagMultiRecv(endpoint, tag, enablePythonFuture)); + return registerInflightRequest( + createRequestTagMulti(endpoint, data::TagMultiReceive(tag, tagMask), enablePythonFuture)); } std::shared_ptr Endpoint::getWorker() { return ::ucxx::getWorker(_parent); } diff --git a/cpp/src/internal/request_am.cpp b/cpp/src/internal/request_am.cpp index 873d22e2..4f0cedec 100644 --- a/cpp/src/internal/request_am.cpp +++ b/cpp/src/internal/request_am.cpp @@ -15,22 +15,30 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, ucp_ep_h ep, std::shared_ptr request, std::shared_ptr buffer) - : _amData(amData), _ep(ep), _request(request), _buffer(buffer) + : _amData(amData), _ep(ep), _request(request) { - _request->_delayedSubmission = - std::make_shared(false, _buffer->data(), _buffer->getSize()); + std::visit(data::dispatch{ + [this, buffer](data::AmReceive& amReceive) { amReceive._buffer = buffer; }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _request->_requestData); } void RecvAmMessage::setUcpRequest(void* request) { _request->_request = request; } void RecvAmMessage::callback(void* request, ucs_status_t status) { - _request->_buffer = _buffer; - _request->callback(request, status); - { - std::lock_guard lock(_amData->_mutex); - _amData->_recvAmMessageMap.erase(_request.get()); - } + std::visit(data::dispatch{ + [this, request, status](data::AmReceive amReceive) { + _request->callback(request, status); + { + std::lock_guard lock(_amData->_mutex); + _amData->_recvAmMessageMap.erase(_request.get()); + } + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _request->_requestData); } } // namespace internal diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 80365717..22aa8af4 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -17,10 +17,10 @@ namespace ucxx { Request::Request(std::shared_ptr endpointOrWorker, - std::shared_ptr delayedSubmission, + const data::RequestData requestData, const std::string operationName, const bool enablePythonFuture) - : _delayedSubmission(delayedSubmission), + : _requestData(requestData), _operationName(operationName), _enablePythonFuture(enablePythonFuture) { @@ -190,7 +190,14 @@ void Request::setStatus(ucs_status_t status) status, ucs_status_string(status)); - if (_status != UCS_INPROGRESS) ucxx_error("setStatus called but the status was already set"); + if (_status != UCS_INPROGRESS) + ucxx_error( + "setStatus called on request: %p with status: %d (%s) but status: %d (%s) was already set", + this, + status, + ucs_status_string(status), + _status, + ucs_status_string(_status)); _status = status; if (_enablePythonFuture) { diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index ba29374d..3a6fc455 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -16,64 +16,63 @@ namespace ucxx { -std::shared_ptr createRequestAmSend( +std::shared_ptr createRequestAm( std::shared_ptr endpoint, - void* buffer, - size_t length, - ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_HOST, + const std::variant requestData, const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr) { - auto req = std::shared_ptr(new RequestAm( - endpoint, buffer, length, memoryType, enablePythonFuture, callbackFunction, callbackData)); - - // A delayed notification request is not populated immediately, instead it is - // delayed to allow the worker progress thread to set its status, and more - // importantly the Python future later on, so that we don't need the GIL here. - req->_worker->registerDelayedSubmission( - req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + std::shared_ptr req = std::visit( + data::dispatch{ + [endpoint, enablePythonFuture, callbackFunction, callbackData](data::AmSend amSend) { + auto req = std::shared_ptr(new RequestAm( + endpoint, amSend, "amSend", enablePythonFuture, callbackFunction, callbackData)); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; + }, + [endpoint, enablePythonFuture, callbackFunction, callbackData](data::AmReceive amReceive) { + auto worker = endpoint->getWorker(); + + auto createRequest = [endpoint, + amReceive, + enablePythonFuture, + callbackFunction, + callbackData]() { + return std::shared_ptr(new RequestAm( + endpoint, amReceive, "amReceive", enablePythonFuture, callbackFunction, callbackData)); + }; + return worker->getAmRecv(endpoint->getHandle(), createRequest); + }, + }, + requestData); return req; } -std::shared_ptr createRequestAmRecv( - std::shared_ptr endpoint, - const bool enablePythonFuture = false, - RequestCallbackUserFunction callbackFunction = nullptr, - RequestCallbackUserData callbackData = nullptr) -{ - auto worker = endpoint->getWorker(); - - auto createRequest = [endpoint, enablePythonFuture, callbackFunction, callbackData]() { - return std::shared_ptr( - new RequestAm(endpoint, enablePythonFuture, callbackFunction, callbackData)); - }; - return worker->getAmRecv(endpoint->getHandle(), createRequest); -} - -RequestAm::RequestAm(std::shared_ptr endpoint, - void* buffer, - size_t length, - ucs_memory_type_t memoryType, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData) - : Request(endpoint, - std::make_shared(true, buffer, length, 0, memoryType), - std::string("amSend"), - enablePythonFuture) -{ - _callback = callbackFunction; - _callbackData = callbackData; -} - RequestAm::RequestAm(std::shared_ptr endpointOrWorker, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpointOrWorker, nullptr, std::string("amRecv"), enablePythonFuture) + : Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture) { + std::visit(data::dispatch{ + [this](data::AmSend amSend) { + if (_endpoint == nullptr) + throw ucxx::Error("An endpoint is required to send active messages"); + }, + [](data::AmReceive amReceive) {}, + }, + requestData); + _callback = callbackFunction; _callbackData = callbackData; } @@ -130,8 +129,8 @@ ucs_status_t RequestAm::recvCallback(void* arg, reqs->second.pop(); ucxx_trace_req("amRecv recvWait: %p", req.get()); } else { - req = std::shared_ptr( - new RequestAm(worker, worker->isFutureEnabled(), nullptr, nullptr)); + req = std::shared_ptr(new RequestAm( + worker, data::AmReceive(), "amReceive", worker->isFutureEnabled(), nullptr, nullptr)); auto [queue, _] = recvPool.try_emplace(ep, std::queue>()); queue->second.push(req); ucxx_trace_req("amRecv recvPool: %p", req.get()); @@ -146,7 +145,7 @@ ucs_status_t RequestAm::recvCallback(void* arg, // recvAmMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); // return UCS_ERR_UNSUPPORTED; - ucxx_trace_req("No allocator registered for memory type %d, falling back to host memory.", + ucxx_trace_req("No allocator registered for memory type %lu, falling back to host memory.", allocatorType); allocatorType = UCS_MEMORY_TYPE_HOST; } @@ -230,59 +229,93 @@ ucs_status_t RequestAm::recvCallback(void* arg, } } -std::shared_ptr RequestAm::getRecvBuffer() { return _buffer; } +std::shared_ptr RequestAm::getRecvBuffer() +{ + return std::visit( + data::dispatch{ + [](data::AmReceive amReceive) { return amReceive._buffer; }, + [](auto) -> std::shared_ptr { throw std::runtime_error("Unreachable"); }, + }, + _requestData); +} void RequestAm::request() { - static const ucp_tag_t tagMask = -1; - - ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_FLAGS | - UCP_OP_ATTR_FIELD_USER_DATA, - .flags = UCP_AM_SEND_FLAG_REPLY, - .datatype = ucp_dt_make_contig(1), - .user_data = this}; - - _sendHeader = _delayedSubmission->_memoryType; - - if (_delayedSubmission->_send) { - param.cb.send = _amSendCallback; - void* request = ucp_am_send_nbx(_endpoint->getHandle(), - 0, - &_sendHeader, - sizeof(_sendHeader), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - ¶m); - - std::lock_guard lock(_mutex); - _request = request; - } else { - throw ucxx::UnsupportedError( - "Receiving active messages must be handled by the worker's callback"); - } + std::visit( + data::dispatch{ + [this](data::AmSend amSend) { + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_FLAGS | + UCP_OP_ATTR_FIELD_USER_DATA, + .flags = UCP_AM_SEND_FLAG_REPLY | UCP_AM_SEND_FLAG_COPY_HEADER, + .datatype = ucp_dt_make_contig(1), + .user_data = this}; + + param.cb.send = _amSendCallback; + void* request = ucp_am_send_nbx(_endpoint->getHandle(), + 0, + &amSend._memoryType, + sizeof(amSend._memoryType), + amSend._buffer, + amSend._length, + ¶m); + + std::lock_guard lock(_mutex); + _request = request; + }, + [](auto) { throw ucxx::UnsupportedError("Only send active messages can call request()"); }, + }, + _requestData); } void RequestAm::populateDelayedSubmission() { + bool terminate = + std::visit(data::dispatch{ + [this](data::AmSend amSend) { + if (_endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before message could be sent"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [](auto) -> decltype(terminate) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + if (terminate) return; + request(); - if (_enablePythonFuture) - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "buffer %p, size %lu, future %p, future handle %p, populateDelayedSubmission", - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _future.get(), - _future->getHandle()); - else - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "buffer %p, size %lu, populateDelayedSubmission", - _delayedSubmission->_buffer, - _delayedSubmission->_length); + auto log = [this](const void* buffer, const size_t length, const ucs_memory_type_t memoryType) { + if (_enablePythonFuture) + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer %p, size %lu, memoryType: %lu, future %p, future handle %p, " + "populateDelayedSubmission", + buffer, + length, + memoryType, + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer %p, size %lu, memoryType: %lu, populateDelayedSubmission", + buffer, + length, + memoryType); + }; + + std::visit(data::dispatch{ + [this, &log](data::AmSend amSend) { + log(amSend._buffer, amSend._length, amSend._memoryType); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); process(); } diff --git a/cpp/src/request_data.cpp b/cpp/src/request_data.cpp new file mode 100644 index 00000000..3bc1efd4 --- /dev/null +++ b/cpp/src/request_data.cpp @@ -0,0 +1,73 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include + +#include + +#include +#include + +namespace ucxx { + +namespace data { + +AmSend::AmSend(const void* buffer, const size_t length, const ucs_memory_type memoryType) + : _buffer(buffer), _length(length), _memoryType(memoryType) +{ +} + +AmReceive::AmReceive() {} + +StreamSend::StreamSend(const void* buffer, const size_t length) : _buffer(buffer), _length(length) +{ + /** + * Stream API does not support zero-sized messages. See + * https://github.com/openucx/ucx/blob/6b45097e32c75c9b5d17f4770923204d568548d0/src/ucp/stream/stream_recv.c#L501 + */ + if (buffer == nullptr) throw std::runtime_error("Buffer cannot be a nullptr."); + if (length == 0) throw std::runtime_error("Length has to be a positive value."); +} + +StreamReceive::StreamReceive(void* buffer, const size_t length) : _buffer(buffer), _length(length) +{ + /** + * Stream API does not support zero-sized messages. See + * https://github.com/openucx/ucx/blob/6b45097e32c75c9b5d17f4770923204d568548d0/src/ucp/stream/stream_recv.c#L501 + */ + if (buffer == nullptr) throw std::runtime_error("Buffer cannot be a nullptr."); + if (length == 0) throw std::runtime_error("Length has to be a positive value."); +} + +TagSend::TagSend(const void* buffer, const size_t length, const ::ucxx::Tag tag) + : _buffer(buffer), _length(length), _tag(tag) +{ +} + +TagReceive::TagReceive(void* buffer, + const size_t length, + const ::ucxx::Tag tag, + const ::ucxx::TagMask tagMask) + : _buffer(buffer), _length(length), _tag(tag), _tagMask(tagMask) +{ +} + +TagMultiSend::TagMultiSend(const std::vector& buffer, + const std::vector& length, + const std::vector& isCUDA, + const ::ucxx::Tag tag) + : _buffer(buffer), _length(length), _isCUDA(isCUDA), _tag(tag) +{ + if (length.size() != buffer.size() || isCUDA.size() != buffer.size()) + throw std::runtime_error("All input vectors should be of equal size"); +} + +TagMultiReceive::TagMultiReceive(const ::ucxx::Tag tag, const ::ucxx::TagMask tagMask) + : _tag(tag), _tagMask(tagMask) +{ +} + +} // namespace data + +} // namespace ucxx diff --git a/cpp/src/request_stream.cpp b/cpp/src/request_stream.cpp index 624e3734..7d3db1a4 100644 --- a/cpp/src/request_stream.cpp +++ b/cpp/src/request_stream.cpp @@ -11,28 +11,44 @@ #include namespace ucxx { - RequestStream::RequestStream(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture) - : Request(endpoint, - std::make_shared(send, buffer, length), - std::string(send ? "streamSend" : "streamRecv"), - enablePythonFuture), - _length(length) + : Request(endpoint, data::getRequestData(requestData), operationName, enablePythonFuture) { + std::visit(data::dispatch{ + [this](data::StreamSend streamSend) { + if (_endpoint == nullptr) + throw ucxx::Error("A valid endpoint is required to send stream messages."); + }, + [this](data::StreamReceive streamReceive) { + if (_endpoint == nullptr) + throw ucxx::Error("A valid endpoint is required to receive stream messages."); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + requestData); } -std::shared_ptr createRequestStream(std::shared_ptr endpoint, - bool send, - void* buffer, - size_t length, - const bool enablePythonFuture = false) +std::shared_ptr createRequestStream( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture = false) { - auto req = std::shared_ptr( - new RequestStream(endpoint, send, buffer, length, enablePythonFuture)); + std::shared_ptr req = + std::visit(data::dispatch{ + [&endpoint, &enablePythonFuture](data::StreamSend streamSend) { + return std::shared_ptr( + new RequestStream(endpoint, streamSend, "streamSend", enablePythonFuture)); + }, + [&endpoint, &enablePythonFuture](data::StreamReceive streamReceive) { + return std::shared_ptr(new RequestStream( + endpoint, streamReceive, "streamReceive", enablePythonFuture)); + }, + [](auto) -> decltype(req) { throw std::runtime_error("Unreachable"); }, + }, + requestData); // A delayed notification request is not populated immediately, instead it is // delayed to allow the worker progress thread to set its status, and more @@ -52,20 +68,25 @@ void RequestStream::request() .user_data = this}; void* request = nullptr; - if (_delayedSubmission->_send) { - param.cb.send = streamSendCallback; - request = ucp_stream_send_nbx( - _endpoint->getHandle(), _delayedSubmission->_buffer, _delayedSubmission->_length, ¶m); - } else { - param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS; - param.flags = UCP_STREAM_RECV_FLAG_WAITALL; - param.cb.recv_stream = streamRecvCallback; - request = ucp_stream_recv_nbx(_endpoint->getHandle(), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - &_delayedSubmission->_length, - ¶m); - } + std::visit(data::dispatch{ + [this, &request, ¶m](data::StreamSend streamSend) { + param.cb.send = streamSendCallback; + request = ucp_stream_send_nbx( + _endpoint->getHandle(), streamSend._buffer, streamSend._length, ¶m); + }, + [this, &request, ¶m](data::StreamReceive streamReceive) { + param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS; + param.flags = UCP_STREAM_RECV_FLAG_WAITALL; + param.cb.recv_stream = streamRecvCallback; + request = ucp_stream_recv_nbx(_endpoint->getHandle(), + streamReceive._buffer, + streamReceive._length, + &streamReceive._lengthReceived, + ¶m); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); std::lock_guard lock(_mutex); _request = request; @@ -73,49 +94,83 @@ void RequestStream::request() void RequestStream::populateDelayedSubmission() { - if (_delayedSubmission->_send && _endpoint->getHandle() == nullptr) { - ucxx_warn("Endpoint was closed before message could be sent"); - Request::callback(this, UCS_ERR_CANCELED); - return; - } else if (!_delayedSubmission->_send && _worker->getHandle() == nullptr) { - ucxx_warn("Worker was closed before message could be received"); - Request::callback(this, UCS_ERR_CANCELED); - return; - } + bool terminate = + std::visit(data::dispatch{ + [this](data::StreamSend streamSend) { + if (_endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before message could be sent"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [this](data::StreamReceive streamReceive) { + if (_worker->getHandle() == nullptr) { + ucxx_warn("Worker was closed before message could be received"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [](auto) -> decltype(terminate) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + if (terminate) return; request(); - if (_enablePythonFuture) - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "buffer %p, size %lu, future %p, future handle %p, populateDelayedSubmission", - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _future.get(), - _future->getHandle()); - else - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "buffer %p, size %lu, populateDelayedSubmission", - _delayedSubmission->_buffer, - _delayedSubmission->_length); + auto log = [this](const void* buffer, const size_t length) { + if (_enablePythonFuture) + ucxx_trace_req_f( + _ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer %p, size %lu, future %p, future handle %p, populateDelayedSubmission", + buffer, + length, + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer %p, size %lu, populateDelayedSubmission", + buffer, + length); + }; + + std::visit( + data::dispatch{ + [this, &log](data::StreamSend streamSend) { log(streamSend._buffer, streamSend._length); }, + [this, &log](data::StreamReceive streamReceive) { + log(streamReceive._buffer, streamReceive._length); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + process(); } void RequestStream::callback(void* request, ucs_status_t status, size_t length) { - status = length == _length ? status : UCS_ERR_MESSAGE_TRUNCATED; - - if (status == UCS_ERR_MESSAGE_TRUNCATED) { - const char* fmt = "length mismatch: %llu (got) != %llu (expected)"; - size_t len = std::snprintf(nullptr, 0, fmt, length, _length); - _status_msg = std::string(len + 1, '\0'); // +1 for null terminator - std::snprintf(_status_msg.data(), _status_msg.size(), fmt, length, _length); - } - - Request::callback(request, status); + std::visit(data::dispatch{ + [this, &request, &status, &length](data::StreamReceive streamReceive) { + status = length == streamReceive._length ? status : UCS_ERR_MESSAGE_TRUNCATED; + + if (status == UCS_ERR_MESSAGE_TRUNCATED) { + const char* fmt = "length mismatch: %llu (got) != %llu (expected)"; + size_t len = std::snprintf(nullptr, 0, fmt, length, streamReceive._length); + _status_msg = std::string(len + 1, '\0'); // +1 for null terminator + std::snprintf( + _status_msg.data(), _status_msg.size(), fmt, length, streamReceive._length); + } + + Request::callback(request, status); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); } void RequestStream::streamSendCallback(void* request, ucs_status_t status, void* arg) diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index f1705c76..bcca4df1 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -9,27 +9,40 @@ #include #include +#include #include namespace ucxx { -std::shared_ptr createRequestTag(std::shared_ptr endpointOrWorker, - bool send, - void* buffer, - size_t length, - ucp_tag_t tag, - const bool enablePythonFuture = false, - RequestCallbackUserFunction callbackFunction = nullptr, - RequestCallbackUserData callbackData = nullptr) +std::shared_ptr createRequestTag( + std::shared_ptr endpointOrWorker, + const std::variant requestData, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) { - auto req = std::shared_ptr(new RequestTag(endpointOrWorker, - send, - buffer, - length, - tag, - enablePythonFuture, - callbackFunction, - callbackData)); + std::shared_ptr req = + std::visit(data::dispatch{ + [&endpointOrWorker, &enablePythonFuture, &callbackFunction, &callbackData]( + data::TagSend tagSend) { + return std::shared_ptr(new RequestTag(endpointOrWorker, + tagSend, + "tagSend", + enablePythonFuture, + callbackFunction, + callbackData)); + }, + [&endpointOrWorker, &enablePythonFuture, &callbackFunction, &callbackData]( + data::TagReceive tagReceive) { + return std::shared_ptr(new RequestTag(endpointOrWorker, + tagReceive, + "tagRecv", + enablePythonFuture, + callbackFunction, + callbackData)); + }, + }, + requestData); // A delayed notification request is not populated immediately, instead it is // delayed to allow the worker progress thread to set its status, and more @@ -41,21 +54,22 @@ std::shared_ptr createRequestTag(std::shared_ptr endpoint } RequestTag::RequestTag(std::shared_ptr endpointOrWorker, - bool send, - void* buffer, - size_t length, - ucp_tag_t tag, + const std::variant requestData, + const std::string operationName, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpointOrWorker, - std::make_shared(send, buffer, length, tag), - std::string(send ? "tagSend" : "tagRecv"), - enablePythonFuture), - _length(length) + : Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture) { - if (send && _endpoint == nullptr) - throw ucxx::Error("An endpoint is required to send tag messages"); + std::visit(data::dispatch{ + [this](data::TagSend tagSend) { + if (_endpoint == nullptr) + throw ucxx::Error("An endpoint is required to send tag messages"); + }, + [](data::TagReceive tagReceive) {}, + }, + requestData); + _callback = callbackFunction; _callbackData = callbackData; } @@ -93,8 +107,6 @@ void RequestTag::tagRecvCallback(void* request, void RequestTag::request() { - static const ucp_tag_t tagMask = -1; - ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_DATATYPE | UCP_OP_ATTR_FIELD_USER_DATA, @@ -102,60 +114,94 @@ void RequestTag::request() .user_data = this}; void* request = nullptr; - if (_delayedSubmission->_send) { - param.cb.send = tagSendCallback; - request = ucp_tag_send_nbx(_endpoint->getHandle(), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _delayedSubmission->_tag, - ¶m); - } else { - param.cb.recv = tagRecvCallback; - request = ucp_tag_recv_nbx(_worker->getHandle(), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _delayedSubmission->_tag, - tagMask, - ¶m); - } + std::visit(data::dispatch{ + [this, &request, ¶m](data::TagSend tagSend) { + param.cb.send = tagSendCallback; + request = ucp_tag_send_nbx( + _endpoint->getHandle(), tagSend._buffer, tagSend._length, tagSend._tag, ¶m); + }, + [this, &request, ¶m](data::TagReceive tagReceive) { + param.cb.recv = tagRecvCallback; + request = ucp_tag_recv_nbx(_worker->getHandle(), + tagReceive._buffer, + tagReceive._length, + tagReceive._tag, + tagReceive._tagMask, + ¶m); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); std::lock_guard lock(_mutex); _request = request; } +static void logPopulateDelayedSubmission() {} + void RequestTag::populateDelayedSubmission() { - if (_delayedSubmission->_send && _endpoint->getHandle() == nullptr) { - ucxx_warn("Endpoint was closed before message could be sent"); - Request::callback(this, UCS_ERR_CANCELED); - return; - } else if (!_delayedSubmission->_send && _worker->getHandle() == nullptr) { - ucxx_warn("Worker was closed before message could be received"); - Request::callback(this, UCS_ERR_CANCELED); - return; - } + bool terminate = + std::visit(data::dispatch{ + [this](data::TagSend tagSend) { + if (_endpoint->getHandle() == nullptr) { + ucxx_warn("Endpoint was closed before message could be sent"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [this](data::TagReceive tagReceive) { + if (_worker->getHandle() == nullptr) { + ucxx_warn("Worker was closed before message could be received"); + Request::callback(this, UCS_ERR_CANCELED); + return true; + } + return false; + }, + [](auto) -> decltype(terminate) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + if (terminate) return; request(); - if (_enablePythonFuture) - ucxx_trace_req_f( - _ownerString.c_str(), - _request, - _operationName.c_str(), - "tag 0x%lx, buffer %p, size %lu, future %p, future handle %p, populateDelayedSubmission", - _delayedSubmission->_tag, - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _future.get(), - _future->getHandle()); - else - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "tag 0x%lx, buffer %p, size %lu, populateDelayedSubmission", - _delayedSubmission->_tag, - _delayedSubmission->_buffer, - _delayedSubmission->_length); + auto log = [this](const void* buffer, const size_t length, const Tag tag, const TagMask tagMask) { + if (_enablePythonFuture) + ucxx_trace_req_f( + _ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer: %p, size: %lu, tag 0x%lx, tagMask: 0x%lx, future %p, future handle %p, " + "populateDelayedSubmission", + buffer, + length, + tag, + tagMask, + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f( + _ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer: %p, size: %lu, tag 0x%lx, tagMask: 0x%lx, populateDelayedSubmission", + buffer, + length, + tag, + tagMask); + }; + + std::visit(data::dispatch{ + [this, &log](data::TagSend tagSend) { + log(tagSend._buffer, tagSend._length, tagSend._tag, TagMaskFull); + }, + [this, &log](data::TagReceive tagReceive) { + log(tagReceive._buffer, tagReceive._length, tagReceive._tag, tagReceive._tagMask); + }, + [](auto) { throw std::runtime_error("Unreachable"); }, + }, + _requestData); process(); } diff --git a/cpp/src/request_tag_multi.cpp b/cpp/src/request_tag_multi.cpp index 385f51ab..ae9b4485 100644 --- a/cpp/src/request_tag_multi.cpp +++ b/cpp/src/request_tag_multi.cpp @@ -2,6 +2,7 @@ * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ +#include "ucxx/delayed_submission.h" #include #include #include @@ -11,26 +12,25 @@ #include #include #include +#include #include #include #include namespace ucxx { +typedef std::pair TagPair; + BufferRequest::BufferRequest() { ucxx_trace("BufferRequest created: %p", this); } BufferRequest::~BufferRequest() { ucxx_trace("BufferRequest destroyed: %p", this); } -RequestTagMulti::RequestTagMulti(std::shared_ptr endpoint, - const bool send, - const ucp_tag_t tag, - const bool enablePythonFuture) - : Request(endpoint, - std::make_shared(!send, nullptr, 0, 0), - std::string(send ? "tagMultiSend" : "tagMultiRecv"), - enablePythonFuture), - _send(send), - _tag(tag) +RequestTagMulti::RequestTagMulti( + std::shared_ptr endpoint, + const std::variant requestData, + const std::string operationName, + const bool enablePythonFuture) + : Request(endpoint, data::getRequestData(requestData), operationName, enablePythonFuture) { auto worker = endpoint->getWorker(); if (enablePythonFuture) _future = worker->getFuture(); @@ -54,52 +54,67 @@ RequestTagMulti::~RequestTagMulti() } } -std::shared_ptr createRequestTagMultiSend(std::shared_ptr endpoint, - const std::vector& buffer, - const std::vector& size, - const std::vector& isCUDA, - const ucp_tag_t tag, - const bool enablePythonFuture) +std::shared_ptr createRequestTagMulti( + std::shared_ptr endpoint, + const std::variant requestData, + const bool enablePythonFuture) { - auto ret = - std::shared_ptr(new RequestTagMulti(endpoint, true, tag, enablePythonFuture)); - - if (size.size() != buffer.size() || isCUDA.size() != buffer.size()) - throw std::runtime_error("All input vectors should be of equal size"); - - ret->send(buffer, size, isCUDA); - - return ret; + std::shared_ptr req = + std::visit(data::dispatch{ + [&endpoint, &enablePythonFuture](data::TagMultiSend tagMultiSend) { + auto req = std::shared_ptr(new RequestTagMulti( + endpoint, tagMultiSend, "tagMultiSend", enablePythonFuture)); + req->send(); + return req; + }, + [&endpoint, &enablePythonFuture](data::TagMultiReceive tagMultiReceive) { + auto req = std::shared_ptr(new RequestTagMulti( + endpoint, tagMultiReceive, "tagMultiRecv", enablePythonFuture)); + req->recvCallback(UCS_OK); + return req; + }, + }, + requestData); + + return req; } -std::shared_ptr createRequestTagMultiRecv(std::shared_ptr endpoint, - const ucp_tag_t tag, - const bool enablePythonFuture) +static TagPair checkAndGetTagPair(const data::RequestData& requestData, + const std::string methodName) { - auto ret = - std::shared_ptr(new RequestTagMulti(endpoint, false, tag, enablePythonFuture)); - - ret->recvCallback(UCS_OK); - - return ret; + return std::visit( + data::dispatch{ + [](data::TagMultiReceive tagMultiReceive) { + return TagPair{tagMultiReceive._tag, tagMultiReceive._tagMask}; + }, + [&methodName](auto) -> TagPair { + throw std::runtime_error(methodName + "() can only be called by a receive request."); + }, + }, + requestData); } void RequestTagMulti::recvFrames() { - if (_send) throw std::runtime_error("Send requests cannot call recvFrames()"); + auto tagPair = checkAndGetTagPair(_requestData, std::string("recvFrames")); std::vector
headers; - ucxx_trace_req("RequestTagMulti::recvFrames request: %p, tag: %lx, _bufferRequests.size(): %lu", - this, - _tag, - _bufferRequests.size()); + ucxx_trace_req( + "RequestTagMulti::recvFrames request: %p, tag: 0x%lx, tagMask: 0x%lx, _bufferRequests.size(): " + "%lu", + this, + tagPair.first, + tagPair.second, + _bufferRequests.size()); for (auto& br : _bufferRequests) { ucxx_trace_req( - "RequestTagMulti::recvFrames request: %p, tag: %lx, *br->stringBuffer.size(): %lu", + "RequestTagMulti::recvFrames request: %p, tag: 0x%lx, tagMask: 0x%lx, " + "*br->stringBuffer.size(): %lu", this, - _tag, + tagPair.first, + tagPair.second, br->stringBuffer->size()); headers.push_back(Header(*br->stringBuffer)); } @@ -114,26 +129,31 @@ void RequestTagMulti::recvFrames() bufferRequest->request = _endpoint->tagRecv( buf->data(), buf->getSize(), - _tag, + tagPair.first, + tagPair.second, false, [this](ucs_status_t status, RequestCallbackUserData arg) { return this->markCompleted(status, arg); }, bufferRequest); bufferRequest->buffer = buf; - ucxx_trace_req("RequestTagMulti::recvFrames request: %p, tag: %lx, buffer: %p", - this, - _tag, - bufferRequest->buffer); + ucxx_trace_req( + "RequestTagMulti::recvFrames request: %p, tag: 0x%lx, tagMask: 0x%lx, buffer: %p", + this, + tagPair.first, + tagPair.second, + bufferRequest->buffer.get()); } } _isFilled = true; - ucxx_trace_req("RequestTagMulti::recvFrames request: %p, tag: %lx, size: %lu, isFilled: %d", - this, - _tag, - _bufferRequests.size(), - _isFilled); + ucxx_trace_req( + "RequestTagMulti::recvFrames request: %p, tag: 0x%lx, tagMask: 0x%lx, size: %lu, isFilled: %d", + this, + tagPair.first, + tagPair.second, + _bufferRequests.size(), + _isFilled); }; void RequestTagMulti::markCompleted(ucs_status_t status, RequestCallbackUserData request) @@ -151,7 +171,21 @@ void RequestTagMulti::markCompleted(ucs_status_t status, RequestCallbackUserData return; } - ucxx_trace_req("RequestTagMulti::markCompleted request: %p, tag: %lx", this, _tag); + TagPair tagPair = std::visit(data::dispatch{ + [](data::TagMultiSend tagMultiSend) { + return TagPair{tagMultiSend._tag, TagMaskFull}; + }, + [](data::TagMultiReceive tagMultiReceive) { + return TagPair{tagMultiReceive._tag, tagMultiReceive._tagMask}; + }, + [](auto) -> TagPair { throw std::runtime_error("Unreachable"); }, + }, + _requestData); + + ucxx_trace_req("RequestTagMulti::markCompleted request: %p, tag: 0x%lx, tagMask: 0x%lx", + this, + tagPair.first, + tagPair.second); std::lock_guard lock(_completedRequestsMutex); if (_finalStatus == UCS_OK && status != UCS_OK) _finalStatus = status; @@ -160,28 +194,35 @@ void RequestTagMulti::markCompleted(ucs_status_t status, RequestCallbackUserData setStatus(_finalStatus); ucxx_trace_req( - "RequestTagMulti::markCompleted request: %p, tag: %lx, completed: %lu/%lu, final status: %d " + "RequestTagMulti::markCompleted request: %p, tag: 0x%lx, tagMask: 0x%lx, completed: %lu/%lu, " + "final status: %d " "(%s)", this, - _tag, + tagPair.first, + tagPair.second, _completedRequests, _totalFrames, _finalStatus, ucs_status_string(_finalStatus)); } else { - ucxx_trace_req("RequestTagMulti::markCompleted request: %p, tag: %lx, completed: %lu/%lu", - this, - _tag, - _completedRequests, - _totalFrames); + ucxx_trace_req( + "RequestTagMulti::markCompleted request: %p, tag: 0x%lx, tagMask: 0x%lx, completed: %lu/%lu", + this, + tagPair.first, + tagPair.second, + _completedRequests, + _totalFrames); } } void RequestTagMulti::recvHeader() { - if (_send) throw std::runtime_error("Send requests cannot call recvHeader()"); + auto tagPair = checkAndGetTagPair(_requestData, std::string("recvHeader")); - ucxx_trace_req("RequestTagMulti::recvHeader entering, request: %p, tag: %lx", this, _tag); + ucxx_trace_req("RequestTagMulti::recvHeader entering, request: %p, tag: 0x%lx, tagMask: 0x%lx", + this, + tagPair.first, + tagPair.second); auto bufferRequest = std::make_shared(); _bufferRequests.push_back(bufferRequest); @@ -189,7 +230,8 @@ void RequestTagMulti::recvHeader() bufferRequest->request = _endpoint->tagRecv(&bufferRequest->stringBuffer->front(), bufferRequest->stringBuffer->size(), - _tag, + tagPair.first, + tagPair.second, false, [this](ucs_status_t status, RequestCallbackUserData arg) { return this->recvCallback(status); @@ -200,33 +242,44 @@ void RequestTagMulti::recvHeader() bufferRequest->request->checkError(); } - ucxx_trace_req("RequestTagMulti::recvHeader exiting, request: %p, tag: %lx, empty: %d", - this, - _tag, - _bufferRequests.empty()); + ucxx_trace_req( + "RequestTagMulti::recvHeader exiting, request: %p, tag: 0x%lx, tagMask: 0x%lx, empty: %d", + this, + tagPair.first, + tagPair.second, + _bufferRequests.empty()); } void RequestTagMulti::recvCallback(ucs_status_t status) { - if (_send) throw std::runtime_error("Send requests cannot call recvCallback()"); + auto tagPair = checkAndGetTagPair(_requestData, std::string("recvCallback")); - ucxx_trace_req("RequestTagMulti::recvCallback request: %p, tag: %lx", this, _tag); + ucxx_trace_req("RequestTagMulti::recvCallback request: %p, tag: 0x%lx, tagMask: 0x%lx", + this, + tagPair.first, + tagPair.second); if (_bufferRequests.empty()) { recvHeader(); } else { if (status == UCS_OK) { ucxx_trace_req( - "RequestTagMulti::recvCallback header received, multi request: %p, tag: %lx", this, _tag); + "RequestTagMulti::recvCallback header received, multi request: %p, tag: 0x%lx, tagMask: " + "0x%lx", + this, + tagPair.first, + tagPair.second); } else { ucxx_trace_req( "RequestTagMulti::recvCallback failed receiving header with status %d (%s), multi request: " "%p, " - "tag: %lx", + "tag: 0x%lx", + "tagMask: 0x%lx", status, ucs_status_string(status), this, - _tag); + tagPair.first, + tagPair.second); _status = status; if (_future) _future->notify(status); @@ -243,38 +296,46 @@ void RequestTagMulti::recvCallback(ucs_status_t status) } } -void RequestTagMulti::send(const std::vector& buffer, - const std::vector& size, - const std::vector& isCUDA) +void RequestTagMulti::send() { - _totalFrames = buffer.size(); - - if ((size.size() != _totalFrames) || (isCUDA.size() != _totalFrames)) - throw std::length_error("buffer, size and isCUDA must have the same length"); - - auto headers = Header::buildHeaders(size, isCUDA); - - for (const auto& header : headers) { - auto serializedHeader = std::make_shared(header.serialize()); - auto bufferRequest = std::make_shared(); - _bufferRequests.push_back(bufferRequest); - bufferRequest->request = - _endpoint->tagSend(&serializedHeader->front(), serializedHeader->size(), _tag, false); - bufferRequest->stringBuffer = serializedHeader; - } - - for (size_t i = 0; i < _totalFrames; ++i) { - auto bufferRequest = std::make_shared(); - _bufferRequests.push_back(bufferRequest); - bufferRequest->request = _endpoint->tagSend( - buffer[i], size[i], _tag, false, [this](ucs_status_t status, RequestCallbackUserData arg) { - return this->markCompleted(status, arg); - }); - } - - _isFilled = true; - ucxx_trace_req( - "RequestTagMulti::send request: %p, tag: %lx, isFilled: %d", this, _tag, _isFilled); + std::visit( + data::dispatch{ + [this](data::TagMultiSend tagMultiSend) { + _totalFrames = tagMultiSend._buffer.size(); + + auto headers = Header::buildHeaders(tagMultiSend._length, tagMultiSend._isCUDA); + + for (const auto& header : headers) { + auto serializedHeader = std::make_shared(header.serialize()); + auto bufferRequest = std::make_shared(); + _bufferRequests.push_back(bufferRequest); + bufferRequest->request = _endpoint->tagSend( + &serializedHeader->front(), serializedHeader->size(), tagMultiSend._tag, false); + bufferRequest->stringBuffer = serializedHeader; + } + + for (size_t i = 0; i < _totalFrames; ++i) { + auto bufferRequest = std::make_shared(); + _bufferRequests.push_back(bufferRequest); + bufferRequest->request = + _endpoint->tagSend(tagMultiSend._buffer[i], + tagMultiSend._length[i], + tagMultiSend._tag, + false, + [this](ucs_status_t status, RequestCallbackUserData arg) { + return this->markCompleted(status, arg); + }); + } + + _isFilled = true; + ucxx_trace_req("RequestTagMulti::send request: %p, tag: 0x%lx, isFilled: %d", + this, + tagMultiSend._tag, + _isFilled); + }, + [](auto) { throw std::runtime_error("send() can only be called by a sendrequest."); }, + }, + _requestData); } void RequestTagMulti::populateDelayedSubmission() {} diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index cb4dc50b..8ccf8a89 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -467,7 +467,7 @@ void Worker::removeInflightRequest(const Request* const request) } } -bool Worker::tagProbe(const ucp_tag_t tag) +bool Worker::tagProbe(const Tag tag) { if (!isProgressThreadRunning()) { progress(); @@ -494,14 +494,18 @@ bool Worker::tagProbe(const ucp_tag_t tag) std::shared_ptr Worker::tagRecv(void* buffer, size_t length, - ucp_tag_t tag, + Tag tag, + TagMask tagMask, const bool enableFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { auto worker = std::dynamic_pointer_cast(shared_from_this()); - return registerInflightRequest(createRequestTag( - worker, false, buffer, length, tag, enableFuture, callbackFunction, callbackData)); + return registerInflightRequest(createRequestTag(worker, + data::TagReceive(buffer, length, tag, tagMask), + enableFuture, + callbackFunction, + callbackData)); } std::shared_ptr
Worker::getAddress() diff --git a/cpp/tests/endpoint.cpp b/cpp/tests/endpoint.cpp index ffc32e7e..16fa4370 100644 --- a/cpp/tests/endpoint.cpp +++ b/cpp/tests/endpoint.cpp @@ -46,7 +46,7 @@ TEST_F(EndpointTest, IsAlive) ASSERT_TRUE(ep->isAlive()); std::vector buf{123}; - auto send_req = ep->tagSend(buf.data(), buf.size() * sizeof(int), 0); + auto send_req = ep->tagSend(buf.data(), buf.size() * sizeof(int), ucxx::Tag{0}); while (!send_req->isCompleted()) _worker->progress(); diff --git a/cpp/tests/listener.cpp b/cpp/tests/listener.cpp index 0dc7ad19..44a92512 100644 --- a/cpp/tests/listener.cpp +++ b/cpp/tests/listener.cpp @@ -111,16 +111,17 @@ TEST_F(ListenerTest, EndpointSendRecv) std::vector client_buf{123}; std::vector server_buf{0}; - requests.push_back(ep->tagSend(client_buf.data(), client_buf.size() * sizeof(int), 0)); - requests.push_back( - listenerContainer->endpoint->tagRecv(&server_buf.front(), server_buf.size() * sizeof(int), 0)); + requests.push_back(ep->tagSend(client_buf.data(), client_buf.size() * sizeof(int), ucxx::Tag{0})); + requests.push_back(listenerContainer->endpoint->tagRecv( + &server_buf.front(), server_buf.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); ::waitRequests(_worker, requests, progress); ASSERT_EQ(server_buf[0], client_buf[0]); - requests.push_back( - listenerContainer->endpoint->tagSend(&server_buf.front(), server_buf.size() * sizeof(int), 1)); - requests.push_back(ep->tagRecv(client_buf.data(), client_buf.size() * sizeof(int), 1)); + requests.push_back(listenerContainer->endpoint->tagSend( + &server_buf.front(), server_buf.size() * sizeof(int), ucxx::Tag{1})); + requests.push_back(ep->tagRecv( + client_buf.data(), client_buf.size() * sizeof(int), ucxx::Tag{1}, ucxx::TagMaskFull)); ::waitRequests(_worker, requests, progress); ASSERT_EQ(client_buf[0], server_buf[0]); @@ -140,7 +141,7 @@ TEST_F(ListenerTest, IsAlive) ASSERT_TRUE(ep->isAlive()); std::vector buf{123}; - auto send_req = ep->tagSend(buf.data(), buf.size() * sizeof(int), 0); + auto send_req = ep->tagSend(buf.data(), buf.size() * sizeof(int), ucxx::Tag{0}); while (!send_req->isCompleted()) _worker->progress(); diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 7838bd6d..5055653d 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -194,15 +194,20 @@ TEST_P(RequestTest, ProgressStream) allocate(); // Submit and wait for transfers to complete - std::vector> requests; - requests.push_back(_ep->streamSend(_sendPtr[0], _messageSize, 0)); - requests.push_back(_ep->streamRecv(_recvPtr[0], _messageSize, 0)); - waitRequests(_worker, requests, _progressWorker); - - copyResults(); - - // Assert data correctness - ASSERT_THAT(_recv[0], ContainerEq(_send[0])); + if (_messageSize == 0) { + EXPECT_THROW(_ep->streamSend(_sendPtr[0], _messageSize, 0), std::runtime_error); + EXPECT_THROW(_ep->streamRecv(_recvPtr[0], _messageSize, 0), std::runtime_error); + } else { + std::vector> requests; + requests.push_back(_ep->streamSend(_sendPtr[0], _messageSize, 0)); + requests.push_back(_ep->streamRecv(_recvPtr[0], _messageSize, 0)); + waitRequests(_worker, requests, _progressWorker); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); + } } TEST_P(RequestTest, ProgressTag) @@ -211,8 +216,8 @@ TEST_P(RequestTest, ProgressTag) // Submit and wait for transfers to complete std::vector> requests; - requests.push_back(_ep->tagSend(_sendPtr[0], _messageSize, 0)); - requests.push_back(_ep->tagRecv(_recvPtr[0], _messageSize, 0)); + requests.push_back(_ep->tagSend(_sendPtr[0], _messageSize, ucxx::Tag{0})); + requests.push_back(_ep->tagRecv(_recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull)); waitRequests(_worker, requests, _progressWorker); copyResults(); @@ -238,8 +243,8 @@ TEST_P(RequestTest, ProgressTagMulti) // Submit and wait for transfers to complete std::vector> requests; - requests.push_back(_ep->tagMultiSend(_sendPtr, multiSize, multiIsCUDA, 0, false)); - requests.push_back(_ep->tagMultiRecv(0, false)); + requests.push_back(_ep->tagMultiSend(_sendPtr, multiSize, multiIsCUDA, ucxx::Tag{0}, false)); + requests.push_back(_ep->tagMultiRecv(ucxx::Tag{0}, ucxx::TagMaskFull, false)); waitRequests(_worker, requests, _progressWorker); auto recvRequest = requests[1]; @@ -286,8 +291,10 @@ TEST_P(RequestTest, TagUserCallback) auto recvIndex = std::make_shared(1u); // Submit and wait for transfers to complete - requests[0] = _ep->tagSend(_sendPtr[0], _messageSize, 0, false, checkStatus, sendIndex); - requests[1] = _ep->tagRecv(_recvPtr[0], _messageSize, 0, false, checkStatus, recvIndex); + requests[0] = + _ep->tagSend(_sendPtr[0], _messageSize, ucxx::Tag{0}, false, checkStatus, sendIndex); + requests[1] = _ep->tagRecv( + _recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull, false, checkStatus, recvIndex); waitRequests(_worker, requests, _progressWorker); copyResults(); @@ -309,7 +316,7 @@ INSTANTIATE_TEST_SUITE_P(ProgressModes, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576))); INSTANTIATE_TEST_SUITE_P(DelayedSubmission, RequestTest, @@ -317,7 +324,7 @@ INSTANTIATE_TEST_SUITE_P(DelayedSubmission, Values(false), Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576))); #if UCXX_ENABLE_RMM INSTANTIATE_TEST_SUITE_P(RMMProgressModes, @@ -330,7 +337,7 @@ INSTANTIATE_TEST_SUITE_P(RMMProgressModes, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576))); INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, RequestTest, @@ -338,7 +345,7 @@ INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, Values(false, true), Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576))); #endif } // namespace diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index d9651beb..254b54ef 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -84,19 +84,19 @@ TEST_F(WorkerTest, TagProbe) auto progressWorker = getProgressFunction(_worker, ProgressMode::Polling); auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); - ASSERT_FALSE(_worker->tagProbe(0)); + ASSERT_FALSE(_worker->tagProbe(ucxx::Tag{0})); std::vector buf{123}; std::vector> requests; - requests.push_back(ep->tagSend(buf.data(), buf.size() * sizeof(int), 0)); + requests.push_back(ep->tagSend(buf.data(), buf.size() * sizeof(int), ucxx::Tag{0})); waitRequests(_worker, requests, progressWorker); // Attempt to progress worker 10 times (arbitrarily defined). // TODO: Maybe a timeout would fit best. - for (size_t i = 0; i < 10 && !_worker->tagProbe(0); ++i) + for (size_t i = 0; i < 10 && !_worker->tagProbe(ucxx::Tag{0}); ++i) progressWorker(); - ASSERT_TRUE(_worker->tagProbe(0)); + ASSERT_TRUE(_worker->tagProbe(ucxx::Tag{0})); } TEST_F(WorkerTest, AmProbe) @@ -113,7 +113,7 @@ TEST_F(WorkerTest, AmProbe) // Attempt to progress worker 10 times (arbitrarily defined). // TODO: Maybe a timeout would fit best. - for (size_t i = 0; i < 10 && !_worker->tagProbe(0); ++i) + for (size_t i = 0; i < 10 && !_worker->tagProbe(ucxx::Tag{0}); ++i) progressWorker(); ASSERT_TRUE(_worker->amProbe(ep->getHandle())); @@ -169,8 +169,9 @@ TEST_P(WorkerProgressTest, ProgressTag) std::vector recv(1); std::vector> requests; - requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), 0)); - requests.push_back(ep->tagRecv(recv.data(), recv.size() * sizeof(int), 0)); + requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); + requests.push_back( + ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); waitRequests(_worker, requests, _progressWorker); ASSERT_EQ(recv[0], send[0]); @@ -193,8 +194,8 @@ TEST_P(WorkerProgressTest, ProgressTagMulti) std::vector multiIsCUDA(numMulti, false); std::vector> requests; - requests.push_back(ep->tagMultiSend(multiBuffer, multiSize, multiIsCUDA, 0, false)); - requests.push_back(ep->tagMultiRecv(0, false)); + requests.push_back(ep->tagMultiSend(multiBuffer, multiSize, multiIsCUDA, ucxx::Tag{0}, false)); + requests.push_back(ep->tagMultiRecv(ucxx::Tag{0}, ucxx::TagMaskFull, false)); waitRequests(_worker, requests, _progressWorker); for (const auto& br : diff --git a/python/examples/basic.py b/python/examples/basic.py index 18ab4205..239bd7b2 100644 --- a/python/examples/basic.py +++ b/python/examples/basic.py @@ -178,8 +178,8 @@ def listener_callback(conn_request): worker.progress_worker_event() wireup_requests = [ - ep.tag_send(Array(wireup_send_buf), tag=0), - listener_ep.tag_recv(Array(wireup_recv_buf), tag=0), + ep.tag_send(Array(wireup_send_buf), tag=ucx_api.UCXXTag(0)), + listener_ep.tag_recv(Array(wireup_recv_buf), tag=ucx_api.UCXXTag(0)), ] _wait_requests(worker, args.progress_mode, wireup_requests) @@ -201,7 +201,9 @@ def listener_callback(conn_request): # data_ptrs, sizes, is_cuda, tag=0 # ) - send_buffer_requests = listener_ep.tag_send_multi(frames, tag=0) + send_buffer_requests = listener_ep.tag_send_multi( + frames, tag=ucx_api.UCXXTag(0) + ) recv_buffer_requests = ep.tag_recv_multi(0) requests = [send_buffer_requests, recv_buffer_requests] @@ -225,12 +227,12 @@ def listener_callback(conn_request): recv_bufs = recv_buffer_requests.get_py_buffers() else: requests = [ - listener_ep.tag_send(Array(send_bufs[0]), tag=0), - listener_ep.tag_send(Array(send_bufs[1]), tag=1), - listener_ep.tag_send(Array(send_bufs[2]), tag=2), - ep.tag_recv(Array(recv_bufs[0]), tag=0), - ep.tag_recv(Array(recv_bufs[1]), tag=1), - ep.tag_recv(Array(recv_bufs[2]), tag=2), + listener_ep.tag_send(Array(send_bufs[0]), tag=ucx_api.UCXTag(0)), + listener_ep.tag_send(Array(send_bufs[1]), tag=ucx_api.UCXTag(1)), + listener_ep.tag_send(Array(send_bufs[2]), tag=ucx_api.UCXTag(2)), + ep.tag_recv(Array(recv_bufs[0]), tag=ucx_api.UCXTag(0)), + ep.tag_recv(Array(recv_bufs[1]), tag=ucx_api.UCXTag(1)), + ep.tag_recv(Array(recv_bufs[2]), tag=ucx_api.UCXTag(2)), ] if args.asyncio_wait_future: diff --git a/python/ucxx/__init__.py b/python/ucxx/__init__.py index acda9988..6719905d 100644 --- a/python/ucxx/__init__.py +++ b/python/ucxx/__init__.py @@ -16,7 +16,7 @@ logger.debug("Setting env UCX_MEMTYPE_CACHE=n, which is required by UCX") os.environ["UCX_MEMTYPE_CACHE"] = "n" -from . import exceptions, testing # noqa +from . import exceptions, types, testing # noqa from ._lib import libucxx # type: ignore from .core import * # noqa from .utils import get_address, get_ucxpy_logger # noqa diff --git a/python/ucxx/_lib/libucxx.pyx b/python/ucxx/_lib/libucxx.pyx index 57d2b650..49ff590d 100644 --- a/python/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/_lib/libucxx.pyx @@ -215,6 +215,26 @@ class PythonRequestNotifierWaitState(enum.Enum): Shutdown = RequestNotifierWaitState.Shutdown +class UCXXTag(): + def __init__(self, tag: int) -> None: + if (tag.bit_length() > 64): + raise ValueError("`tag` must be a 64-bit integer") + self.value = tag + + +class UCXXTagMask(): + def __init__(self, tag_mask: int) -> None: + if (tag_mask.bit_length() > 64): + raise ValueError("`tag_mask` must be a 64-bit integer") + self.value = tag_mask + + +############################################################################### +# Constants # +############################################################################### + +UCXXTagMaskFull = UCXXTagMask(2 ** 64 - 1) + ############################################################################### # Classes # ############################################################################### @@ -486,7 +506,7 @@ cdef class UCXWorker(): handle = self._worker.get().getHandle() return int(handle) - + @property def ucxx_ptr(self): cdef Worker* worker @@ -571,11 +591,14 @@ cdef class UCXWorker(): return num_canceled - def tag_probe(self, size_t tag): + def tag_probe(self, tag: UCXXTag): + if not isinstance(tag, UCXXTag): + raise TypeError(f"The `tag` object must be of type {UCXXTag}") cdef bint tag_matched + cdef Tag cpp_tag = tag.value with nogil: - tag_matched = self._worker.get().tagProbe(tag) + tag_matched = self._worker.get().tagProbe(cpp_tag) return tag_matched @@ -629,10 +652,16 @@ cdef class UCXWorker(): def is_python_future_enabled(self): return self._enable_python_future - def tag_recv(self, Array arr, size_t tag): + def tag_recv(self, Array arr, tag: UCXXTagMask, tag_mask: UCXXTagMask = UCXXTagMaskFull): + if not isinstance(tag, UCXXTag): + raise TypeError(f"The `tag` object must be of type {UCXXTag}") + if not isinstance(tag_mask, UCXXTagMask): + raise TypeError(f"The `tag_mask` object must be of type {UCXXTagMask}") cdef void* buf = arr.ptr cdef size_t nbytes = arr.nbytes cdef shared_ptr[Request] req + cdef Tag cpp_tag = tag.value + cdef TagMask cpp_tag_mask = tag_mask.value if not self._context_feature_flags & Feature.TAG.value: raise ValueError("UCXContext must be created with `Feature.TAG`") @@ -641,7 +670,8 @@ cdef class UCXWorker(): req = self._worker.get().tagRecv( buf, nbytes, - tag, + cpp_tag, + cpp_tag_mask, self._enable_python_future ) @@ -1146,10 +1176,13 @@ cdef class UCXEndpoint(): return UCXRequest(&req, self._enable_python_future) - def tag_send(self, Array arr, size_t tag): + def tag_send(self, Array arr, tag: UCXXTagMask): + if not isinstance(tag, UCXXTag): + raise TypeError(f"The `tag` object must be of type {UCXXTag}") cdef void* buf = arr.ptr cdef size_t nbytes = arr.nbytes cdef shared_ptr[Request] req + cdef Tag cpp_tag = tag.value if not self._context_feature_flags & Feature.TAG.value: raise ValueError("UCXContext must be created with `Feature.TAG`") @@ -1165,16 +1198,22 @@ cdef class UCXEndpoint(): req = self._endpoint.get().tagSend( buf, nbytes, - tag, + cpp_tag, self._enable_python_future ) return UCXRequest(&req, self._enable_python_future) - def tag_recv(self, Array arr, size_t tag): + def tag_recv(self, Array arr, tag: UCXXTagMask, tag_mask: UCXXTagMask=UCXXTagMaskFull): + if not isinstance(tag, UCXXTag): + raise TypeError(f"The `tag` object must be of type {UCXXTag}") + if not isinstance(tag_mask, UCXXTagMask): + raise TypeError(f"The `tag_mask` object must be of type {UCXXTagMask}") cdef void* buf = arr.ptr cdef size_t nbytes = arr.nbytes cdef shared_ptr[Request] req + cdef Tag cpp_tag = tag.value + cdef TagMask cpp_tag_mask = tag_mask.value if not self._context_feature_flags & Feature.TAG.value: raise ValueError("UCXContext must be created with `Feature.TAG`") @@ -1190,17 +1229,21 @@ cdef class UCXEndpoint(): req = self._endpoint.get().tagRecv( buf, nbytes, - tag, + cpp_tag, + cpp_tag_mask, self._enable_python_future ) return UCXRequest(&req, self._enable_python_future) - def tag_send_multi(self, tuple arrays, size_t tag): + def tag_send_multi(self, tuple arrays, tag: UCXXTagMask): + if not isinstance(tag, UCXXTag): + raise TypeError(f"The `tag` object must be of type {UCXXTag}") cdef vector[void*] v_buffer cdef vector[size_t] v_size cdef vector[int] v_is_cuda cdef shared_ptr[Request] ucxx_buffer_requests + cdef Tag cpp_tag = tag.value for arr in arrays: if not isinstance(arr, Array): @@ -1225,7 +1268,7 @@ cdef class UCXEndpoint(): v_buffer, v_size, v_is_cuda, - tag, + cpp_tag, self._enable_python_future, ) @@ -1233,12 +1276,18 @@ cdef class UCXEndpoint(): &ucxx_buffer_requests, self._enable_python_future, ) - def tag_recv_multi(self, size_t tag): + def tag_recv_multi(self, tag: UCXXTagMask, tag_mask: UCXXTagMask=UCXXTagMaskFull): + if not isinstance(tag, UCXXTag): + raise TypeError(f"The `tag` object must be of type {UCXXTag}") + if not isinstance(tag_mask, UCXXTagMask): + raise TypeError(f"The `tag_mask` object must be of type {UCXXTagMask}") cdef shared_ptr[Request] ucxx_buffer_requests + cdef Tag cpp_tag = tag.value + cdef TagMask cpp_tag_mask = tag_mask.value with nogil: ucxx_buffer_requests = self._endpoint.get().tagMultiRecv( - tag, self._enable_python_future + cpp_tag, cpp_tag_mask, self._enable_python_future ) return UCXBufferRequests( diff --git a/python/ucxx/_lib/tests/test_cancel.py b/python/ucxx/_lib/tests/test_cancel.py index 865aaf69..f5c234cb 100644 --- a/python/ucxx/_lib/tests/test_cancel.py +++ b/python/ucxx/_lib/tests/test_cancel.py @@ -55,7 +55,7 @@ def _client_cancel(queue): assert ep.is_alive() msg = Array(bytearray(1)) - request = ep.tag_recv(msg, tag=0) + request = ep.tag_recv(msg, tag=ucx_api.UCXXTag(0)) while not request.is_completed(): worker.progress() diff --git a/python/ucxx/_lib/tests/test_config.py b/python/ucxx/_lib/tests/test_config.py index 4c9daadd..e2f36913 100644 --- a/python/ucxx/_lib/tests/test_config.py +++ b/python/ucxx/_lib/tests/test_config.py @@ -72,11 +72,11 @@ def test_feature_flags_mismatch(feature_flag): with pytest.raises( ValueError, match="UCXContext must be created with `Feature.TAG`" ): - ep.tag_send(msg, 0) + ep.tag_send(msg, tag=ucx_api.UCXXTag(0)) with pytest.raises( ValueError, match="UCXContext must be created with `Feature.TAG`" ): - ep.tag_recv(msg, 0) + ep.tag_recv(msg, tag=ucx_api.UCXXTag(0)) if feature_flag != ucx_api.Feature.STREAM: with pytest.raises( ValueError, match="UCXContext must be created with `Feature.STREAM`" diff --git a/python/ucxx/_lib/tests/test_endpoint.py b/python/ucxx/_lib/tests/test_endpoint.py index 6717c1ba..beac250b 100644 --- a/python/ucxx/_lib/tests/test_endpoint.py +++ b/python/ucxx/_lib/tests/test_endpoint.py @@ -50,7 +50,7 @@ def _listener_handler(conn_request): worker.progress() wireup_msg = Array(bytearray(WireupMessageSize)) - wireup_request = ep[0].tag_recv(wireup_msg, tag=0) + wireup_request = ep[0].tag_recv(wireup_msg, tag=ucx_api.UCXXTag(0)) wait_requests(worker, "blocking", wireup_request) if server_close_callback is True: @@ -73,7 +73,7 @@ def _client(port, server_close_callback): ) worker.progress() wireup_msg = Array(bytes(os.urandom(WireupMessageSize))) - wireup_request = ep.tag_send(wireup_msg, tag=0) + wireup_request = ep.tag_send(wireup_msg, tag=ucx_api.UCXXTag(0)) wait_requests(worker, "blocking", wireup_request) if server_close_callback is False: closed = [False] diff --git a/python/ucxx/_lib/tests/test_probe.py b/python/ucxx/_lib/tests/test_probe.py index 929720e4..bffa25b4 100644 --- a/python/ucxx/_lib/tests/test_probe.py +++ b/python/ucxx/_lib/tests/test_probe.py @@ -52,7 +52,13 @@ def _listener_handler(conn_request): wireup = bytes(wireup_req.get_recv_buffer()) else: wireup = bytearray(len(WireupMessage)) - wait_requests(worker, "blocking", ep.tag_recv(Array(wireup), tag=0)) + wait_requests( + worker, + "blocking", + ep.tag_recv( + Array(wireup), tag=ucx_api.UCXXTag(0), tag_mask=ucx_api.UCXXTagMaskFull + ), + ) queue.put("wireup completed") # Ensure client has disconnected -- endpoint is not alive anymore @@ -67,10 +73,18 @@ def _listener_handler(conn_request): wait_requests(worker, "blocking", recv_req) received = bytes(recv_req.get_recv_buffer()) else: - while worker.tag_probe(0) is False: + while worker.tag_probe(ucx_api.UCXXTag(0)) is False: worker.progress() received = bytearray(len(DataMessage)) - wait_requests(worker, "blocking", ep.tag_recv(Array(received), tag=0)) + wait_requests( + worker, + "blocking", + ep.tag_recv( + Array(received), + tag=ucx_api.UCXXTag(0), + tag_mask=ucx_api.UCXXTagMaskFull, + ), + ) assert wireup == WireupMessage assert received == DataMessage @@ -97,8 +111,8 @@ def _client_probe(queue, transfer_api): ] else: requests = [ - ep.tag_send(Array(WireupMessage), tag=0), - ep.tag_send(Array(DataMessage), tag=0), + ep.tag_send(Array(WireupMessage), tag=ucx_api.UCXXTag(0)), + ep.tag_send(Array(DataMessage), tag=ucx_api.UCXXTag(0)), ] wait_requests(worker, "blocking", requests) diff --git a/python/ucxx/_lib/tests/test_server_client.py b/python/ucxx/_lib/tests/test_server_client.py index 79723723..0ee949a2 100644 --- a/python/ucxx/_lib/tests/test_server_client.py +++ b/python/ucxx/_lib/tests/test_server_client.py @@ -22,7 +22,7 @@ def _send(ep, api, message): elif api == "stream": return ep.stream_send(message) else: - return ep.tag_send(message, tag=0) + return ep.tag_send(message, tag=ucx_api.UCXXTag(0)) def _recv(ep, api, message): @@ -31,7 +31,7 @@ def _recv(ep, api, message): elif api == "stream": return ep.stream_recv(message) else: - return ep.tag_recv(message, tag=0) + return ep.tag_recv(message, tag=ucx_api.UCXXTag(0)) def _echo_server(get_queue, put_queue, transfer_api, msg_size, progress_mode): @@ -80,6 +80,13 @@ def _listener_handler(conn_request): msg = Array(bytearray(msg_size)) + if transfer_api == "stream" and msg_size == 0: + with pytest.raises(RuntimeError): + _recv(ep[0], transfer_api, msg) + with pytest.raises(RuntimeError): + _send(ep[0], transfer_api, msg) + return + # We reuse the message buffer, so we must receive, wait, and then send # it back again. requests = [_recv(ep[0], transfer_api, msg)] @@ -131,6 +138,14 @@ def _echo_client(transfer_api, msg_size, progress_mode, port): send_msg = bytes(os.urandom(msg_size)) recv_msg = bytearray(msg_size) + + if transfer_api == "stream" and msg_size == 0: + with pytest.raises(RuntimeError): + _send(ep, transfer_api, Array(send_msg)) + with pytest.raises(RuntimeError): + _recv(ep, transfer_api, Array(recv_msg)) + return + requests = [ _send(ep, transfer_api, Array(send_msg)), _recv(ep, transfer_api, Array(recv_msg)), @@ -146,7 +161,7 @@ def _echo_client(transfer_api, msg_size, progress_mode, port): @pytest.mark.parametrize("transfer_api", ["am", "stream", "tag"]) -@pytest.mark.parametrize("msg_size", [10, 2**24]) +@pytest.mark.parametrize("msg_size", [0, 10, 2**24]) @pytest.mark.parametrize("progress_mode", ["blocking", "thread"]) def test_server_client(transfer_api, msg_size, progress_mode): put_queue, get_queue = mp.Queue(), mp.Queue() diff --git a/python/ucxx/_lib/ucxx_api.pxd b/python/ucxx/_lib/ucxx_api.pxd index e4756113..b0ff3476 100644 --- a/python/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/_lib/ucxx_api.pxd @@ -179,6 +179,13 @@ cdef extern from "" namespace "ucxx" nogil: cdef extern from "" namespace "ucxx" nogil: + cdef enum Tag: + pass + cdef enum TagMask: + pass + # ctypedef Tag CppTag + # ctypedef TagMask CppTagMask + # Using function[Buffer] here doesn't seem possible due to Cython bugs/limitations. The # workaround is to use a raw C function pointer and let it be parsed by the compiler. # See https://github.com/cython/cython/issues/2041 and @@ -238,7 +245,7 @@ cdef extern from "" namespace "ucxx" nogil: size_t cancelInflightRequests( uint64_t period, uint64_t maxAttempts ) except +raise_py_error - bint tagProbe(const ucp_tag_t) const + bint tagProbe(const Tag) const void setProgressThreadStartCallback( function[void(void*)] callback, void* callbackArg ) @@ -249,7 +256,11 @@ cdef extern from "" namespace "ucxx" nogil: void runRequestNotifier() except +raise_py_error void populateFuturesPool() except +raise_py_error shared_ptr[Request] tagRecv( - void* buffer, size_t length, ucp_tag_t tag, bint enable_python_future + void* buffer, + size_t length, + Tag tag, + TagMask tag_mask, + bint enable_python_future ) except +raise_py_error bint isDelayedRequestSubmissionEnabled() const bint isFutureEnabled() const @@ -272,20 +283,24 @@ cdef extern from "" namespace "ucxx" nogil: void* buffer, size_t length, bint enable_python_future ) except +raise_py_error shared_ptr[Request] tagSend( - void* buffer, size_t length, ucp_tag_t tag, bint enable_python_future + void* buffer, size_t length, Tag tag, bint enable_python_future ) except +raise_py_error shared_ptr[Request] tagRecv( - void* buffer, size_t length, ucp_tag_t tag, bint enable_python_future + void* buffer, + size_t length, + Tag tag, + TagMask tag_mask, + bint enable_python_future ) except +raise_py_error shared_ptr[Request] tagMultiSend( const vector[void*]& buffer, const vector[size_t]& length, const vector[int]& isCUDA, - ucp_tag_t tag, + Tag tag, bint enable_python_future ) except +raise_py_error shared_ptr[Request] tagMultiRecv( - ucp_tag_t tag, bint enable_python_future + Tag tag, TagMask tagMask, bint enable_python_future ) except +raise_py_error bint isAlive() void raiseOnError() except +raise_py_error @@ -329,7 +344,6 @@ cdef extern from "" namespace "ucxx" nogil: vector[BufferRequestPtr] _bufferRequests bint _isFilled shared_ptr[Endpoint] _endpoint - ucp_tag_t _tag bint _send cpp_bool isCompleted() diff --git a/python/ucxx/_lib_async/application_context.py b/python/ucxx/_lib_async/application_context.py index 69843991..a7617b1b 100644 --- a/python/ucxx/_lib_async/application_context.py +++ b/python/ucxx/_lib_async/application_context.py @@ -10,6 +10,7 @@ import ucxx._lib.libucxx as ucx_api from ucxx._lib.arr import Array from ucxx.exceptions import UCXMessageTruncatedError +from ucxx.types import Tag from .continuous_ucx_progress import PollingMode, ThreadMode from .endpoint import Endpoint @@ -450,10 +451,12 @@ async def recv(self, buffer, tag): """ if not isinstance(buffer, Array): buffer = Array(buffer) + if not isinstance(tag, Tag): + tag = Tag(tag) nbytes = buffer.nbytes log = "[Worker Recv] worker: %s, tag: %s, nbytes: %d, type: %s" % ( hex(self.worker.handle), - hex(tag), + hex(tag.value), nbytes, type(buffer.obj), ) diff --git a/python/ucxx/_lib_async/endpoint.py b/python/ucxx/_lib_async/endpoint.py index 302cfbad..78952ca0 100644 --- a/python/ucxx/_lib_async/endpoint.py +++ b/python/ucxx/_lib_async/endpoint.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.) # SPDX-License-Identifier: BSD-3-Clause @@ -9,6 +9,7 @@ import ucxx._lib.libucxx as ucx_api from ucxx._lib.arr import Array from ucxx._lib.libucxx import UCXCanceled, UCXCloseError, UCXError +from ucxx.types import Tag, TagMaskFull from .utils import hash64bits @@ -183,6 +184,8 @@ async def send(self, buffer, tag=None, force_tag=False): tag = self._tags["msg_send"] elif not force_tag: tag = hash64bits(self._tags["msg_send"], hash(tag)) + if not isinstance(tag, Tag): + tag = Tag(tag) # Optimization to eliminate producing logger string overhead if logger.isEnabledFor(logging.DEBUG): @@ -190,7 +193,7 @@ async def send(self, buffer, tag=None, force_tag=False): log = "[Send #%03d] ep: 0x%x, tag: 0x%x, nbytes: %d, type: %s" % ( self._send_count, self.uid, - tag, + tag.value, nbytes, type(buffer.obj), ) @@ -235,13 +238,15 @@ async def send_multi(self, buffers, tag=None, force_tag=False): tag = self._tags["msg_send"] elif not force_tag: tag = hash64bits(self._tags["msg_send"], hash(tag)) + if not isinstance(tag, Tag): + tag = Tag(tag) # Optimization to eliminate producing logger string overhead if logger.isEnabledFor(logging.DEBUG): log = "[Send Multi #%03d] ep: 0x%x, tag: 0x%x, nbytes: %s, type: %s" % ( self._send_count, self.uid, - tag, + tag.value, tuple([b.nbytes for b in buffers]), # nbytes, tuple([type(b.obj) for b in buffers]), ) @@ -344,6 +349,8 @@ async def recv(self, buffer, tag=None, force_tag=False): tag = self._tags["msg_recv"] elif not force_tag: tag = hash64bits(self._tags["msg_recv"], hash(tag)) + if not isinstance(tag, Tag): + tag = Tag(tag) try: self._ep.raise_on_error() @@ -364,7 +371,7 @@ async def recv(self, buffer, tag=None, force_tag=False): log = "[Recv #%03d] ep: 0x%x, tag: 0x%x, nbytes: %d, type: %s" % ( self._recv_count, self.uid, - tag, + tag.value, nbytes, type(buffer.obj), ) @@ -372,7 +379,7 @@ async def recv(self, buffer, tag=None, force_tag=False): self._recv_count += 1 - req = self._ep.tag_recv(buffer, tag) + req = self._ep.tag_recv(buffer, tag, TagMaskFull) ret = await req.wait() self._finished_recv_count += 1 @@ -403,6 +410,8 @@ async def recv_multi(self, tag=None, force_tag=False): tag = self._tags["msg_recv"] elif not force_tag: tag = hash64bits(self._tags["msg_recv"], hash(tag)) + if not isinstance(tag, Tag): + tag = Tag(tag) try: self._ep.raise_on_error() @@ -419,13 +428,13 @@ async def recv_multi(self, tag=None, force_tag=False): log = "[Recv Multi #%03d] ep: 0x%x, tag: 0x%x" % ( self._recv_count, self.uid, - tag, + tag.value, ) logger.debug(log) self._recv_count += 1 - buffer_requests = self._ep.tag_recv_multi(tag) + buffer_requests = self._ep.tag_recv_multi(tag, TagMaskFull) await buffer_requests.wait() buffer_requests.check_error() for r in buffer_requests.get_requests(): diff --git a/python/ucxx/_lib_async/tests/test_probe.py b/python/ucxx/_lib_async/tests/test_probe.py index c04a70d9..4bb4c056 100644 --- a/python/ucxx/_lib_async/tests/test_probe.py +++ b/python/ucxx/_lib_async/tests/test_probe.py @@ -4,8 +4,9 @@ import asyncio import pytest +from ucxx.types import Tag -import ucxx as ucxx +import ucxx @pytest.mark.asyncio @@ -23,7 +24,7 @@ async def server_node(ep): assert ep._ep.am_probe() is True received = bytes(await ep.am_recv()) else: - assert ep._ctx.worker.tag_probe(ep._tags["msg_recv"]) is True + assert ep._ctx.worker.tag_probe(Tag(ep._tags["msg_recv"])) is True received = bytearray(10) await ep.recv(received) assert received == msg diff --git a/python/ucxx/_lib_async/tests/test_send_recv.py b/python/ucxx/_lib_async/tests/test_send_recv.py index 9671fec6..49c34858 100644 --- a/python/ucxx/_lib_async/tests/test_send_recv.py +++ b/python/ucxx/_lib_async/tests/test_send_recv.py @@ -10,7 +10,7 @@ np = pytest.importorskip("numpy") -msg_sizes = [2**i for i in range(0, 25, 4)] +msg_sizes = [0, *(2**i for i in range(0, 25, 4))] dtypes = ["|u1", "