From d275aaabcf2837c768a1b08d4a37ac7575b3aca1 Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Wed, 20 Apr 2022 22:35:58 -0500 Subject: [PATCH] stream: add MPIX_I{send,recv}_enqueue and MPIX_Wait{all}_enqueue --- maint/local_python/binding_c.py | 3 + src/binding/c/stream_api.txt | 32 ++++ src/include/mpir_request.h | 7 + src/mpi/stream/stream_impl.c | 309 +++++++++++++++++++++++++++++++- 4 files changed, 346 insertions(+), 5 deletions(-) diff --git a/maint/local_python/binding_c.py b/maint/local_python/binding_c.py index df5c870e56a..a278f8ac03a 100644 --- a/maint/local_python/binding_c.py +++ b/maint/local_python/binding_c.py @@ -696,6 +696,9 @@ def process_func_parameters(func): if RE.match(r'mpi_startall', func['name'], re.IGNORECASE): impl_arg_list.append(ptrs_name) impl_param_list.append("MPIR_Request **%s" % ptrs_name) + else: + impl_arg_list.append(name) + impl_param_list.append("MPI_Request %s[]" % name) else: print("Unhandled handle array: " + name, file=sys.stderr) elif "code-handle_ptr-tail" in func and name in func['code-handle_ptr-tail']: diff --git a/src/binding/c/stream_api.txt b/src/binding/c/stream_api.txt index a891f627c0c..05d046f280c 100644 --- a/src/binding/c/stream_api.txt +++ b/src/binding/c/stream_api.txt @@ -28,3 +28,35 @@ MPIX_Recv_enqueue: tag: TAG, [message tag or MPI_ANY_TAG] comm: COMMUNICATOR status: STATUS, direction=out + +MPIX_Isend_enqueue: + buf: BUFFER, constant=True, [initial address of send buffer] + count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer] + datatype: DATATYPE, [datatype of each send buffer element] + dest: RANK, [rank of destination] + tag: TAG, [message tag] + comm: COMMUNICATOR + request: REQUEST, direction=out + +MPIX_Irecv_enqueue: + buf: BUFFER, direction=out, [initial address of receive buffer] + count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer] + datatype: DATATYPE, [datatype of each receive buffer element] + source: RANK, [rank of source or MPI_ANY_SOURCE] + tag: TAG, [message tag or MPI_ANY_TAG] + comm: COMMUNICATOR + request: REQUEST, direction=out + +MPI_Wait_enqueue: + request: REQUEST, direction=inout, [request] + status: STATUS, direction=out + +MPI_Waitall_enqueue: + count: ARRAY_LENGTH_NNI, [lists length] + array_of_requests: REQUEST, direction=inout, length=count, [array of requests] + array_of_statuses: STATUS, direction=out, length=*, pointer=False, [array of status objects] +{ -- error_check -- array_of_statuses + if (count > 0) { + MPIR_ERRTEST_ARGNULL(array_of_statuses, "array_of_statuses", mpi_errno); + } +} diff --git a/src/include/mpir_request.h b/src/include/mpir_request.h index 27d3074d15d..f8accc7f608 100644 --- a/src/include/mpir_request.h +++ b/src/include/mpir_request.h @@ -65,6 +65,7 @@ typedef enum MPIR_Request_kind_t { MPIR_REQUEST_KIND__PART_SEND, /* Partitioned send req returned to user */ MPIR_REQUEST_KIND__PART_RECV, /* Partitioned recv req returned to user */ MPIR_REQUEST_KIND__PART, /* Partitioned pt2pt internal reqs */ + MPIR_REQUEST_KIND__ENQUEUE, /* enqueued (to gpu stream) request */ MPIR_REQUEST_KIND__GREQUEST, MPIR_REQUEST_KIND__COLL, MPIR_REQUEST_KIND__MPROBE, /* see NOTE-R1 */ @@ -220,6 +221,12 @@ struct MPIR_Request { MPL_atomic_int_t active_flag; /* flag indicating whether in a start-complete active period. * Value is 0 or 1. */ } part; /* kind : MPIR_REQUEST_KIND__PART_SEND or MPIR_REQUEST_KIND__PART_RECV */ + struct { + MPL_gpu_stream_t gpu_stream; + struct MPIR_Request *real_request; + bool is_send; + void *data; + } enqueue; struct { MPIR_Win *win; } rma; /* kind : MPIR_REQUEST_KIND__RMA */ diff --git a/src/mpi/stream/stream_impl.c b/src/mpi/stream/stream_impl.c index 6c358468b4d..49989508fad 100644 --- a/src/mpi/stream/stream_impl.c +++ b/src/mpi/stream/stream_impl.c @@ -210,6 +210,29 @@ static int get_local_gpu_stream(MPIR_Comm * comm_ptr, MPL_gpu_stream_t * gpu_str goto fn_exit; } +static int allocate_enqueue_request(MPIR_Comm * comm_ptr, MPIR_Request ** req) +{ + int mpi_errno = MPI_SUCCESS; + + MPIR_Stream *stream_ptr = NULL; + if (comm_ptr->stream_comm_type == MPIR_STREAM_COMM_SINGLE) { + stream_ptr = comm_ptr->stream_comm.single.stream; + } else if (comm_ptr->stream_comm_type == MPIR_STREAM_COMM_MULTIPLEX) { + stream_ptr = comm_ptr->stream_comm.multiplex.local_streams[comm_ptr->rank]; + } + MPIR_Assert(stream_ptr); + + int vci = stream_ptr->vci; + MPIR_Assert(vci > 0); + + /* stream vci are only accessed within a serialized context */ + (*req) = MPIR_Request_create_from_pool_safe(MPIR_REQUEST_KIND__ENQUEUE, vci, 1); + (*req)->u.enqueue.gpu_stream = stream_ptr->u.gpu_stream; + (*req)->u.enqueue.real_request = NULL; + + return mpi_errno; +} + /* send enqueue */ struct send_data { const void *buf; @@ -221,9 +244,11 @@ struct send_data { void *host_buf; MPI_Aint data_sz; MPI_Aint actual_pack_bytes; + /* for isend */ + MPIR_Request *req; }; -static void send_stream_cb(void *data) +static void send_enqueue_cb(void *data) { int mpi_errno; MPIR_Request *request_ptr = NULL; @@ -286,7 +311,7 @@ int MPIR_Send_enqueue_impl(const void *buf, MPI_Aint count, MPI_Datatype datatyp p->datatype = datatype; } - MPL_gpu_launch_hostfn(gpu_stream, send_stream_cb, p); + MPL_gpu_launch_hostfn(gpu_stream, send_enqueue_cb, p); fn_exit: return mpi_errno; @@ -306,9 +331,11 @@ struct recv_data { void *host_buf; MPI_Aint data_sz; MPI_Aint actual_unpack_bytes; + /* for irend */ + MPIR_Request *req; }; -static void recv_stream_cb(void *data) +static void recv_enqueue_cb(void *data) { int mpi_errno; MPIR_Request *request_ptr = NULL; @@ -370,7 +397,7 @@ int MPIR_Recv_enqueue_impl(void *buf, MPI_Aint count, MPI_Datatype datatype, MPIR_gpu_malloc_host(&p->host_buf, p->data_sz); - MPL_gpu_launch_hostfn(gpu_stream, recv_stream_cb, p); + MPL_gpu_launch_hostfn(gpu_stream, recv_enqueue_cb, p); mpi_errno = MPIR_Typerep_unpack_stream(p->host_buf, p->data_sz, buf, count, datatype, 0, &p->actual_unpack_bytes, &gpu_stream); @@ -383,7 +410,7 @@ int MPIR_Recv_enqueue_impl(void *buf, MPI_Aint count, MPI_Datatype datatype, p->count = count; p->datatype = datatype; - MPL_gpu_launch_hostfn(gpu_stream, recv_stream_cb, p); + MPL_gpu_launch_hostfn(gpu_stream, recv_enqueue_cb, p); } fn_exit: @@ -391,3 +418,275 @@ int MPIR_Recv_enqueue_impl(void *buf, MPI_Aint count, MPI_Datatype datatype, fn_fail: goto fn_exit; } + +/* ---- isend enqueue ---- */ +static void isend_enqueue_cb(void *data) +{ + int mpi_errno; + MPIR_Request *request_ptr = NULL; + + struct send_data *p = data; + if (p->host_buf) { + assert(p->actual_pack_bytes == p->data_sz); + + mpi_errno = MPID_Send(p->host_buf, p->data_sz, MPI_BYTE, p->dest, p->tag, p->comm_ptr, + MPIR_CONTEXT_INTRA_PT2PT, &request_ptr); + } else { + mpi_errno = MPID_Send(p->buf, p->count, p->datatype, p->dest, p->tag, p->comm_ptr, + MPIR_CONTEXT_INTRA_PT2PT, &request_ptr); + } + assert(mpi_errno == MPI_SUCCESS); + assert(request_ptr != NULL); + + p->req->u.enqueue.real_request = request_ptr; +} + +int MPIR_Isend_enqueue_impl(const void *buf, MPI_Aint count, MPI_Datatype datatype, + int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req) +{ + int mpi_errno = MPI_SUCCESS; + + MPL_gpu_stream_t gpu_stream; + mpi_errno = get_local_gpu_stream(comm_ptr, &gpu_stream); + MPIR_ERR_CHECK(mpi_errno); + + struct send_data *p; + p = MPL_malloc(sizeof(struct send_data), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem"); + + mpi_errno = allocate_enqueue_request(comm_ptr, req); + MPIR_ERR_CHECK(mpi_errno); + (*req)->u.enqueue.is_send = true; + (*req)->u.enqueue.data = p; + + p->req = *req; + p->dest = dest; + p->tag = tag; + p->comm_ptr = comm_ptr; + + if (MPIR_GPU_query_pointer_is_dev(buf)) { + MPI_Aint dt_size; + MPIR_Datatype_get_size_macro(datatype, dt_size); + p->data_sz = dt_size * count; + + MPIR_gpu_malloc_host(&p->host_buf, p->data_sz); + + mpi_errno = MPIR_Typerep_pack_stream(buf, count, datatype, 0, p->host_buf, p->data_sz, + &p->actual_pack_bytes, &gpu_stream); + MPIR_ERR_CHECK(mpi_errno); + } else { + p->host_buf = NULL; + p->buf = buf; + p->count = count; + p->datatype = datatype; + } + + MPL_gpu_launch_hostfn(gpu_stream, isend_enqueue_cb, p); + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +/* ---- irecv enqueue ---- */ +static void irecv_enqueue_cb(void *data) +{ + int mpi_errno; + MPIR_Request *request_ptr = NULL; + + struct recv_data *p = data; + if (p->host_buf) { + mpi_errno = MPID_Recv(p->host_buf, p->data_sz, MPI_BYTE, p->source, p->tag, p->comm_ptr, + MPIR_CONTEXT_INTRA_PT2PT, p->status, &request_ptr); + } else { + mpi_errno = MPID_Recv(p->buf, p->count, p->datatype, p->source, p->tag, p->comm_ptr, + MPIR_CONTEXT_INTRA_PT2PT, p->status, &request_ptr); + } + assert(mpi_errno == MPI_SUCCESS); + assert(request_ptr != NULL); + + p->req->u.enqueue.real_request = request_ptr; +} + +int MPIR_Irecv_enqueue_impl(void *buf, MPI_Aint count, MPI_Datatype datatype, + int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req) +{ + int mpi_errno = MPI_SUCCESS; + + MPL_gpu_stream_t gpu_stream; + mpi_errno = get_local_gpu_stream(comm_ptr, &gpu_stream); + MPIR_ERR_CHECK(mpi_errno); + + struct recv_data *p; + p = MPL_malloc(sizeof(struct recv_data), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem"); + + mpi_errno = allocate_enqueue_request(comm_ptr, req); + MPIR_ERR_CHECK(mpi_errno); + (*req)->u.enqueue.is_send = false; + (*req)->u.enqueue.data = p; + + p->req = *req; + p->source = source; + p->tag = tag; + p->comm_ptr = comm_ptr; + p->status = MPI_STATUS_IGNORE; + + if (MPIR_GPU_query_pointer_is_dev(buf)) { + MPI_Aint dt_size; + MPIR_Datatype_get_size_macro(datatype, dt_size); + p->data_sz = dt_size * count; + + MPIR_gpu_malloc_host(&p->host_buf, p->data_sz); + + MPL_gpu_launch_hostfn(gpu_stream, recv_enqueue_cb, p); + + mpi_errno = MPIR_Typerep_unpack_stream(p->host_buf, p->data_sz, buf, count, datatype, 0, + &p->actual_unpack_bytes, &gpu_stream); + MPIR_ERR_CHECK(mpi_errno); + + MPL_gpu_launch_hostfn(gpu_stream, recv_stream_cleanup_cb, p); + } else { + p->host_buf = NULL; + p->buf = buf; + p->count = count; + p->datatype = datatype; + + MPL_gpu_launch_hostfn(gpu_stream, irecv_enqueue_cb, p); + } + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +} + +/* ---- wait enqueue ---- */ +static void wait_enqueue_cb(void *data) +{ + int mpi_errno; + MPIR_Request *enqueue_req = data; + MPIR_Request *real_req = enqueue_req->u.enqueue.real_request; + + if (enqueue_req->u.enqueue.is_send) { + struct send_data *p = enqueue_req->u.enqueue.data; + + mpi_errno = MPID_Wait(real_req, MPI_STATUS_IGNORE); + assert(mpi_errno == MPI_SUCCESS); + + MPIR_Request_free(real_req); + + if (p->host_buf) { + MPIR_gpu_free_host(p->host_buf); + } + MPL_free(p); + } else { + struct recv_data *p = enqueue_req->u.enqueue.data; + + mpi_errno = MPID_Wait(real_req, MPI_STATUS_IGNORE); + assert(mpi_errno == MPI_SUCCESS); + + MPIR_Request_extract_status(real_req, p->status); + MPIR_Request_free(real_req); + + if (!p->host_buf) { + MPL_free(p); + } + } + MPIR_Request_free(enqueue_req); +} + +int MPIR_Wait_enqueue_impl(MPIR_Request * req_ptr, MPI_Status * status) +{ + int mpi_errno = MPI_SUCCESS; + MPIR_Assert(req_ptr && req_ptr->kind == MPIR_REQUEST_KIND__ENQUEUE); + + MPL_gpu_stream_t gpu_stream = req_ptr->u.enqueue.gpu_stream; + if (!req_ptr->u.enqueue.is_send) { + struct recv_data *p = req_ptr->u.enqueue.data; + p->status = status; + } + + MPL_gpu_launch_hostfn(gpu_stream, wait_enqueue_cb, req_ptr); + + return mpi_errno; +} + +/* ---- waitall enqueue ---- */ +struct waitall_data { + int count; + MPI_Request *array_of_requests; + MPI_Status *array_of_statuses; +}; + +static void waitall_enqueue_cb(void *data) +{ + struct waitall_data *p = data; + + MPI_Request *reqs = MPL_malloc(p->count * sizeof(MPI_Request), MPL_MEM_OTHER); + MPIR_Assert(reqs); + + for (int i = 0; i < p->count; i++) { + MPIR_Request *enqueue_req; + MPIR_Request_get_ptr(p->array_of_requests[i], enqueue_req); + reqs[i] = enqueue_req->u.enqueue.real_request->handle; + } + + MPIR_Waitall(p->count, reqs, p->array_of_statuses); + + for (int i = 0; i < p->count; i++) { + MPIR_Request *enqueue_req; + MPIR_Request_get_ptr(p->array_of_requests[i], enqueue_req); + + if (enqueue_req->u.enqueue.is_send) { + struct send_data *p2 = enqueue_req->u.enqueue.data; + if (p2->host_buf) { + MPIR_gpu_free_host(p2->host_buf); + } + MPL_free(p2); + } else { + struct recv_data *p2 = enqueue_req->u.enqueue.data; + if (!p2->host_buf) { + MPL_free(p2); + } + } + MPIR_Request_free(enqueue_req); + } + MPL_free(reqs); + MPL_free(p); +} + +int MPIR_Waitall_enqueue_impl(int count, MPI_Request * array_of_requests, + MPI_Status * array_of_statuses) +{ + int mpi_errno = MPI_SUCCESS; + + MPL_gpu_stream_t gpu_stream = MPL_GPU_STREAM_DEFAULT; + for (int i = 0; i < count; i++) { + MPIR_Request *enqueue_req; + MPIR_Request_get_ptr(array_of_requests[i], enqueue_req); + + MPIR_Assert(enqueue_req && enqueue_req->kind == MPIR_REQUEST_KIND__ENQUEUE); + if (i == 0) { + gpu_stream = enqueue_req->u.enqueue.gpu_stream; + } else { + MPIR_Assert(gpu_stream == enqueue_req->u.enqueue.gpu_stream); + } + } + + struct waitall_data *p; + p = MPL_malloc(sizeof(struct waitall_data), MPL_MEM_OTHER); + MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem"); + + p->count = count; + p->array_of_requests = array_of_requests; + p->array_of_statuses = array_of_statuses; + + MPL_gpu_launch_hostfn(gpu_stream, waitall_enqueue_cb, p); + + fn_exit: + return mpi_errno; + fn_fail: + goto fn_exit; +}