Skip to content

Commit

Permalink
Move ucxx::Request status setting to base class
Browse files Browse the repository at this point in the history
Setting the status from a derived class before `callback()`, and,
therefore `setStatus()`, can be dangerous in situations where the user
code is not holding a reference to the request. In such cases, the user
code may be checking for the status to change to immediately continue by
releasing the request, when that happens and `setStatus()` is eventually
called, that will be operating on an invalid `ucxx::Request` object.
  • Loading branch information
pentschev committed Jun 14, 2023
1 parent 9e6145b commit 874b4c3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 14 deletions.
14 changes: 4 additions & 10 deletions cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,14 @@ void Request::setStatus(ucs_status_t status)
if (_endpoint != nullptr) _endpoint->removeInflightRequest(this);
_worker->removeInflightRequest(this);

if (_status == UCS_INPROGRESS) {
// If the status is not `UCS_INPROGRESS`, the derived class has already set the
// status, a truncated message for example.
_status.store(status);
}

ucs_status_t s = _status;

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback called with status %d (%s)",
s,
ucs_status_string(s));
status,
ucs_status_string(status));

_status.store(status);

if (_enablePythonFuture) {
auto future = std::static_pointer_cast<ucxx::Future>(_future);
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/request_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ void RequestStream::populateDelayedSubmission()

void RequestStream::callback(void* request, ucs_status_t status, size_t length)
{
status = length == _length ? status : UCS_ERR_MESSAGE_TRUNCATED;
_status = status;
status = length == _length ? status : UCS_ERR_MESSAGE_TRUNCATED;

if (status == UCS_ERR_MESSAGE_TRUNCATED) {
const char* fmt = "length mismatch: %llu (got) != %llu (expected)";
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/request_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ void RequestTag::callback(void* request, ucs_status_t status, const ucp_tag_recv
std::snprintf(_status_msg.data(), _status_msg.size(), fmt, info->length, _length);
}

_status = status;

Request::callback(request, status);
}

Expand Down

0 comments on commit 874b4c3

Please sign in to comment.