diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9d98fc22..6b72f865 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -113,9 +113,11 @@ add_library( src/endpoint.cpp src/header.cpp src/inflight_requests.cpp + src/internal/request_am.cpp src/listener.cpp src/log.cpp src/request.cpp + src/request_am.cpp src/request_helper.cpp src/request_stream.cpp src/request_tag.cpp diff --git a/cpp/include/ucxx/constructors.h b/cpp/include/ucxx/constructors.h index 3e1d6206..0d0876dc 100644 --- a/cpp/include/ucxx/constructors.h +++ b/cpp/include/ucxx/constructors.h @@ -19,6 +19,7 @@ class Future; class Listener; class Notifier; class Request; +class RequestAm; class RequestStream; class RequestTag; class RequestTagMulti; @@ -54,6 +55,19 @@ std::shared_ptr createWorker(std::shared_ptr context, const bool enableFuture); // Transfers +std::shared_ptr createRequestAmSend(std::shared_ptr endpoint, + void* buffer, + size_t length, + ucs_memory_type_t memoryType, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + +std::shared_ptr createRequestAmRecv(std::shared_ptr endpoint, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + std::shared_ptr createRequestStream(std::shared_ptr endpoint, bool send, void* buffer, diff --git a/cpp/include/ucxx/delayed_submission.h b/cpp/include/ucxx/delayed_submission.h index 9b6da2be..01595913 100644 --- a/cpp/include/ucxx/delayed_submission.h +++ b/cpp/include/ucxx/delayed_submission.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -17,14 +18,13 @@ namespace ucxx { typedef std::function DelayedSubmissionCallbackType; -typedef std::shared_ptr DelayedSubmissionCallbackPtrType; - class DelayedSubmission { public: bool _send{false}; ///< Whether this is a send (`true`) operation or recv (`false`) void* _buffer{nullptr}; ///< Raw pointer to data buffer size_t _length{0}; ///< Length of the message in bytes ucp_tag_t _tag{0}; ///< Tag to match + ucs_memory_type_t _memoryType{UCS_MEMORY_TYPE_UNKNOWN}; ///< Buffer memory type DelayedSubmission() = delete; @@ -42,17 +42,23 @@ class DelayedSubmission { * a multi-threaded application for blocking while waiting for the UCX spinlock, since * all transfer operations may be pushed to the worker progress thread. * - * @param[in] send whether this is a send (`true`) or receive (`false`) operation. - * @param[in] buffer a raw pointer to the data being transferred. - * @param[in] length the size in bytes of the message being transfer. - * @param[in] tag tag to match for this operation (only applies for tag operations). + * @param[in] send whether this is a send (`true`) or receive (`false`) operation. + * @param[in] buffer a raw pointer to the data being transferred. + * @param[in] length the size in bytes of the message being transfer. + * @param[in] tag tag to match for this operation (only applies for tag + * operations). + * @param[in] memoryType the memory type of the buffer. */ - DelayedSubmission(const bool send, void* buffer, const size_t length, const ucp_tag_t tag = 0); + DelayedSubmission(const bool send, + void* buffer, + const size_t length, + const ucp_tag_t tag = 0, + const ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_UNKNOWN); }; class DelayedSubmissionCollection { private: - std::vector + std::vector, DelayedSubmissionCallbackType>> _collection{}; ///< The collection of all known delayed submission operations. std::mutex _mutex{}; ///< Mutex to provide access to the collection. @@ -87,10 +93,12 @@ class DelayedSubmissionCollection { * Register a request for delayed submission with a callback that will be executed when * the request is in fact submitted when `process()` is called. * + * @param[in] request the request to which the callback belongs, ensuring it remains + * alive until the callback is invoked. * @param[in] callback the callback that will be executed by `process()` when the * operation is submitted. */ - void registerRequest(DelayedSubmissionCallbackType callback); + void registerRequest(std::shared_ptr request, DelayedSubmissionCallbackType callback); }; } // namespace ucxx diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 85e8bf07..357f012b 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -253,6 +253,60 @@ class Endpoint : public Component { */ void setCloseCallback(std::function closeCallback, void* closeCallbackArg); + /** + * @brief Enqueue an active message send operation. + * + * Enqueue an active message send operation, returning a `std::shared_ptr` + * that can be later awaited and checked for errors. This is a non-blocking operation, and + * the status of the transfer must be verified from the resulting request object before + * the data can be released. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the tag message to be sent. + * @param[in] memoryType the memory type of the buffer. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion and its state. + */ + std::shared_ptr amSend(void* buffer, + size_t length, + ucs_memory_type_t memoryType, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Enqueue an active message receive operation. + * + * Enqueue an active message receive operation, returning a + * `std::shared_ptr` that can be later awaited and checked for errors, + * making data available via the return value's `getRecvBuffer()` method once the + * operation completes successfully. This is a non-blocking operation, and the status of + * the transfer must be verified from the resulting request object before the data can be + * consumed. + * + * Using a Python future may be requested by specifying `enablePythonFuture`. If a + * Python future is requested, the Python application must then await on this future to + * ensure the transfer has completed. Requires UCXX Python support. + * + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns Request to be subsequently checked for the completion state and data. + */ + std::shared_ptr amRecv(const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + /** * @brief Enqueue a stream send operation. * diff --git a/cpp/include/ucxx/internal/request_am.h b/cpp/include/ucxx/internal/request_am.h new file mode 100644 index 00000000..73c6d8dd --- /dev/null +++ b/cpp/include/ucxx/internal/request_am.h @@ -0,0 +1,100 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace ucxx { + +class Buffer; +class InflightRequests; +class RequestAm; +class Request; +class Worker; + +namespace internal { + +class AmData; + +class RecvAmMessage { + public: + internal::AmData* _amData{nullptr}; ///< Active messages data + ucp_ep_h _ep{nullptr}; ///< Handle containing address of the reply endpoint + std::shared_ptr _request{ + nullptr}; ///< Request which will later be notified/delivered to user + std::shared_ptr _buffer{nullptr}; ///< Buffer containing the received data + + RecvAmMessage() = delete; + RecvAmMessage(const RecvAmMessage&) = delete; + RecvAmMessage& operator=(RecvAmMessage const&) = delete; + RecvAmMessage(RecvAmMessage&& o) = delete; + RecvAmMessage& operator=(RecvAmMessage&& o) = delete; + + /** + * @brief Constructor of `ucxx::RecvAmMessage`. + * + * Construct the object, setting attributes that are later needed by the callback. + * + * @param[in] amData active messages worker data. + * @param[in] ep handle containing address of the reply endpoint (i.e., endpoint + * where user is requesting to receive). + * @param[in] request request to be later notified/delivered to user. + * @param[in] buffer buffer containing the received data + */ + RecvAmMessage(internal::AmData* amData, + ucp_ep_h ep, + std::shared_ptr request, + std::shared_ptr buffer); + + /** + * @brief Set the UCP request. + * + * Set the underlying UCP request (`_request` attribute) of the `RequestAm`. + * + * @param[in] request the UCP request associated to the active message receive operation. + */ + void setUcpRequest(void* request); + + /** + * @brief Execute the `ucxx::Request::callback()`. + * + * Execute the `ucxx::Request::callback()` method to set the status of the request, the + * buffer containing the data received and release the reference to this object from + * `AmData`. + * + * @param[in] request the UCP request associated to the active message receive operation. + * @param[in] status the completion status of the UCP request. + */ + void callback(void* request, ucs_status_t status); +}; + +typedef std::unordered_map>> AmPoolType; +typedef std::unordered_map> RecvAmMessageMapType; + +class AmData { + public: + std::weak_ptr _worker{}; ///< The worker to which the Active Message callback belongs + std::string _ownerString{}; ///< The owner string used for logging + AmPoolType _recvPool{}; ///< The pool of completed receive requests (waiting for user request) + AmPoolType _recvWait{}; ///< The pool of user receive requests (waiting for message arrival) + RecvAmMessageMapType + _recvAmMessageMap{}; ///< The active messages waiting to be handled by callback + std::mutex _mutex{}; ///< Mutex to provide access to pools/maps + std::function)> + _registerInflightRequest{}; ///< Worker function to register inflight requests with + std::unordered_map + _allocators{}; ///< Default and user-defined active message allocators +}; + +} // namespace internal + +} // namespace ucxx diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index e9326403..3f4e6eb2 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -192,6 +192,20 @@ class Request : public Component { * @returns the formatted string containing the owner type and its handle. */ const std::string& getOwnerString() const; + + /** + * @brief Get the received buffer. + * + * This method is used to get the received buffer for applicable derived classes (e.g., + * `RequestAm` receive operations), in all other cases this will return `nullptr`. Before + * getting the received buffer it's necessary to check that the request completed + * successfully either by validating `getStatus() == UCS_OK` or by checking the request + * completed with `isCompleted() == true` and that it did not error with `checkError()`, + * if any of those is unsuccessful this call returns `nullptr`. + * + * @return The received buffer (if applicable) or `nullptr`. + */ + virtual std::shared_ptr getRecvBuffer(); }; } // namespace ucxx diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h new file mode 100644 index 00000000..ecf3f845 --- /dev/null +++ b/cpp/include/ucxx/request_am.h @@ -0,0 +1,199 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once +#include +#include + +#include + +#include +#include +#include + +namespace ucxx { + +class Buffer; + +namespace internal { +class RecvAmMessage; +} // namespace internal + +class RequestAm : public Request { + private: + friend class internal::RecvAmMessage; + + ucs_memory_type_t _sendHeader{}; ///< The header to send + std::shared_ptr _buffer{nullptr}; ///< The AM received message buffer + + /** + * @brief Private constructor of `ucxx::RequestAm` send. + * + * This is the internal implementation of `ucxx::RequestAm` send constructor, made private + * not to be called directly. This constructor is made private to ensure all UCXX objects + * are shared pointers and the correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::Endpoint::amSend()` + * - `ucxx::createRequestAmSend()` + * + * @throws ucxx::Error if `endpoint` is not a valid `std::shared_ptr`. + * + * @param[in] endpoint the parent endpoint. + * @param[in] buffer a raw pointer to the data to be sent. + * @param[in] length the size in bytes of the active message to be sent. + * @param[in] memoryType the memory type of the buffer. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + */ + RequestAm(std::shared_ptr endpoint, + void* buffer, + size_t length, + ucs_memory_type_t memoryType, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Private constructor of `ucxx::RequestAm` receive. + * + * This is the internal implementation of `ucxx::RequestAm` receive constructor, made + * private not to be called directly. This constructor is made private to ensure all UCXX + * objects are shared pointers and the correct lifetime management of each one. + * + * Instead the user should use one of the following: + * + * - `ucxx::Endpoint::amRecv()` + * - `ucxx::createRequestAmRecv()` + * + * @throws ucxx::Error if `endpointOrWorker` is not a valid + * `std::shared_ptr` or + * `std::shared_ptr`. + * + * @param[in] endpointOrWorker the parent component, which may either be a + * `std::shared_ptr` or + * `std::shared_ptr`. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + */ + RequestAm(std::shared_ptr endpointOrWorker, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); + + public: + /** + * @brief Constructor for `std::shared_ptr` send. + * + * The constructor for a `std::shared_ptr` object, creating a send active + * message request, returning a pointer to a request object that can be later awaited and + * checked for errors. This is a non-blocking operation, and the status of the transfer + * must be verified from the resulting request object before the data can be + * released. + * + * @throws ucxx::Error if `endpoint` is not a valid + * `std::shared_ptr`. + * + * @param[in] endpoint the parent endpoint. + * @param[in] buffer a raw pointer to the data to be transferred. + * @param[in] length the size in bytes of the tag message to be transferred. + * @param[in] memoryType the memory type of the buffer. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createRequestAmSend( + std::shared_ptr endpoint, + void* buffer, + size_t length, + ucs_memory_type_t memoryType, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + + /** + * @brief Constructor for `std::shared_ptr` receive. + * + * The constructor for a `std::shared_ptr` object, creating a receive + * active message request, returning a pointer to a request object that can be later + * awaited and checked for errors. This is a non-blocking operation, and the status of + * the transfer must be verified from the resulting request object before the data can be + * consumed, the data is available via the `getRecvBuffer()` method if the transfer + * completed successfully. + * + * @throws ucxx::Error if `endpoint` is not a valid + * `std::shared_ptr`. + * + * @param[in] endpoint the parent endpoint. + * @param[in] enablePythonFuture whether a python future should be created and + * subsequently notified. + * @param[in] callbackFunction user-defined callback function to call upon completion. + * @param[in] callbackData user-defined data to pass to the `callbackFunction`. + * + * @returns The `shared_ptr` object + */ + friend std::shared_ptr createRequestAmRecv( + std::shared_ptr endpoint, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + + virtual void populateDelayedSubmission(); + + /** + * @brief Create and submit an active message send request. + * + * This is the method that should be called to actually submit an active message send + * request. It is meant to be called from `populateDelayedSubmission()`, which is decided + * at the discretion of `std::shared_ptr`. See `populateDelayedSubmission()` + * for more details. + */ + void request(); + + /** + * @brief Receive callback registered by `ucxx::Worker`. + * + * This is the receive callback registered by the `ucxx::Worker` to handle incoming active + * messages. For each incoming active message, a proper buffer will be allocated based on + * the header sent by the remote endpoint using the default allocator or one registered by + * the user via `ucxx::Worker::registerAmAllocator()`. Following that, the message is + * immediately received onto the buffer and a `UCS_OK` or the proper error status is set + * onto the `RequestAm` that is returned to the user, or will be later handled by another + * callback when the message is ready. If the callback is executed when a user has already + * requested received of the active message, the buffer and status will be set on the + * earliest request, otherwise a new request is created and saved in a pool that will be + * already populated and ready for consumption or waiting for the internal callback when + * requested. + * + * This is always called by `ucp_worker_progress()`, and thus will happen in the same + * thread that is called from, when using the worker progress thread, this is called from + * the progress thread. + * + * param[in,out] arg pointer to the `AmData` object held by the `ucxx::Worker` who + * registered this callback. + * param[in] header pointer to the header containing the sender buffer's memory type. + * param[in] header_length length in bytes of the receive header. + * param[in] data pointer to the buffer containing the remote endpoint's send data. + * param[in] length the length in bytes of the message to be received. + * param[in] param UCP parameters of the active message being received. + */ + static ucs_status_t recvCallback(void* arg, + const void* header, + size_t header_length, + void* data, + size_t length, + const ucp_am_recv_param_t* param); + + std::shared_ptr getRecvBuffer() override; +}; + +} // namespace ucxx diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 6cea3576..1c8b1933 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -11,6 +11,7 @@ namespace ucxx { +class Buffer; class Request; // Logging levels @@ -36,4 +37,6 @@ typedef std::unordered_map ConfigMap; typedef std::function)> RequestCallbackUserFunction; typedef std::shared_ptr RequestCallbackUserData; +typedef std::function(size_t)> AmAllocatorType; + } // namespace ucxx diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index da280a40..520be6c7 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -25,8 +25,14 @@ namespace ucxx { class Address; +class Buffer; class Endpoint; class Listener; +class RequestAm; + +namespace internal { +class AmData; +} // namespace internal class Worker : public Component { private: @@ -48,6 +54,12 @@ class Worker : public Component { std::shared_ptr _delayedSubmissionCollection{ nullptr}; ///< Collection of enqueued delayed submissions + friend std::shared_ptr createRequestAmRecv( + std::shared_ptr endpoint, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData); + protected: bool _enableFuture{ false}; ///< Boolean identifying whether the worker was created with future capability @@ -55,6 +67,8 @@ class Worker : public Component { std::queue> _futuresPool{}; ///< Futures pool to prevent running out of fresh futures std::shared_ptr _notifier{nullptr}; ///< Notifier object + std::shared_ptr + _amData; ///< Worker data made available to Active Messages callback private: /** @@ -65,6 +79,23 @@ class Worker : public Component { */ void drainWorkerTagRecv(); + /** + * @brief Get active message receive request. + * + * Returns an active message request from the pool if the worker has already begun + * handling a request with the active messages callback, otherwise creates a new request + * that is later populated with status and buffer by the active messages callback. + * + * @param[in] ep the endpoint handle where receiving the message, the same handle that + * will later be used to reply to the message. + * @param[in] createAmRecvRequestFunction function to create a new request if one is not + * already availale in the pool. + * + * @returns Request to be subsequently checked for the completion state and data. + */ + std::shared_ptr getAmRecv( + ucp_ep_h ep, std::function()> createAmRecvRequestFunction); + /** * @brief Stop the progress thread if running without raising warnings. * @@ -355,10 +386,13 @@ class Worker : public Component { * thread, thus decreasing computation on the caller thread, but potentially increasing * transfer latency. * + * @param[in] request the request to which the callback belongs, ensuring it remains + * alive until the callback is invoked. * @param[in] callback the callback set to execute the UCP transfer routine during the * worker thread loop. */ - void registerDelayedSubmission(DelayedSubmissionCallbackType callback); + void registerDelayedSubmission(std::shared_ptr request, + DelayedSubmissionCallbackType callback); /** * @brief Inquire if worker has been created with delayed submission enabled. @@ -535,7 +569,7 @@ class Worker : public Component { * * @returns `true` if any uncaught messages were received, `false` otherwise. */ - bool tagProbe(ucp_tag_t tag); + bool tagProbe(const ucp_tag_t tag); /** * @brief Enqueue a tag receive operation. @@ -659,6 +693,56 @@ class Worker : public Component { std::shared_ptr createListener(uint16_t port, ucp_listener_conn_callback_t callback, void* callbackArgs); + + /** + * @brief Register allocator for active messages. + * + * Register a new allocator for active messages. By default, only one allocator is defined + * for host memory (`UCS_MEMORY_TYPE_HOST`), and is used as a fallback when an allocator + * for the source's memory type is unavailable. In many circumstances relying exclusively + * on the host allocator is undesirable, for example when transferring CUDA buffers the + * destination is always going to be a host buffer and prevent the use of transports such + * as NVLink or InfiniBand+GPUDirectRDMA. For that reason it's important that the user + * defines those allocators that are important for the application. + * + * If the `memoryType` has already been registered, the previous allocator will be + * replaced by the new one. Be careful when doing this after transfers have started, there + * are no guarantees that inflight messages have not already been allocated with the old + * allocator for that type. + * + * @code{.cpp} + * // context is `std::shared_ptr` + * auto worker = context->createWorker(false); + * + * worker->registerAmAllocator(`UCS_MEMORY_TYPE_CUDA`, ucxx::RMMBuffer); + * @endcode + * + * @param[in] memoryType the memory type the allocator will be used for. + * @param[in] allocator the allocator callable that will be used to allocate new + * active message buffers. + */ + void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator); + + /** + * @brief Check for uncaught active messages. + * + * Checks the worker for any uncaught active messages. An uncaught active message is any + * active message that has been fully or partially received by the worker, but not matched + * by a corresponding `createRequestAmRecv()` call. + * + * @code{.cpp} + * // `worker` is `std::shared_ptr` + * // `ep` is a remote `std::shared_ptramProbe(ep->getHandle())); + * + * ep->amSend(buffer, length); + * + * assert(worker->amProbe(0)); + * @endcode + * + * @returns `true` if any uncaught messages were received, `false` otherwise. + */ + bool amProbe(const ucp_ep_h endpointHandle) const; }; } // namespace ucxx diff --git a/cpp/python/src/worker.cpp b/cpp/python/src/worker.cpp index 59cf1d71..9b8233df 100644 --- a/cpp/python/src/worker.cpp +++ b/cpp/python/src/worker.cpp @@ -6,9 +6,11 @@ #include #include #include +#include #include +#include #include #include #include @@ -31,8 +33,20 @@ std::shared_ptr<::ucxx::Worker> createWorker(std::shared_ptr context, const bool enableDelayedSubmission, const bool enableFuture) { - return std::shared_ptr<::ucxx::Worker>( + auto worker = std::shared_ptr<::ucxx::python::Worker>( new ::ucxx::python::Worker(context, enableDelayedSubmission, enableFuture)); + + // We can only get a `shared_ptr` for the Active Messages callback after it's + // been created, thus this cannot be in the constructor. + if (worker->_amData != nullptr) { + worker->_amData->_worker = worker; + + std::stringstream ownerStream; + ownerStream << "worker " << worker->getHandle(); + worker->_amData->_ownerString = ownerStream.str(); + } + + return worker; } void Worker::populateFuturesPool() diff --git a/cpp/src/delayed_submission.cpp b/cpp/src/delayed_submission.cpp index f7eb57c4..ee768333 100644 --- a/cpp/src/delayed_submission.cpp +++ b/cpp/src/delayed_submission.cpp @@ -16,8 +16,9 @@ namespace ucxx { DelayedSubmission::DelayedSubmission(const bool send, void* buffer, const size_t length, - const ucp_tag_t tag) - : _send(send), _buffer(buffer), _length(length), _tag(tag) + const ucp_tag_t tag, + const ucs_memory_type_t memoryType) + : _send(send), _buffer(buffer), _length(length), _tag(tag), _memoryType(memoryType) { } @@ -34,26 +35,25 @@ void DelayedSubmissionCollection::process() toProcess = std::move(_collection); } - for (auto& callbackPtr : toProcess) { - auto& callback = *callbackPtr; + for (auto& pair : toProcess) { + auto& req = pair.first; + auto& callback = pair.second; - ucxx_trace_req("Submitting request: %p", callback.target)>()); + ucxx_trace_req("Submitting request: %p", req.get()); if (callback) callback(); } } } -void DelayedSubmissionCollection::registerRequest(DelayedSubmissionCallbackType callback) +void DelayedSubmissionCollection::registerRequest(std::shared_ptr request, + DelayedSubmissionCallbackType callback) { - auto r = std::make_shared(callback); - { std::lock_guard lock(_mutex); - _collection.push_back(r); + _collection.push_back({request, callback}); } - ucxx_trace_req("Registered submit request: %p", - callback.target)>()); + ucxx_trace_req("Registered submit request: %p", request.get()); } } // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index af550812..0bc6f40b 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -205,6 +206,27 @@ void Endpoint::removeInflightRequest(const Request* const request) size_t Endpoint::cancelInflightRequests() { return _inflightRequests->cancelAll(); } +std::shared_ptr Endpoint::amSend(void* buffer, + size_t length, + ucs_memory_type_t memoryType, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest(createRequestAmSend( + endpoint, buffer, length, memoryType, enablePythonFuture, callbackFunction, callbackData)); +} + +std::shared_ptr Endpoint::amRecv(const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) +{ + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); + return registerInflightRequest( + createRequestAmRecv(endpoint, enablePythonFuture, callbackFunction, callbackData)); +} + std::shared_ptr Endpoint::streamSend(void* buffer, size_t length, const bool enablePythonFuture) diff --git a/cpp/src/internal/request_am.cpp b/cpp/src/internal/request_am.cpp new file mode 100644 index 00000000..873d22e2 --- /dev/null +++ b/cpp/src/internal/request_am.cpp @@ -0,0 +1,38 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include +#include + +namespace ucxx { + +namespace internal { + +RecvAmMessage::RecvAmMessage(internal::AmData* amData, + ucp_ep_h ep, + std::shared_ptr request, + std::shared_ptr buffer) + : _amData(amData), _ep(ep), _request(request), _buffer(buffer) +{ + _request->_delayedSubmission = + std::make_shared(false, _buffer->data(), _buffer->getSize()); +} + +void RecvAmMessage::setUcpRequest(void* request) { _request->_request = request; } + +void RecvAmMessage::callback(void* request, ucs_status_t status) +{ + _request->_buffer = _buffer; + _request->callback(request, status); + { + std::lock_guard lock(_amData->_mutex); + _amData->_recvAmMessageMap.erase(_request.get()); + } +} + +} // namespace internal + +} // namespace ucxx diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 3adcfaa4..f2453b17 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -105,10 +105,11 @@ void Request::callback(void* request, ucs_status_t status) _callback.target()); if (_callback) _callback(status, _callbackData); - if (_request != nullptr) ucp_request_free(request); + if (request != nullptr) ucp_request_free(request); ucxx_trace("Request completed: %p, handle: %p", this, request); setStatus(status); + ucxx_trace("Request %p, isCompleted: %d", this, isCompleted()); } void Request::process() @@ -184,4 +185,6 @@ void Request::setStatus(ucs_status_t status) const std::string& Request::getOwnerString() const { return _ownerString; } +std::shared_ptr Request::getRecvBuffer() { return nullptr; } + } // namespace ucxx diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp new file mode 100644 index 00000000..c09bfcab --- /dev/null +++ b/cpp/src/request_am.cpp @@ -0,0 +1,287 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-License-Identifier: BSD-3-Clause + */ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace ucxx { + +std::shared_ptr createRequestAmSend( + std::shared_ptr endpoint, + void* buffer, + size_t length, + ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_HOST, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) +{ + auto req = std::shared_ptr(new RequestAm( + endpoint, buffer, length, memoryType, enablePythonFuture, callbackFunction, callbackData)); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; +} + +std::shared_ptr createRequestAmRecv( + std::shared_ptr endpoint, + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr) +{ + auto worker = endpoint->getWorker(); + + auto createRequest = [endpoint, enablePythonFuture, callbackFunction, callbackData]() { + return std::shared_ptr( + new RequestAm(endpoint, enablePythonFuture, callbackFunction, callbackData)); + }; + return worker->getAmRecv(endpoint->getHandle(), createRequest); +} + +RequestAm::RequestAm(std::shared_ptr endpoint, + void* buffer, + size_t length, + ucs_memory_type_t memoryType, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) + : Request(endpoint, + std::make_shared(true, buffer, length, 0, memoryType), + std::string("amSend"), + enablePythonFuture) +{ + _callback = callbackFunction; + _callbackData = callbackData; +} + +RequestAm::RequestAm(std::shared_ptr endpointOrWorker, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) + : Request(endpointOrWorker, nullptr, std::string("amRecv"), enablePythonFuture) +{ + _callback = callbackFunction; + _callbackData = callbackData; +} + +static void _amSendCallback(void* request, ucs_status_t status, void* user_data) +{ + Request* req = reinterpret_cast(user_data); + ucxx_trace_req_f(req->getOwnerString().c_str(), request, "amSend", "_amSendCallback"); + req->callback(request, status); +} + +static void _recvCompletedCallback(void* request, + ucs_status_t status, + size_t length, + void* user_data) +{ + internal::RecvAmMessage* recvAmMessage = static_cast(user_data); + ucxx_trace_req_f( + recvAmMessage->_request->getOwnerString().c_str(), request, "amRecv", "amRecvCallback"); + recvAmMessage->callback(request, status); +} + +ucs_status_t RequestAm::recvCallback(void* arg, + const void* header, + size_t header_length, + void* data, + size_t length, + const ucp_am_recv_param_t* param) +{ + internal::AmData* amData = static_cast(arg); + auto worker = amData->_worker.lock(); + auto& ownerString = amData->_ownerString; + auto& recvPool = amData->_recvPool; + auto& recvWait = amData->_recvWait; + + if ((param->recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP) == 0) + ucxx_error("UCP_AM_RECV_ATTR_FIELD_REPLY_EP not set"); + + ucp_ep_h ep = param->reply_ep; + + bool is_rndv = param->recv_attr & UCP_AM_RECV_ATTR_FLAG_RNDV; + + std::shared_ptr buf{nullptr}; + auto allocatorType = *static_cast(header); + + std::shared_ptr req{nullptr}; + + { + std::lock_guard lock(amData->_mutex); + + auto reqs = recvWait.find(ep); + if (reqs != recvWait.end() && !reqs->second.empty()) { + req = reqs->second.front(); + reqs->second.pop(); + ucxx_trace_req("amRecv recvWait: %p", req.get()); + } else { + req = std::shared_ptr( + new RequestAm(worker, worker->isFutureEnabled(), nullptr, nullptr)); + auto [queue, _] = recvPool.try_emplace(ep, std::queue>()); + queue->second.push(req); + ucxx_trace_req("amRecv recvPool: %p", req.get()); + } + } + + if (is_rndv) { + if (amData->_allocators.find(allocatorType) == amData->_allocators.end()) { + // TODO: Is a hard failure better? + // ucxx_debug("Unsupported memory type %d", allocatorType); + // internal::RecvAmMessage recvAmMessage(amData, ep, req, nullptr); + // recvAmMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); + // return UCS_ERR_UNSUPPORTED; + + ucxx_trace_req("No allocator registered for memory type %d, falling back to host memory.", + allocatorType); + allocatorType = UCS_MEMORY_TYPE_HOST; + } + + std::shared_ptr buf = amData->_allocators.at(allocatorType)(length); + + auto recvAmMessage = std::make_shared(amData, ep, req, buf); + + ucp_request_param_t request_param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA | + UCP_OP_ATTR_FLAG_NO_IMM_CMPL, + .cb = {.recv_am = _recvCompletedCallback}, + .user_data = recvAmMessage.get()}; + + ucs_status_ptr_t status = + ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &request_param); + + if (req->_enablePythonFuture) + ucxx_trace_req_f(ownerString.c_str(), + status, + "amRecv rndv", + "ep %p, buffer %p, size %lu, future %p, future handle %p, recvCallback", + ep, + buf->data(), + length, + req->_future.get(), + req->_future->getHandle()); + else + ucxx_trace_req_f(ownerString.c_str(), + status, + "amRecv rndv", + "ep %p, buffer %p, size %lu, recvCallback", + ep, + buf->data(), + length); + + if (req->isCompleted()) { + // The request completed/errored immediately + ucs_status_t s = UCS_PTR_STATUS(status); + recvAmMessage->callback(nullptr, s); + + return s; + } else { + // The request will be handled by the callback + recvAmMessage->setUcpRequest(status); + amData->_registerInflightRequest(req); + + { + std::lock_guard lock(amData->_mutex); + amData->_recvAmMessageMap.emplace(req.get(), recvAmMessage); + } + + return UCS_INPROGRESS; + } + } else { + std::shared_ptr buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); + if (length > 0) memcpy(buf->data(), data, length); + + if (req->_enablePythonFuture) + ucxx_trace_req_f(ownerString.c_str(), + nullptr, + "amRecv eager", + "ep: %p, buffer %p, size %lu, future %p, future handle %p, recvCallback", + ep, + buf->data(), + length, + req->_future.get(), + req->_future->getHandle()); + else + ucxx_trace_req_f(ownerString.c_str(), + nullptr, + "amRecv eager", + "ep: %p, buffer %p, size %lu, recvCallback", + ep, + buf->data(), + length); + + internal::RecvAmMessage recvAmMessage(amData, ep, req, buf); + recvAmMessage.callback(nullptr, UCS_OK); + return UCS_OK; + } +} + +std::shared_ptr RequestAm::getRecvBuffer() { return _buffer; } + +void RequestAm::request() +{ + static const ucp_tag_t tagMask = -1; + + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_FLAGS | + UCP_OP_ATTR_FIELD_USER_DATA, + .flags = UCP_AM_SEND_FLAG_REPLY, + .datatype = ucp_dt_make_contig(1), + .user_data = this}; + + _sendHeader = _delayedSubmission->_memoryType; + + if (_delayedSubmission->_send) { + param.cb.send = _amSendCallback; + _request = ucp_am_send_nbx(_endpoint->getHandle(), + 0, + &_sendHeader, + sizeof(_sendHeader), + _delayedSubmission->_buffer, + _delayedSubmission->_length, + ¶m); + } else { + throw ucxx::UnsupportedError( + "Receiving active messages must be handled by the worker's callback"); + } +} + +void RequestAm::populateDelayedSubmission() +{ + request(); + + if (_enablePythonFuture) + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer %p, size %lu, future %p, future handle %p, populateDelayedSubmission", + _delayedSubmission->_buffer, + _delayedSubmission->_length, + _future.get(), + _future->getHandle()); + else + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "buffer %p, size %lu, populateDelayedSubmission", + _delayedSubmission->_buffer, + _delayedSubmission->_length); + + process(); +} + +} // namespace ucxx diff --git a/cpp/src/request_stream.cpp b/cpp/src/request_stream.cpp index 6cfb5f61..c93b447d 100644 --- a/cpp/src/request_stream.cpp +++ b/cpp/src/request_stream.cpp @@ -23,13 +23,6 @@ RequestStream::RequestStream(std::shared_ptr endpoint, enablePythonFuture), _length(length) { - auto worker = endpoint->getWorker(); - - // A delayed notification request is not populated immediately, instead it is - // delayed to allow the worker progress thread to set its status, and more - // importantly the Python future later on, so that we don't need the GIL here. - worker->registerDelayedSubmission( - std::bind(std::mem_fn(&Request::populateDelayedSubmission), this)); } std::shared_ptr createRequestStream(std::shared_ptr endpoint, @@ -38,8 +31,16 @@ std::shared_ptr createRequestStream(std::shared_ptr end size_t length, const bool enablePythonFuture = false) { - return std::shared_ptr( + auto req = std::shared_ptr( new RequestStream(endpoint, send, buffer, length, enablePythonFuture)); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; } void RequestStream::request() diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index 12130c9f..b6a4cf5e 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -22,14 +22,22 @@ std::shared_ptr createRequestTag(std::shared_ptr endpoint RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr) { - return std::shared_ptr(new RequestTag(endpointOrWorker, - send, - buffer, - length, - tag, - enablePythonFuture, - callbackFunction, - callbackData)); + auto req = std::shared_ptr(new RequestTag(endpointOrWorker, + send, + buffer, + length, + tag, + enablePythonFuture, + callbackFunction, + callbackData)); + + // A delayed notification request is not populated immediately, instead it is + // delayed to allow the worker progress thread to set its status, and more + // importantly the Python future later on, so that we don't need the GIL here. + req->_worker->registerDelayedSubmission( + req, std::bind(std::mem_fn(&Request::populateDelayedSubmission), req.get())); + + return req; } RequestTag::RequestTag(std::shared_ptr endpointOrWorker, @@ -50,12 +58,6 @@ RequestTag::RequestTag(std::shared_ptr endpointOrWorker, throw ucxx::Error("An endpoint is required to send tag messages"); _callback = callbackFunction; _callbackData = callbackData; - - // A delayed notification request is not populated immediately, instead it is - // delayed to allow the worker progress thread to set its status, and more - // importantly the Python future later on, so that we don't need the GIL here. - _worker->registerDelayedSubmission( - std::bind(std::mem_fn(&Request::populateDelayedSubmission), this)); } void RequestTag::callback(void* request, ucs_status_t status, const ucp_tag_recv_info_t* info) diff --git a/cpp/src/request_tag_multi.cpp b/cpp/src/request_tag_multi.cpp index eca92695..1883edbb 100644 --- a/cpp/src/request_tag_multi.cpp +++ b/cpp/src/request_tag_multi.cpp @@ -202,7 +202,6 @@ void RequestTagMulti::callback(ucs_status_t status) { if (_send) throw std::runtime_error("Send requests cannot call callback()"); - // TODO: Remove arg ucxx_trace_req("RequestTagMulti::callback request: %p, tag: %lx", this, _tag); if (_bufferRequests.empty()) { diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index c0edc35b..bc7baf82 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -15,6 +16,9 @@ #include #include +#include +#include +#include #include #include #include @@ -37,6 +41,24 @@ Worker::Worker(std::shared_ptr context, if (enableDelayedSubmission) _delayedSubmissionCollection = std::make_shared(); + if (context->getFeatureFlags() & UCP_FEATURE_AM) { + unsigned int AM_MSG_ID = 0; + _amData = std::make_shared(); + _amData->_registerInflightRequest = [this](std::shared_ptr req) { + this->registerInflightRequest(req); + }; + registerAmAllocator(UCS_MEMORY_TYPE_HOST, + [](size_t length) { return std::make_shared(length); }); + + ucp_am_handler_param_t am_handler_param = {.field_mask = UCP_AM_HANDLER_PARAM_FIELD_ID | + UCP_AM_HANDLER_PARAM_FIELD_CB | + UCP_AM_HANDLER_PARAM_FIELD_ARG, + .id = AM_MSG_ID, + .cb = RequestAm::recvCallback, + .arg = _amData.get()}; + utils::ucsErrorThrow(ucp_worker_set_am_recv_handler(_handle, &am_handler_param)); + } + ucxx_trace("Worker created: %p, enableDelayedSubmission: %d, enableFuture: %d", this, enableDelayedSubmission, @@ -84,11 +106,44 @@ void Worker::drainWorkerTagRecv() } } +std::shared_ptr Worker::getAmRecv( + ucp_ep_h ep, std::function()> createAmRecvRequestFunction) +{ + std::lock_guard lock(_amData->_mutex); + + auto& recvPool = _amData->_recvPool; + auto& recvWait = _amData->_recvWait; + + auto reqs = recvPool.find(ep); + if (reqs != recvPool.end() && !reqs->second.empty()) { + auto req = reqs->second.front(); + reqs->second.pop(); + return req; + } else { + auto req = createAmRecvRequestFunction(); + auto [queue, _] = recvWait.try_emplace(ep, std::queue>()); + queue->second.push(req); + return req; + } +} + std::shared_ptr createWorker(std::shared_ptr context, const bool enableDelayedSubmission, const bool enableFuture) { - return std::shared_ptr(new Worker(context, enableDelayedSubmission, enableFuture)); + auto worker = std::shared_ptr(new Worker(context, enableDelayedSubmission, enableFuture)); + + // We can only get a `shared_ptr` for the Active Messages callback after it's + // been created, thus this cannot be in the constructor. + if (worker->_amData != nullptr) { + worker->_amData->_worker = worker; + + std::stringstream ownerStream; + ownerStream << "worker " << worker->getHandle(); + worker->_amData->_ownerString = ownerStream.str(); + } + + return worker; } Worker::~Worker() @@ -206,12 +261,13 @@ bool Worker::progress() return ret; } -void Worker::registerDelayedSubmission(DelayedSubmissionCallbackType callback) +void Worker::registerDelayedSubmission(std::shared_ptr request, + DelayedSubmissionCallbackType callback) { if (_delayedSubmissionCollection == nullptr) { callback(); } else { - _delayedSubmissionCollection->registerRequest(callback); + _delayedSubmissionCollection->registerRequest(request, callback); /* Waking the progress event is needed here because the UCX request is * not dispatched immediately. Thus we must signal the progress task so @@ -319,7 +375,7 @@ void Worker::removeInflightRequest(const Request* const request) } } -bool Worker::tagProbe(ucp_tag_t tag) +bool Worker::tagProbe(const ucp_tag_t tag) { // TODO: Fix temporary workaround, if progress thread is active we must wait for it // to progress the worker instead. @@ -378,4 +434,16 @@ std::shared_ptr Worker::createListener(uint16_t port, return listener; } +void Worker::registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator) +{ + if (_amData == nullptr) + throw std::runtime_error("Active Messages wasn not enabled during context creation"); + _amData->_allocators.insert_or_assign(memoryType, allocator); +} + +bool Worker::amProbe(const ucp_ep_h endpointHandle) const +{ + return _amData->_recvPool.find(endpointHandle) != _amData->_recvPool.end(); +} + } // namespace ucxx diff --git a/cpp/tests/include/utils.h b/cpp/tests/include/utils.h index aeb7ef82..9110af50 100644 --- a/cpp/tests/include/utils.h +++ b/cpp/tests/include/utils.h @@ -22,13 +22,23 @@ enum class ProgressMode { void createCudaContextCallback(void* callbackArg); -void waitRequests(std::shared_ptr worker, - const std::vector>& requests, - const std::function& progressWorker); - -void waitRequestsTagMulti(std::shared_ptr worker, - const std::vector>& requests, - const std::function& progressWorker); +template +inline void waitRequests(std::shared_ptr worker, + const std::vector>& requests, + const std::function& progressWorker) +{ + auto remainingRequests = requests; + while (!remainingRequests.empty()) { + auto updatedRequests = std::exchange(remainingRequests, decltype(remainingRequests)()); + for (auto const& r : updatedRequests) { + if (progressWorker) progressWorker(); + if (!r->isCompleted()) + remainingRequests.push_back(r); + else + r->checkError(); + } + } +} std::function getProgressFunction(std::shared_ptr worker, ProgressMode progressMode); diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 182c3f23..2aafde20 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -21,20 +21,22 @@ using ::testing::Combine; using ::testing::ContainerEq; using ::testing::Values; -class RequestTest - : public ::testing::TestWithParam> { +class RequestTest : public ::testing::TestWithParam< + std::tuple> { protected: - std::shared_ptr _context{ - ucxx::createContext({}, ucxx::Context::defaultFeatureFlags)}; + std::shared_ptr _context{nullptr}; std::shared_ptr _worker{nullptr}; std::shared_ptr _ep{nullptr}; std::function _progressWorker; ucxx::BufferType _bufferType; + ucs_memory_type_t _memoryType; + bool _registerCustomAmAllocator; bool _enableDelayedSubmission; ProgressMode _progressMode; size_t _messageLength; size_t _messageSize; + size_t _rndvThresh{8192}; size_t _numBuffers{0}; std::vector> _send; @@ -52,10 +54,18 @@ class RequestTest #endif } - std::tie(_bufferType, _enableDelayedSubmission, _progressMode, _messageLength) = GetParam(); + std::tie(_bufferType, + _registerCustomAmAllocator, + _enableDelayedSubmission, + _progressMode, + _messageLength) = GetParam(); + _memoryType = + (_bufferType == ucxx::BufferType::RMM) ? UCS_MEMORY_TYPE_CUDA : UCS_MEMORY_TYPE_HOST; _messageSize = _messageLength * sizeof(int); - _worker = _context->createWorker(_enableDelayedSubmission); + _context = ucxx::createContext({{"RNDV_THRESH", std::to_string(_rndvThresh)}}, + ucxx::Context::defaultFeatureFlags); + _worker = _context->createWorker(_enableDelayedSubmission); if (_progressMode == ProgressMode::Blocking) { _worker->initBlockingProgressMode(); @@ -134,6 +144,41 @@ class RequestTest } }; +TEST_P(RequestTest, ProgressAm) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { + _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { + return std::make_shared(length); + }); + } + + allocate(1, false); + + // Submit and wait for transfers to complete + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + auto recvReq = requests[1]; + _recvPtr[0] = recvReq->getRecvBuffer()->data(); + + // Messages larger than `_rndvThresh` are rendezvous and will use custom allocator, + // smaller messages are eager and will always be host-allocated. + ASSERT_THAT(recvReq->getRecvBuffer()->getType(), + (_registerCustomAmAllocator && _messageSize >= _rndvThresh) ? _bufferType + : ucxx::BufferType::Host); + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + TEST_P(RequestTest, ProgressStream) { allocate(); @@ -185,7 +230,7 @@ TEST_P(RequestTest, ProgressTagMulti) std::vector> requests; requests.push_back(_ep->tagMultiSend(_sendPtr, multiSize, multiIsCUDA, 0, false)); requests.push_back(_ep->tagMultiRecv(0, false)); - waitRequestsTagMulti(_worker, requests, _progressWorker); + waitRequests(_worker, requests, _progressWorker); auto recvRequest = requests[1]; @@ -246,39 +291,43 @@ TEST_P(RequestTest, TagUserCallback) INSTANTIATE_TEST_SUITE_P(ProgressModes, RequestTest, Combine(Values(ucxx::BufferType::Host), + Values(false), Values(false), Values(ProgressMode::Polling, ProgressMode::Blocking, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(1, 1024, 1048576))); + Values(1, 1024, 2048, 1048576))); INSTANTIATE_TEST_SUITE_P(DelayedSubmission, RequestTest, Combine(Values(ucxx::BufferType::Host), + Values(false), Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(1, 1024, 1048576))); + Values(1, 1024, 2048, 1048576))); #if UCXX_ENABLE_RMM INSTANTIATE_TEST_SUITE_P(RMMProgressModes, RequestTest, Combine(Values(ucxx::BufferType::RMM), + Values(false, true), Values(false), Values(ProgressMode::Polling, ProgressMode::Blocking, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(1, 1024, 1048576))); + Values(1, 1024, 2048, 1048576))); INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, RequestTest, Combine(Values(ucxx::BufferType::RMM), + Values(false, true), Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(1, 1024, 1048576))); + Values(1, 1024, 2048, 1048576))); #endif } // namespace diff --git a/cpp/tests/utils.cpp b/cpp/tests/utils.cpp index 7c5a9a0f..daf4ab97 100644 --- a/cpp/tests/utils.cpp +++ b/cpp/tests/utils.cpp @@ -13,30 +13,6 @@ void createCudaContextCallback(void* callbackArg) cudaFree(0); } -void waitRequests(std::shared_ptr worker, - const std::vector>& requests, - const std::function& progressWorker) -{ - for (auto& r : requests) { - do { - if (progressWorker) progressWorker(); - } while (!r->isCompleted()); - r->checkError(); - } -} - -void waitRequestsTagMulti(std::shared_ptr worker, - const std::vector>& requests, - const std::function& progressWorker) -{ - for (auto& r : requests) { - do { - if (progressWorker) progressWorker(); - } while (!r->isCompleted()); - r->checkError(); - } -} - std::function getProgressFunction(std::shared_ptr worker, ProgressMode progressMode) { diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index 862008de..b113a131 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -69,6 +69,16 @@ class WorkerProgressTest : public WorkerTest, TEST_F(WorkerTest, HandleIsValid) { ASSERT_TRUE(_worker->getHandle() != nullptr); } +TEST_P(WorkerCapabilityTest, CheckCapability) +{ + ASSERT_EQ(_worker->isDelayedSubmissionEnabled(), _enableDelayedSubmission); + ASSERT_EQ(_worker->isFutureEnabled(), _enableFuture); +} + +INSTANTIATE_TEST_SUITE_P(Capabilities, + WorkerCapabilityTest, + Combine(Values(false, true), Values(false, true))); + TEST_F(WorkerTest, TagProbe) { auto progressWorker = getProgressFunction(_worker, ProgressMode::Polling); @@ -89,15 +99,52 @@ TEST_F(WorkerTest, TagProbe) ASSERT_TRUE(_worker->tagProbe(0)); } -TEST_P(WorkerCapabilityTest, CheckCapability) +TEST_F(WorkerTest, AmProbe) { - ASSERT_EQ(_worker->isDelayedSubmissionEnabled(), _enableDelayedSubmission); - ASSERT_EQ(_worker->isFutureEnabled(), _enableFuture); + auto progressWorker = getProgressFunction(_worker, ProgressMode::Polling); + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + + ASSERT_FALSE(_worker->amProbe(ep->getHandle())); + + std::vector buf{123}; + std::vector> requests; + requests.push_back(ep->amSend(buf.data(), buf.size() * sizeof(int), UCS_MEMORY_TYPE_HOST)); + waitRequests(_worker, requests, progressWorker); + + // Attempt to progress worker 10 times (arbitrarily defined). + // TODO: Maybe a timeout would fit best. + for (size_t i = 0; i < 10 && !_worker->tagProbe(0); ++i) + progressWorker(); + + ASSERT_TRUE(_worker->amProbe(ep->getHandle())); } -INSTANTIATE_TEST_SUITE_P(Capabilities, - WorkerCapabilityTest, - Combine(Values(false, true), Values(false, true))); +TEST_P(WorkerProgressTest, ProgressAm) +{ + if (_progressMode == ProgressMode::Wait) { + // TODO: Is this the same reason as TagMulti? + GTEST_SKIP() << "Wait mode not supported"; + } + + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + + std::vector send{123}; + + std::vector> requests; + requests.push_back(ep->amSend(send.data(), send.size() * sizeof(int), UCS_MEMORY_TYPE_HOST)); + requests.push_back(ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + auto recvReq = requests[1]; + auto recvBuffer = recvReq->getRecvBuffer(); + + ASSERT_EQ(recvBuffer->getType(), ucxx::BufferType::Host); + ASSERT_EQ(recvBuffer->getSize(), send.size() * sizeof(int)); + + std::vector recvAbstract(reinterpret_cast(recvBuffer->data()), + reinterpret_cast(recvBuffer->data()) + send.size()); + ASSERT_EQ(recvAbstract[0], send[0]); +} TEST_P(WorkerProgressTest, ProgressStream) { @@ -148,7 +195,7 @@ TEST_P(WorkerProgressTest, ProgressTagMulti) std::vector> requests; requests.push_back(ep->tagMultiSend(multiBuffer, multiSize, multiIsCUDA, 0, false)); requests.push_back(ep->tagMultiRecv(0, false)); - waitRequestsTagMulti(_worker, requests, _progressWorker); + waitRequests(_worker, requests, _progressWorker); for (const auto& br : requests[1]->_bufferRequests) { // br->buffer == nullptr are headers diff --git a/python/ucxx/_lib/libucxx.pyx b/python/ucxx/_lib/libucxx.pyx index 5987c244..ec94fc25 100644 --- a/python/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/_lib/libucxx.pyx @@ -18,6 +18,7 @@ from libcpp.functional cimport function from libcpp.map cimport map as cpp_map from libcpp.memory cimport ( dynamic_pointer_cast, + make_shared, make_unique, shared_ptr, unique_ptr, @@ -60,6 +61,11 @@ def _get_host_buffer(uintptr_t recv_buffer_ptr): return ptr_to_ndarray(host_buffer.release(), size) +cdef shared_ptr[Buffer] _rmm_am_allocator(size_t length): + cdef shared_ptr[RMMBuffer] rmm_buffer = make_shared[RMMBuffer](length) + return dynamic_pointer_cast[Buffer, RMMBuffer](rmm_buffer) + + ############################################################################### # Exceptions # ############################################################################### @@ -439,6 +445,10 @@ cdef class UCXWorker(): ): cdef bint ucxx_enable_delayed_submission = enable_delayed_submission cdef bint ucxx_enable_python_future = enable_python_future + cdef AmAllocatorType rmm_am_allocator + + self._context_feature_flags = (context.feature_flags) + with nogil: self._worker = createPythonWorker( context._context, @@ -448,7 +458,9 @@ cdef class UCXWorker(): self._enable_delayed_submission = self._worker.get().isDelayedSubmissionEnabled() self._enable_python_future = self._worker.get().isFutureEnabled() - self._context_feature_flags = (context.feature_flags) + if self._context_feature_flags & UCP_FEATURE_AM: + rmm_am_allocator = (&_rmm_am_allocator) + self._worker.get().registerAmAllocator(UCS_MEMORY_TYPE_CUDA, rmm_am_allocator) @property def handle(self): @@ -661,6 +673,22 @@ cdef class UCXRequest(): else: await self.wait_yield() + def get_recv_buffer(self): + cdef shared_ptr[Buffer] buf + cdef BufferType bufType + + with nogil: + buf = self._request.get().getRecvBuffer() + bufType = buf.get().getType() + + # If buf == NULL, it's not allocated by the request but rather the user + if buf == NULL: + return None + elif bufType == BufferType.RMM: + return _get_rmm_buffer(buf.get()) + else: + return _get_host_buffer(buf.get()) + cdef class UCXBufferRequest: cdef: @@ -951,6 +979,48 @@ cdef class UCXEndpoint(): with nogil: self._endpoint.get().close() + def am_probe(self): + cdef ucp_ep_h handle + cdef shared_ptr[Worker] worker + cdef bint ep_matched + + with nogil: + handle = self._endpoint.get().getHandle() + worker = self._endpoint.get().getWorker() + ep_matched = worker.get().amProbe(handle) + + return ep_matched + + def am_send(self, Array arr): + cdef void* buf = arr.ptr + cdef size_t nbytes = arr.nbytes + cdef bint cuda_array = arr.cuda + cdef shared_ptr[Request] req + + if not self._context_feature_flags & Feature.AM.value: + raise ValueError("UCXContext must be created with `Feature.AM`") + + with nogil: + req = self._endpoint.get().amSend( + buf, + nbytes, + UCS_MEMORY_TYPE_CUDA if cuda_array else UCS_MEMORY_TYPE_HOST, + self._enable_python_future + ) + + return UCXRequest(&req, self._enable_python_future) + + def am_recv(self): + cdef shared_ptr[Request] req + + if not self._context_feature_flags & Feature.AM.value: + raise ValueError("UCXContext must be created with `Feature.AM`") + + with nogil: + req = self._endpoint.get().amRecv(self._enable_python_future) + + return UCXRequest(&req, self._enable_python_future) + def stream_send(self, Array arr): cdef void* buf = arr.ptr cdef size_t nbytes = arr.nbytes diff --git a/python/ucxx/_lib/tests/test_probe.py b/python/ucxx/_lib/tests/test_probe.py index 5cb74223..929720e4 100644 --- a/python/ucxx/_lib/tests/test_probe.py +++ b/python/ucxx/_lib/tests/test_probe.py @@ -3,6 +3,7 @@ import multiprocessing as mp +import pytest from ucxx._lib import libucxx as ucx_api from ucxx._lib.arr import Array from ucxx.testing import terminate_process, wait_requests @@ -13,14 +14,17 @@ DataMessage = bytearray(b"0" * 10) -def _server_probe(queue): +def _server_probe(queue, transfer_api): """Server that probes and receives message after client disconnected. Note that since it is illegal to call progress() in callback functions, we keep a reference to the endpoint after the listener callback has terminated, this way we can progress even after Python blocking calls. """ - ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) + feature_flags = ( + ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, + ) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) # Keep endpoint to be used from outside the listener callback @@ -42,8 +46,13 @@ def _listener_handler(conn_request): ep = ep[0] # Ensure wireup and inform client before it can disconnect - wireup = bytearray(len(WireupMessage)) - wait_requests(worker, "blocking", ep.tag_recv(Array(wireup), tag=0)) + if transfer_api == "am": + wireup_req = ep.am_recv() + wait_requests(worker, "blocking", wireup_req) + wireup = bytes(wireup_req.get_recv_buffer()) + else: + wireup = bytearray(len(WireupMessage)) + wait_requests(worker, "blocking", ep.tag_recv(Array(wireup), tag=0)) queue.put("wireup completed") # Ensure client has disconnected -- endpoint is not alive anymore @@ -51,17 +60,27 @@ def _listener_handler(conn_request): worker.progress() # Probe/receive message even after the remote endpoint has disconnected - while worker.tag_probe(0) is False: - worker.progress() - received = bytearray(len(DataMessage)) - wait_requests(worker, "blocking", ep.tag_recv(Array(received), tag=0)) + if transfer_api == "am": + while ep.am_probe() is False: + worker.progress() + recv_req = ep.am_recv() + wait_requests(worker, "blocking", recv_req) + received = bytes(recv_req.get_recv_buffer()) + else: + while worker.tag_probe(0) is False: + worker.progress() + received = bytearray(len(DataMessage)) + wait_requests(worker, "blocking", ep.tag_recv(Array(received), tag=0)) assert wireup == WireupMessage assert received == DataMessage -def _client_probe(queue): - ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) +def _client_probe(queue, transfer_api): + feature_flags = ( + ucx_api.Feature.AM if transfer_api == "am" else ucx_api.Feature.TAG, + ) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) worker = ucx_api.UCXWorker(ctx) port = queue.get() ep = ucx_api.UCXEndpoint.create( @@ -71,21 +90,28 @@ def _client_probe(queue): endpoint_error_handling=True, ) - requests = [ - ep.tag_send(Array(WireupMessage), tag=0), - ep.tag_send(Array(DataMessage), tag=0), - ] + if transfer_api == "am": + requests = [ + ep.am_send(Array(WireupMessage)), + ep.am_send(Array(DataMessage)), + ] + else: + requests = [ + ep.tag_send(Array(WireupMessage), tag=0), + ep.tag_send(Array(DataMessage), tag=0), + ] wait_requests(worker, "blocking", requests) # Wait for wireup before disconnecting assert queue.get() == "wireup completed" -def test_message_probe(): +@pytest.mark.parametrize("transfer_api", ["am", "tag"]) +def test_message_probe(transfer_api): queue = mp.Queue() - server = mp.Process(target=_server_probe, args=(queue,)) + server = mp.Process(target=_server_probe, args=(queue, transfer_api)) server.start() - client = mp.Process(target=_client_probe, args=(queue,)) + client = mp.Process(target=_client_probe, args=(queue, transfer_api)) client.start() client.join(timeout=10) server.join(timeout=10) diff --git a/python/ucxx/_lib/tests/test_server_client.py b/python/ucxx/_lib/tests/test_server_client.py index d7cbcf52..79723723 100644 --- a/python/ucxx/_lib/tests/test_server_client.py +++ b/python/ucxx/_lib/tests/test_server_client.py @@ -17,14 +17,18 @@ def _send(ep, api, message): - if api == "stream": + if api == "am": + return ep.am_send(message) + elif api == "stream": return ep.stream_send(message) else: return ep.tag_send(message, tag=0) def _recv(ep, api, message): - if api == "stream": + if api == "am": + return ep.am_recv() + elif api == "stream": return ep.stream_recv(message) else: return ep.tag_recv(message, tag=0) @@ -38,7 +42,9 @@ def _echo_server(get_queue, put_queue, transfer_api, msg_size, progress_mode): outside of the callback function. """ feature_flags = [ucx_api.Feature.WAKEUP] - if transfer_api == "stream": + if transfer_api == "am": + feature_flags.append(ucx_api.Feature.AM) + elif transfer_api == "stream": feature_flags.append(ucx_api.Feature.STREAM) else: feature_flags.append(ucx_api.Feature.TAG) @@ -78,6 +84,8 @@ def _listener_handler(conn_request): # it back again. requests = [_recv(ep[0], transfer_api, msg)] wait_requests(worker, progress_mode, requests) + if transfer_api == "am": + msg = Array(requests[0].get_recv_buffer()) requests = [_send(ep[0], transfer_api, msg)] wait_requests(worker, progress_mode, requests) @@ -92,6 +100,8 @@ def _listener_handler(conn_request): def _echo_client(transfer_api, msg_size, progress_mode, port): feature_flags = [ucx_api.Feature.WAKEUP] + if transfer_api == "am": + feature_flags.append(ucx_api.Feature.AM) if transfer_api == "stream": feature_flags.append(ucx_api.Feature.STREAM) else: @@ -127,10 +137,15 @@ def _echo_client(transfer_api, msg_size, progress_mode, port): ] wait_requests(worker, progress_mode, requests) - assert recv_msg == send_msg + if transfer_api == "am": + recv_msg = requests[1].get_recv_buffer() + + assert bytes(recv_msg) == send_msg + else: + assert recv_msg == send_msg -@pytest.mark.parametrize("transfer_api", ["stream", "tag"]) +@pytest.mark.parametrize("transfer_api", ["am", "stream", "tag"]) @pytest.mark.parametrize("msg_size", [10, 2**24]) @pytest.mark.parametrize("progress_mode", ["blocking", "thread"]) def test_server_client(transfer_api, msg_size, progress_mode): diff --git a/python/ucxx/_lib/ucxx_api.pxd b/python/ucxx/_lib/ucxx_api.pxd index be2faa95..2bd5e1a7 100644 --- a/python/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/_lib/ucxx_api.pxd @@ -71,9 +71,15 @@ cdef extern from "ucp/api/ucp.h" nogil: ctypedef enum ucs_status_t: pass + ctypedef enum ucs_memory_type_t: + pass + # Constants ucs_status_t UCS_OK + ucs_memory_type_t UCS_MEMORY_TYPE_HOST + ucs_memory_type_t UCS_MEMORY_TYPE_CUDA + int UCP_FEATURE_TAG int UCP_FEATURE_WAKEUP int UCP_FEATURE_STREAM @@ -152,16 +158,19 @@ cdef extern from "" namespace "ucxx" nogil: Invalid "ucxx::BufferType::Invalid" cdef cppclass Buffer: + Buffer(const BufferType bufferType, const size_t size_t) BufferType getType() size_t getSize() cdef cppclass HostBuffer: + HostBuffer(const size_t size_t) BufferType getType() size_t getSize() void* release() except +raise_py_error void* data() except +raise_py_error cdef cppclass RMMBuffer: + RMMBuffer(const size_t size_t) BufferType getType() size_t getSize() unique_ptr[device_buffer] release() except +raise_py_error @@ -177,6 +186,12 @@ cdef extern from "" namespace "ucxx" nogil: cdef extern from "" namespace "ucxx" nogil: + # Using function[Buffer] here doesn't seem possible due to Cython bugs/limitations. The + # workaround is to use a raw C function pointer and let it be parsed by the compiler. + # See https://github.com/cython/cython/issues/2041 and + # https://github.com/cython/cython/issues/3193 + ctypedef shared_ptr[Buffer] (*AmAllocatorType)(size_t) + ctypedef cpp_unordered_map[string, string] ConfigMap shared_ptr[Context] createContext( @@ -228,7 +243,7 @@ cdef extern from "" namespace "ucxx" nogil: ) except +raise_py_error void stopProgressThread() except +raise_py_error size_t cancelInflightRequests() except +raise_py_error - bint tagProbe(ucp_tag_t) + bint tagProbe(const ucp_tag_t) const void setProgressThreadStartCallback( function[void(void*)] callback, void* callbackArg ) @@ -243,10 +258,18 @@ cdef extern from "" namespace "ucxx" nogil: ) except +raise_py_error bint isDelayedSubmissionEnabled() const bint isFutureEnabled() const + bint amProbe(ucp_ep_h) const + void registerAmAllocator(ucs_memory_type_t memoryType, AmAllocatorType allocator) cdef cppclass Endpoint(Component): ucp_ep_h getHandle() void close() + shared_ptr[Request] amSend( + void* buffer, size_t length, ucs_memory_type_t memory_type, bint enable_python_future + ) except +raise_py_error + shared_ptr[Request] amRecv( + bint enable_python_future + ) except +raise_py_error shared_ptr[Request] streamSend( void* buffer, size_t length, bint enable_python_future ) except +raise_py_error @@ -274,6 +297,7 @@ cdef extern from "" namespace "ucxx" nogil: void setCloseCallback( function[void(void*)] close_callback, void* close_callback_arg ) + shared_ptr[Worker] getWorker() cdef cppclass Listener(Component): shared_ptr[Endpoint] createEndpointFromConnRequest( @@ -292,6 +316,7 @@ cdef extern from "" namespace "ucxx" nogil: ucs_status_t getStatus() void checkError() except +raise_py_error void* getFuture() except +raise_py_error + shared_ptr[Buffer] getRecvBuffer() except +raise_py_error cdef extern from "" namespace "ucxx" nogil: diff --git a/python/ucxx/_lib_async/endpoint.py b/python/ucxx/_lib_async/endpoint.py index 60b74c87..cf9049bd 100644 --- a/python/ucxx/_lib_async/endpoint.py +++ b/python/ucxx/_lib_async/endpoint.py @@ -59,7 +59,7 @@ def abort(self): To do that, use `Endpoint.close()` """ if self._ep is not None: - logger.debug("Endpoint.abort(): %s" % hex(self.uid)) + logger.debug("Endpoint.abort(): 0x%x" % self.uid) self._ep.close() self._ep = None self._ctx = None @@ -85,6 +85,43 @@ async def close(self): await asyncio.sleep(0) self.abort() + async def am_send(self, buffer): + """Send `buffer` to connected peer via active messages. + + Parameters + ---------- + buffer: exposing the buffer protocol or array/cuda interface + The buffer to send. Raise ValueError if buffer is smaller + than nbytes. + """ + self._ep.raise_on_error() + if self.closed(): + raise UCXCloseError("Endpoint closed") + if not isinstance(buffer, Array): + buffer = Array(buffer) + + # Optimization to eliminate producing logger string overhead + if logger.isEnabledFor(logging.DEBUG): + nbytes = buffer.nbytes + log = "[AM Send #%03d] ep: 0x%x, nbytes: %d, type: %s" % ( + self._send_count, + self.uid, + nbytes, + type(buffer.obj), + ) + logger.debug(log) + + self._send_count += 1 + + try: + request = self._ep.am_send(buffer) + return await request.wait() + except UCXCanceled as e: + # If self._ep has already been closed and destroyed, we reraise the + # UCXCanceled exception. + if self._ep is None: + raise e + # @ucx_api.nvtx_annotate("UCXPY_SEND", color="green", domain="ucxpy") async def send(self, buffer, tag=None, force_tag=False): """Send `buffer` to connected peer. @@ -94,7 +131,6 @@ async def send(self, buffer, tag=None, force_tag=False): buffer: exposing the buffer protocol or array/cuda interface The buffer to send. Raise ValueError if buffer is smaller than nbytes. - tag: hashable, optional tag: hashable, optional Set a tag that the receiver must match. Currently the tag is hashed together with the internal Endpoint tag that is @@ -118,10 +154,10 @@ async def send(self, buffer, tag=None, force_tag=False): # Optimization to eliminate producing logger string overhead if logger.isEnabledFor(logging.DEBUG): nbytes = buffer.nbytes - log = "[Send #%03d] ep: %s, tag: %s, nbytes: %d, type: %s" % ( + log = "[Send #%03d] ep: 0x%x, tag: 0x%x, nbytes: %d, type: %s" % ( self._send_count, - hex(self.uid), - hex(tag), + self.uid, + tag, nbytes, type(buffer.obj), ) @@ -146,7 +182,6 @@ async def send_multi(self, buffers, tag=None, force_tag=False): buffer: exposing the buffer protocol or array/cuda interface The buffer to send. Raise ValueError if buffer is smaller than nbytes. - tag: hashable, optional tag: hashable, optional Set a tag that the receiver must match. Currently the tag is hashed together with the internal Endpoint tag that is @@ -170,10 +205,10 @@ async def send_multi(self, buffers, tag=None, force_tag=False): # Optimization to eliminate producing logger string overhead if logger.isEnabledFor(logging.DEBUG): - log = "[Send Multi #%03d] ep: %s, tag: %s, nbytes: %s, type: %s" % ( + log = "[Send Multi #%03d] ep: 0x%x, tag: 0x%x, nbytes: %s, type: %s" % ( self._send_count, - hex(self.uid), - hex(tag), + self.uid, + tag, tuple([b.nbytes for b in buffers]), # nbytes, tuple([type(b.obj) for b in buffers]), ) @@ -214,6 +249,44 @@ async def send_obj(self, obj, tag=None): await self.send(nbytes, tag=tag) await self.send(obj, tag=tag) + async def am_recv(self): + """Receive from connected peer via active messages.""" + if not self._ep.am_probe(): + self._ep.raise_on_error() + if self.closed(): + raise UCXCloseError("Endpoint closed") + + # Optimization to eliminate producing logger string overhead + if logger.isEnabledFor(logging.DEBUG): + log = "[AM Recv #%03d] ep: 0x%x" % ( + self._recv_count, + self.uid, + ) + logger.debug(log) + + self._recv_count += 1 + + req = self._ep.am_recv() + await req.wait() + buffer = req.get_recv_buffer() + + if logger.isEnabledFor(logging.DEBUG): + log = "[AM Recv Completed #%03d] ep: 0x%x, nbytes: %d, type: %s" % ( + self._recv_count, + self.uid, + buffer.nbytes, + type(buffer), + ) + logger.debug(log) + + self._finished_recv_count += 1 + if ( + self._close_after_n_recv is not None + and self._finished_recv_count >= self._close_after_n_recv + ): + self.abort() + return buffer + # @ucx_api.nvtx_annotate("UCXPY_RECV", color="red", domain="ucxpy") async def recv(self, buffer, tag=None, force_tag=False): """Receive from connected peer into `buffer`. @@ -255,10 +328,10 @@ async def recv(self, buffer, tag=None, force_tag=False): # Optimization to eliminate producing logger string overhead if logger.isEnabledFor(logging.DEBUG): nbytes = buffer.nbytes - log = "[Recv #%03d] ep: %s, tag: %s, nbytes: %d, type: %s" % ( + log = "[Recv #%03d] ep: 0x%x, tag: 0x%x, nbytes: %d, type: %s" % ( self._recv_count, - hex(self.uid), - hex(tag), + self.uid, + tag, nbytes, type(buffer.obj), ) @@ -310,10 +383,10 @@ async def recv_multi(self, tag=None, force_tag=False): # Optimization to eliminate producing logger string overhead if logger.isEnabledFor(logging.DEBUG): - log = "[Recv Multi #%03d] ep: %s, tag: %s" % ( + log = "[Recv Multi #%03d] ep: 0x%x, tag: 0x%x" % ( self._recv_count, - hex(self.uid), - hex(tag), + self.uid, + tag, ) logger.debug(log) diff --git a/python/ucxx/_lib_async/tests/test_info.py b/python/ucxx/_lib_async/tests/test_info.py index f95c4f47..d3ab8b59 100644 --- a/python/ucxx/_lib_async/tests/test_info.py +++ b/python/ucxx/_lib_async/tests/test_info.py @@ -25,11 +25,11 @@ def test_worker_info(): @pytest.mark.parametrize( "transports", - ["posix", "tcp", "posix,tcp"], + ["self", "tcp", "self,tcp"], ) def test_check_transport(transports): transports_list = transports.split(",") - inactive_transports = list(set(["posix", "tcp"]) - set(transports_list)) + inactive_transports = list(set(["self", "tcp"]) - set(transports_list)) ucxx.reset() options = {"TLS": transports, "NET_DEVICES": "all"} diff --git a/python/ucxx/_lib_async/tests/test_probe.py b/python/ucxx/_lib_async/tests/test_probe.py index a6e524b6..c56d06cc 100644 --- a/python/ucxx/_lib_async/tests/test_probe.py +++ b/python/ucxx/_lib_async/tests/test_probe.py @@ -12,9 +12,6 @@ @pytest.mark.parametrize("transfer_api", ["am", "tag"]) @pytest.mark.xfail(reason="https://github.com/rapidsai/ucxx/issues/19") async def test_message_probe(transfer_api): - if transfer_api == "am": - pytest.skip("AM not implemented yet") - msg = bytearray(b"0" * 10) async def server_node(ep): @@ -25,7 +22,7 @@ async def server_node(ep): if transfer_api == "am": assert ep._ep.am_probe() is True - received = await ep.am_recv() + received = bytes(await ep.am_recv()) else: assert ep._ctx.worker.tag_probe(ep._tags["msg_recv"]) is True received = bytearray(10) diff --git a/python/ucxx/_lib_async/tests/test_send_recv_am.py b/python/ucxx/_lib_async/tests/test_send_recv_am.py new file mode 100644 index 00000000..3a250448 --- /dev/null +++ b/python/ucxx/_lib_async/tests/test_send_recv_am.py @@ -0,0 +1,95 @@ +import asyncio +from functools import partial + +import numpy as np +import pytest +from utils import wait_listener_client_handlers + +import ucxx + +msg_sizes = [0] + [2**i for i in range(0, 25, 4)] + + +def _bytearray_assert_equal(a, b): + assert a == b + + +def get_data(): + ret = [ + { + "allocator": bytearray, + "generator": lambda n: bytearray(b"m" * n), + "validator": lambda recv, exp: _bytearray_assert_equal(bytes(recv), exp), + "memory_type": "host", + }, + { + "allocator": partial(np.ones, dtype=np.uint8), + "generator": partial(np.arange, dtype=np.int64), + "validator": lambda recv, exp: np.testing.assert_equal( + recv.view(np.int64), exp + ), + "memory_type": "host", + }, + ] + + try: + import cupy as cp + + ret.append( + { + "allocator": partial(cp.ones, dtype=np.uint8), + "generator": partial(cp.arange, dtype=np.int64), + "validator": lambda recv, exp: cp.testing.assert_array_equal( + cp.asarray(recv).view(np.int64), exp + ), + "memory_type": "cuda", + } + ) + except ImportError: + pass + + return ret + + +def simple_server(size, recv): + async def server(ep): + recv.append(await ep.am_recv()) + await ep.close() + + return server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("size", msg_sizes) +@pytest.mark.parametrize("recv_wait", [True, False]) +@pytest.mark.parametrize("data", get_data()) +async def test_send_recv_am(size, recv_wait, data): + rndv_thresh = 8192 + ucxx.init(options={"RNDV_THRESH": str(rndv_thresh)}) + + msg = data["generator"](size) + + recv = [] + listener = ucxx.create_listener(simple_server(size, recv)) + num_clients = 1 + clients = [ + await ucxx.create_endpoint(ucxx.get_address(), listener.port) + for i in range(num_clients) + ] + for c in clients: + if recv_wait: + # By sleeping here we ensure that the listener's + # ep.am_recv call will have to wait, rather than return + # immediately as receive data is already available. + await asyncio.sleep(1) + await c.am_send(msg) + for c in clients: + await c.close() + await wait_listener_client_handlers(listener) + + if data["memory_type"] == "cuda" and msg.nbytes < rndv_thresh: + # Eager messages are always received on the host, if no custom host + # allocator is registered, UCXX defaults to `np.array`. + np.testing.assert_equal(recv[0].view(np.int64), msg.get()) + else: + data["validator"](recv[0], msg)