diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index c164ecf6..0e02cf88 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -77,7 +77,9 @@ class Request : public Component { Request(std::shared_ptr endpointOrWorker, const data::RequestData requestData, const std::string operationName, - const bool enablePythonFuture = false); + const bool enablePythonFuture = false, + RequestCallbackUserFunction callbackFunction = nullptr, + RequestCallbackUserData callbackData = nullptr); /** * @brief Perform initial processing of the request to determine if immediate completion. diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 981984ff..22e135e0 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -19,10 +19,14 @@ namespace ucxx { Request::Request(std::shared_ptr endpointOrWorker, const data::RequestData requestData, const std::string operationName, - const bool enablePythonFuture) + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) : _requestData(requestData), _operationName(operationName), - _enablePythonFuture(enablePythonFuture) + _enablePythonFuture(enablePythonFuture), + _callback(callbackFunction), + _callbackData(callbackData) { _endpoint = std::dynamic_pointer_cast(endpointOrWorker); _worker = @@ -210,8 +214,8 @@ void Request::setStatus(ucs_status_t status) if (_status != UCS_INPROGRESS) ucxx_error( - "ucxx::Request: %p, setStatus called with status: %d (%s) but status: %d (%s) was already " - "set", + "ucxx::Request: %p, setStatus called with status: %d (%s) but status: %d (%s) was " + "already set", this, status, ucs_status_string(status), diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index cae3cea4..713f6bdc 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -137,7 +137,12 @@ RequestAm::RequestAm(std::shared_ptr endpointOrWorker, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture) + : Request(endpointOrWorker, + data::getRequestData(requestData), + operationName, + enablePythonFuture, + callbackFunction, + callbackData) { std::visit(data::dispatch{ [this](data::AmSend amSend) { @@ -147,9 +152,6 @@ RequestAm::RequestAm(std::shared_ptr endpointOrWorker, [](data::AmReceive amReceive) {}, }, requestData); - - _callback = callbackFunction; - _callbackData = callbackData; } static void _amSendCallback(void* request, ucs_status_t status, void* user_data) diff --git a/cpp/src/request_endpoint_close.cpp b/cpp/src/request_endpoint_close.cpp index 537f4f2f..09aa6403 100644 --- a/cpp/src/request_endpoint_close.cpp +++ b/cpp/src/request_endpoint_close.cpp @@ -41,13 +41,11 @@ RequestEndpointClose::RequestEndpointClose(std::shared_ptr endpoint, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpoint, requestData, operationName, enablePythonFuture) + : Request( + endpoint, requestData, operationName, enablePythonFuture, callbackFunction, callbackData) { if (_endpoint == nullptr && _worker == nullptr) throw ucxx::Error("A valid endpoint or worker is required for a close operation."); - - _callback = callbackFunction; - _callbackData = callbackData; } void RequestEndpointClose::endpointCloseCallback(void* request, ucs_status_t status, void* arg) diff --git a/cpp/src/request_flush.cpp b/cpp/src/request_flush.cpp index 5bc0ea9e..95b5021b 100644 --- a/cpp/src/request_flush.cpp +++ b/cpp/src/request_flush.cpp @@ -39,13 +39,15 @@ RequestFlush::RequestFlush(std::shared_ptr endpointOrWorker, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpointOrWorker, requestData, operationName, enablePythonFuture) + : Request(endpointOrWorker, + requestData, + operationName, + enablePythonFuture, + callbackFunction, + callbackData) { if (_endpoint == nullptr && _worker == nullptr) throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); - - _callback = callbackFunction; - _callbackData = callbackData; } void RequestFlush::flushCallback(void* request, ucs_status_t status, void* arg) diff --git a/cpp/src/request_mem.cpp b/cpp/src/request_mem.cpp index 682e9084..6fe883e1 100644 --- a/cpp/src/request_mem.cpp +++ b/cpp/src/request_mem.cpp @@ -49,7 +49,12 @@ RequestMem::RequestMem(std::shared_ptr endpoint, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpoint, data::getRequestData(requestData), operationName, enablePythonFuture) + : Request(endpoint, + data::getRequestData(requestData), + operationName, + enablePythonFuture, + callbackFunction, + callbackData) { std::visit(data::dispatch{ [this](data::MemPut memPut) { @@ -63,9 +68,6 @@ RequestMem::RequestMem(std::shared_ptr endpoint, [](auto) { throw std::runtime_error("Unreachable"); }, }, requestData); - - _callback = callbackFunction; - _callbackData = callbackData; } void RequestMem::memPutCallback(void* request, ucs_status_t status, void* arg) diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index f8dfc8e2..9ffdf72f 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -59,7 +59,12 @@ RequestTag::RequestTag(std::shared_ptr endpointOrWorker, const bool enablePythonFuture, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) - : Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture) + : Request(endpointOrWorker, + data::getRequestData(requestData), + operationName, + enablePythonFuture, + callbackFunction, + callbackData) { std::visit(data::dispatch{ [this](data::TagSend tagSend) { @@ -69,9 +74,6 @@ RequestTag::RequestTag(std::shared_ptr endpointOrWorker, [](data::TagReceive tagReceive) {}, }, requestData); - - _callback = callbackFunction; - _callbackData = callbackData; } void RequestTag::callback(void* request, ucs_status_t status, const ucp_tag_recv_info_t* info)