diff --git a/src/binding/c/stream_api.txt b/src/binding/c/stream_api.txt index 21df4495fcd..d7ab09133f0 100644 --- a/src/binding/c/stream_api.txt +++ b/src/binding/c/stream_api.txt @@ -187,6 +187,8 @@ MPIX_Isend_enqueue: tag: TAG, [message tag] comm: COMMUNICATOR request: REQUEST, direction=out + .impl: mpid + .decl: MPIR_Isend_enqueue_impl MPIX_Irecv_enqueue: buf: BUFFER, direction=out, [initial address of receive buffer] @@ -194,9 +196,10 @@ MPIX_Irecv_enqueue: datatype: DATATYPE, [datatype of each receive buffer element] source: RANK, [rank of source or MPI_ANY_SOURCE] tag: TAG, [message tag or MPI_ANY_TAG] - comm: COMMUNICATOR request: REQUEST, direction=out + .impl: mpid + .decl: MPIR_Irecv_enqueue_impl MPIX_Wait_enqueue: request: REQUEST, direction=inout, [request] diff --git a/src/mpid/ch3/include/mpidpre.h b/src/mpid/ch3/include/mpidpre.h index 49d7633ca85..c51a82a70d3 100644 --- a/src/mpid/ch3/include/mpidpre.h +++ b/src/mpid/ch3/include/mpidpre.h @@ -817,6 +817,10 @@ int MPID_Send_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, MPIR_Comm * comm_ptr); int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status); +int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype, + int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req); +int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype, + int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req); void MPID_Progress_start(MPID_Progress_state * state); int MPID_Progress_wait(MPID_Progress_state * state); diff --git a/src/mpid/ch3/src/ch3_stream_enqueue.c b/src/mpid/ch3/src/ch3_stream_enqueue.c index a12a5588b8a..9bcf10b93c3 100644 --- a/src/mpid/ch3/src/ch3_stream_enqueue.c +++ b/src/mpid/ch3/src/ch3_stream_enqueue.c @@ -16,3 +16,16 @@ int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype, { return MPIR_Recv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, status); } + + +int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype, + int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req) +{ + return MPIR_Isend_enqueue_impl(buf, count, datatype, dest, tag, comm_ptr, req); +} + +int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype, + int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req) +{ + return MPIR_Irecv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, req); +} diff --git a/src/mpid/ch4/include/mpidch4.h b/src/mpid/ch4/include/mpidch4.h index d7f1705d716..199f7cb5713 100644 --- a/src/mpid/ch4/include/mpidch4.h +++ b/src/mpid/ch4/include/mpidch4.h @@ -317,6 +317,10 @@ int MPID_Send_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag, MPIR_Comm * comm_ptr); int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status); +int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype, + int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req); +int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype, + int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req); int MPID_Abort(struct MPIR_Comm *comm, int mpi_errno, int exit_code, const char *error_msg); /* This function is not exposed to the upper layers but functions in a way diff --git a/src/mpid/ch4/src/ch4_stream_enqueue.c b/src/mpid/ch4/src/ch4_stream_enqueue.c index 706bb8055df..5ab60cf050c 100644 --- a/src/mpid/ch4/src/ch4_stream_enqueue.c +++ b/src/mpid/ch4/src/ch4_stream_enqueue.c @@ -195,3 +195,145 @@ int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype, fn_fail: goto fn_exit; } + +/* ---- isend enqueue ---- */ +static void isend_enqueue_cb(void *data) +{ + int mpi_errno; + MPIR_Request *request_ptr = NULL; + + struct send_data *p = data; + mpi_errno = MPID_Send(p->buf, p->count, p->datatype, p->dest, p->tag, p->comm_ptr, + MPIR_CONTEXT_INTRA_PT2PT, &request_ptr); + MPIR_Assertp(mpi_errno == MPI_SUCCESS); + MPIR_Assertp(request_ptr != NULL); + + p->req->u.enqueue.real_request = request_ptr; +} + +int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype, + int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; + + if (!MPIR_CVAR_CH4_ENABLE_STREAM_WORKQ) { + mpi_errno = MPIR_Isend_enqueue_impl(buf, count, datatype, dest, tag, comm_ptr, req); + goto fn_exit; + } + + MPL_gpu_stream_t gpu_stream; + MPIDU_stream_workq_t *workq; + GET_STREAM_AND_WORKQ(comm_ptr, gpu_stream, workq); + + struct send_data *p; + p = MPL_malloc(sizeof(struct send_data), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem"); + + MPIDU_stream_workq_op_t *op; + op = MPL_malloc(sizeof(MPIDU_stream_workq_op_t), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP(!op, mpi_errno, MPI_ERR_OTHER, "**nomem"); + + MPL_gpu_event_t *trigger_event; + MPIDU_stream_workq_alloc_event(&trigger_event); + + mpi_errno = MPIR_allocate_enqueue_request(comm_ptr, req); + MPIR_ERR_CHECK(mpi_errno); + (*req)->u.enqueue.is_send = true; + (*req)->u.enqueue.data = p; + + p->buf = buf; + p->count = count; + p->datatype = datatype; + p->dest = dest; + p->tag = tag; + p->comm_ptr = comm_ptr; + p->req = *req; + + op->cb = isend_enqueue_cb; + op->data = p; + op->trigger_event = trigger_event; + op->done_event = NULL; + op->request = &(*req)->u.enqueue.real_request; + op->status = NULL; + + MPL_gpu_enqueue_trigger(trigger_event, gpu_stream); + MPIDU_stream_workq_enqueue(workq, op); + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +} + +/* ---- irecv enqueue ---- */ +static void irecv_enqueue_cb(void *data) +{ + int mpi_errno; + MPIR_Request *request_ptr = NULL; + + struct recv_data *p = data; + mpi_errno = MPID_Irecv(p->buf, p->count, p->datatype, p->source, p->tag, p->comm_ptr, + MPIR_CONTEXT_INTRA_PT2PT, &request_ptr); + MPIR_Assertp(mpi_errno == MPI_SUCCESS); + MPIR_Assertp(request_ptr != NULL); + + p->req->u.enqueue.real_request = request_ptr; +} + +int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype, + int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_FUNC_ENTER; + + if (!MPIR_CVAR_CH4_ENABLE_STREAM_WORKQ) { + mpi_errno = MPIR_Irecv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, req); + goto fn_exit; + } + + MPL_gpu_stream_t gpu_stream; + MPIDU_stream_workq_t *workq; + GET_STREAM_AND_WORKQ(comm_ptr, gpu_stream, workq); + + struct recv_data *p; + p = MPL_malloc(sizeof(struct recv_data), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem"); + + MPIDU_stream_workq_op_t *op; + op = MPL_malloc(sizeof(MPIDU_stream_workq_op_t), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP(!op, mpi_errno, MPI_ERR_OTHER, "**nomem"); + + MPL_gpu_event_t *trigger_event; + MPIDU_stream_workq_alloc_event(&trigger_event); + + mpi_errno = MPIR_allocate_enqueue_request(comm_ptr, req); + MPIR_ERR_CHECK(mpi_errno); + (*req)->u.enqueue.is_send = false; + (*req)->u.enqueue.data = p; + + p->buf = buf; + p->count = count; + p->datatype = datatype; + p->source = source; + p->tag = tag; + p->comm_ptr = comm_ptr; + p->req = *req; + + op->cb = irecv_enqueue_cb; + op->data = p; + op->trigger_event = trigger_event; + op->done_event = NULL; + op->request = &(*req)->u.enqueue.real_request; + op->status = NULL; + + MPL_gpu_enqueue_trigger(trigger_event, gpu_stream); + MPIDU_stream_workq_enqueue(workq, op); + + fn_exit: + MPIR_FUNC_EXIT; + return mpi_errno; + fn_fail: + goto fn_exit; +}