Skip to content

Commit

Permalink
Add ucxx::Request callback to constructor arguments (#194)
Browse files Browse the repository at this point in the history
Callbacks had to be assigned during the constructor because the member attributes are part of the base class. Now the base class exposes those arguments via the constructor and they can then be passed directly to the base class' constructor.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #194
  • Loading branch information
pentschev authored Jun 28, 2024
1 parent 49b56d9 commit c461ea9
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 25 deletions.
4 changes: 3 additions & 1 deletion cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ class Request : public Component {
Request(std::shared_ptr<Component> endpointOrWorker,
const data::RequestData requestData,
const std::string operationName,
const bool enablePythonFuture = false);
const bool enablePythonFuture = false,
RequestCallbackUserFunction callbackFunction = nullptr,
RequestCallbackUserData callbackData = nullptr);

/**
* @brief Perform initial processing of the request to determine if immediate completion.
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ namespace ucxx {
Request::Request(std::shared_ptr<Component> endpointOrWorker,
const data::RequestData requestData,
const std::string operationName,
const bool enablePythonFuture)
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
: _requestData(requestData),
_operationName(operationName),
_enablePythonFuture(enablePythonFuture)
_enablePythonFuture(enablePythonFuture),
_callback(callbackFunction),
_callbackData(callbackData)
{
_endpoint = std::dynamic_pointer_cast<Endpoint>(endpointOrWorker);
_worker =
Expand Down Expand Up @@ -210,8 +214,8 @@ void Request::setStatus(ucs_status_t status)

if (_status != UCS_INPROGRESS)
ucxx_error(
"ucxx::Request: %p, setStatus called with status: %d (%s) but status: %d (%s) was already "
"set",
"ucxx::Request: %p, setStatus called with status: %d (%s) but status: %d (%s) was "
"already set",
this,
status,
ucs_status_string(status),
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/request_am.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ RequestAm::RequestAm(std::shared_ptr<Component> endpointOrWorker,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
: Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture)
: Request(endpointOrWorker,
data::getRequestData(requestData),
operationName,
enablePythonFuture,
callbackFunction,
callbackData)
{
std::visit(data::dispatch{
[this](data::AmSend amSend) {
Expand All @@ -147,9 +152,6 @@ RequestAm::RequestAm(std::shared_ptr<Component> endpointOrWorker,
[](data::AmReceive amReceive) {},
},
requestData);

_callback = callbackFunction;
_callbackData = callbackData;
}

static void _amSendCallback(void* request, ucs_status_t status, void* user_data)
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/request_endpoint_close.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@ RequestEndpointClose::RequestEndpointClose(std::shared_ptr<Endpoint> endpoint,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
: Request(endpoint, requestData, operationName, enablePythonFuture)
: Request(
endpoint, requestData, operationName, enablePythonFuture, callbackFunction, callbackData)
{
if (_endpoint == nullptr && _worker == nullptr)
throw ucxx::Error("A valid endpoint or worker is required for a close operation.");

_callback = callbackFunction;
_callbackData = callbackData;
}

void RequestEndpointClose::endpointCloseCallback(void* request, ucs_status_t status, void* arg)
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/request_flush.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ RequestFlush::RequestFlush(std::shared_ptr<Component> endpointOrWorker,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
: Request(endpointOrWorker, requestData, operationName, enablePythonFuture)
: Request(endpointOrWorker,
requestData,
operationName,
enablePythonFuture,
callbackFunction,
callbackData)
{
if (_endpoint == nullptr && _worker == nullptr)
throw ucxx::Error("A valid endpoint or worker is required for a flush operation.");

_callback = callbackFunction;
_callbackData = callbackData;
}

void RequestFlush::flushCallback(void* request, ucs_status_t status, void* arg)
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/request_mem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ RequestMem::RequestMem(std::shared_ptr<Endpoint> endpoint,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
: Request(endpoint, data::getRequestData(requestData), operationName, enablePythonFuture)
: Request(endpoint,
data::getRequestData(requestData),
operationName,
enablePythonFuture,
callbackFunction,
callbackData)
{
std::visit(data::dispatch{
[this](data::MemPut memPut) {
Expand All @@ -63,9 +68,6 @@ RequestMem::RequestMem(std::shared_ptr<Endpoint> endpoint,
[](auto) { throw std::runtime_error("Unreachable"); },
},
requestData);

_callback = callbackFunction;
_callbackData = callbackData;
}

void RequestMem::memPutCallback(void* request, ucs_status_t status, void* arg)
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/request_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ RequestTag::RequestTag(std::shared_ptr<Component> endpointOrWorker,
const bool enablePythonFuture,
RequestCallbackUserFunction callbackFunction,
RequestCallbackUserData callbackData)
: Request(endpointOrWorker, data::getRequestData(requestData), operationName, enablePythonFuture)
: Request(endpointOrWorker,
data::getRequestData(requestData),
operationName,
enablePythonFuture,
callbackFunction,
callbackData)
{
std::visit(data::dispatch{
[this](data::TagSend tagSend) {
Expand All @@ -69,9 +74,6 @@ RequestTag::RequestTag(std::shared_ptr<Component> endpointOrWorker,
[](data::TagReceive tagReceive) {},
},
requestData);

_callback = callbackFunction;
_callbackData = callbackData;
}

void RequestTag::callback(void* request, ucs_status_t status, const ucp_tag_recv_info_t* info)
Expand Down

0 comments on commit c461ea9

Please sign in to comment.