From 409c00d953687786ac614117f1302dd0f55b97e4 Mon Sep 17 00:00:00 2001 From: Jianxin Xiong Date: Fri, 6 Dec 2024 16:34:17 -0800 Subject: [PATCH] prov/ucx: Fix segfault in ucx_send_callback In one code path, the request was not initialized before the callback function is called. As a result, NULL cq was dereferenced, leading to segfault. Signed-off-by: Jianxin Xiong --- prov/ucx/src/ucx_core.c | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/prov/ucx/src/ucx_core.c b/prov/ucx/src/ucx_core.c index d1ffc947296..058770b29b9 100644 --- a/prov/ucx/src/ucx_core.c +++ b/prov/ucx/src/ucx_core.c @@ -94,6 +94,25 @@ ssize_t ucx_do_sendmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, return ucx_translate_errcode(*(ucs_status_t*)status); } + if (UCS_PTR_STATUS(status) != UCS_OK) { + struct ucx_request *req = (struct ucx_request *)status; + + /* + * Set up the req fields before the callback function is called + * (in ucp_worker_progress or ucp_worker_flush). + */ + req->ep = u_ep; + if (!no_completion) { + req->completion.op_context = msg->context; + req->completion.flags = FI_SEND | + (mode == UCX_MSG ? FI_MSG : FI_TAGGED); + req->completion.len = msg->msg_iov[0].iov_len; + req->completion.buf = msg->msg_iov[0].iov_base; + req->completion.tag = msg->tag; + req->cq = cq; + } + } + if (flags & FI_INJECT) { if(UCS_PTR_STATUS(status) != UCS_OK) { while ((cstatus = ucp_request_check_status(status)) @@ -110,13 +129,6 @@ ssize_t ucx_do_sendmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, goto done; } - if (no_completion) { - if (UCS_PTR_STATUS(status) != UCS_OK) - goto fence; - - goto done; - } - if (msg->context) { struct fi_context *ctx = ((struct fi_context*)(msg->context)); @@ -129,16 +141,6 @@ ssize_t ucx_do_sendmsg(struct fid_ep *ep, const struct fi_msg_tagged *msg, * Not done yet. completion will be handled by the callback * function. */ - struct ucx_request *req = (struct ucx_request *)status; - - req->completion.op_context = msg->context; - req->completion.flags = FI_SEND | - (mode == UCX_MSG ? FI_MSG : FI_TAGGED); - req->completion.len = msg->msg_iov[0].iov_len; - req->completion.buf = msg->msg_iov[0].iov_base; - req->completion.tag = msg->tag; - req->ep = u_ep; - req->cq = cq; goto fence; }