diff --git a/prov/ucx/src/ucx_core.c b/prov/ucx/src/ucx_core.c index d1ffc947296..058770b29b9 100644 --- a/prov/ucx/src/ucx_core.c +++ b/prov/ucx/src/ucx_core.c @@ -94,6 +94,25 @@ ssize_t ucx_do_sendmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, return ucx_translate_errcode(*(ucs_status_t*)status); } + if (UCS_PTR_STATUS(status) != UCS_OK) { + struct ucx_request *req = (struct ucx_request *)status; + + /* + * Set up the req fields before the callback function is called + * (in ucp_worker_progress or ucp_worker_flush). + */ + req->ep = u_ep; + if (!no_completion) { + req->completion.op_context = msg->context; + req->completion.flags = FI_SEND | + (mode == UCX_MSG ? FI_MSG : FI_TAGGED); + req->completion.len = msg->msg_iov[0].iov_len; + req->completion.buf = msg->msg_iov[0].iov_base; + req->completion.tag = msg->tag; + req->cq = cq; + } + } + if (flags & FI_INJECT) { if(UCS_PTR_STATUS(status) != UCS_OK) { while ((cstatus = ucp_request_check_status(status)) @@ -110,13 +129,6 @@ ssize_t ucx_do_sendmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, goto done; } - if (no_completion) { - if (UCS_PTR_STATUS(status) != UCS_OK) - goto fence; - - goto done; - } - if (msg->context) { struct fi_context *ctx = ((struct fi_context*)(msg->context)); @@ -129,16 +141,6 @@ ssize_t ucx_do_sendmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, * Not done yet. completion will be handled by the callback * function. */ - struct ucx_request *req = (struct ucx_request *)status; - - req->completion.op_context = msg->context; - req->completion.flags = FI_SEND | - (mode == UCX_MSG ? FI_MSG : FI_TAGGED); - req->completion.len = msg->msg_iov[0].iov_len; - req->completion.buf = msg->msg_iov[0].iov_base; - req->completion.tag = msg->tag; - req->ep = u_ep; - req->cq = cq; goto fence; }