Skip to content

Commit

Permalink
ADI: add MPID_Wait_enqueue and MPID_Waitall_enqueue
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhou committed Jul 14, 2022
1 parent 8a5419a commit 7dd59d1
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/binding/c/stream_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,15 @@ MPIX_Irecv_enqueue:
MPIX_Wait_enqueue:
request: REQUEST, direction=inout, [request]
status: STATUS, direction=out
.impl: mpid
.decl: MPIR_Wait_enqueue_impl

MPIX_Waitall_enqueue:
count: ARRAY_LENGTH_NNI, [lists length]
array_of_requests: REQUEST, direction=inout, length=count, [array of requests]
array_of_statuses: STATUS, direction=out, length=*, pointer=False, [array of status objects]
.impl: mpid
.decl: MPIR_Waitall_enqueue_impl
{ -- error_check -- array_of_statuses
if (count > 0) {
MPIR_ERRTEST_ARGNULL(array_of_statuses, "array_of_statuses", mpi_errno);
Expand Down
3 changes: 3 additions & 0 deletions src/mpid/ch3/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,9 @@ int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req);
int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req);
int MPID_Wait_enqueue(MPIR_Request * req_ptr, MPI_Status * status);
int MPID_Waitall_enqueue(int count, MPI_Request * array_of_requests,
MPI_Status * array_of_statuses);

void MPID_Progress_start(MPID_Progress_state * state);
int MPID_Progress_wait(MPID_Progress_state * state);
Expand Down
11 changes: 11 additions & 0 deletions src/mpid/ch3/src/ch3_stream_enqueue.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
{
return MPIR_Irecv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, req);
}

int MPID_Wait_enqueue(MPIR_Request * req_ptr, MPI_Status * status)
{
return MPIR_Wait_enqueue_impl(req_ptr, status);
}

int MPID_Waitall_enqueue(int count, MPI_Request * array_of_requests,
MPI_Status * array_of_statuses)
{
return MPIR_Waitall_enqueue_impl(count, array_of_requests, array_of_statuses);
}
3 changes: 3 additions & 0 deletions src/mpid/ch4/include/mpidch4.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req);
int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req);
int MPID_Wait_enqueue(MPIR_Request * req_ptr, MPI_Status * status);
int MPID_Waitall_enqueue(int count, MPI_Request * array_of_requests,
MPI_Status * array_of_statuses);
int MPID_Abort(struct MPIR_Comm *comm, int mpi_errno, int exit_code, const char *error_msg);

/* This function is not exposed to the upper layers but functions in a way
Expand Down
113 changes: 113 additions & 0 deletions src/mpid/ch4/src/ch4_stream_enqueue.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
workq = stream_ptr->dev.workq; \
} while (0)

#define REQUEST_GET_STREAM_AND_WORKQ(req, gpu_stream, workq) \
do { \
MPIR_Stream *stream_ptr = req->u.enqueue.stream_ptr; \
gpu_stream = stream_ptr->u.gpu_stream; \
workq = stream_ptr->dev.workq; \
} while (0)

/* ---- send enqueue ---- */
struct send_data {
const void *buf;
Expand Down Expand Up @@ -337,3 +344,109 @@ int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
fn_fail:
goto fn_exit;
}

/* ---- wait enqueue ---- */
int MPID_Wait_enqueue(MPIR_Request * req_ptr, MPI_Status * status)
{
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

if (!MPIR_CVAR_CH4_ENABLE_STREAM_WORKQ) {
mpi_errno = MPIR_Wait_enqueue_impl(req_ptr, status);
goto fn_exit;
}

MPL_gpu_stream_t gpu_stream;
MPIDU_stream_workq_t *workq;
REQUEST_GET_STREAM_AND_WORKQ(req_ptr, gpu_stream, workq);

MPIDU_stream_workq_op_t *op;
op = MPL_malloc(sizeof(MPIDU_stream_workq_op_t), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!op, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPL_gpu_event_t *trigger_event;
MPL_gpu_event_t *done_event;
MPIDU_stream_workq_alloc_event(&trigger_event);
MPIDU_stream_workq_alloc_event(&done_event);

op->cb = NULL;
op->data = NULL;
op->trigger_event = trigger_event;
op->done_event = done_event;
op->request = &req_ptr->u.enqueue.real_request;
if (status != MPI_STATUS_IGNORE) {
op->status = status;
} else {
op->status = NULL;
}

MPL_gpu_enqueue_trigger(trigger_event, gpu_stream);
MPL_gpu_enqueue_wait(done_event, gpu_stream);
MPIDU_stream_workq_enqueue(workq, op);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

/* ---- waitall enqueue ---- */
int MPID_Waitall_enqueue(int count, MPI_Request * array_of_requests, MPI_Status * array_of_statuses)
{
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

if (!MPIR_CVAR_CH4_ENABLE_STREAM_WORKQ) {
mpi_errno = MPIR_Waitall_enqueue_impl(count, array_of_requests, array_of_statuses);
goto fn_exit;
}

MPL_gpu_event_t *trigger_event;
MPL_gpu_event_t *done_event;
MPIDU_stream_workq_alloc_event(&trigger_event);
MPIDU_stream_workq_alloc_event(&done_event);
MPL_gpu_event_init_count(done_event, count);

MPL_gpu_stream_t the_gpu_stream;
for (int i = 0; i < count; i++) {
MPIR_Request *req_ptr;
MPIR_Request_get_ptr(array_of_requests[i], req_ptr);
MPIR_Assert(req_ptr && req_ptr->kind == MPIR_REQUEST_KIND__ENQUEUE);

MPL_gpu_stream_t gpu_stream;
MPIDU_stream_workq_t *workq;
REQUEST_GET_STREAM_AND_WORKQ(req_ptr, gpu_stream, workq);

if (i == 0) {
MPL_gpu_enqueue_trigger(trigger_event, gpu_stream);
MPL_gpu_enqueue_wait(done_event, gpu_stream);
the_gpu_stream = gpu_stream;
} else {
MPIR_Assertp(the_gpu_stream == gpu_stream);
}

MPIDU_stream_workq_op_t *op;
op = MPL_malloc(sizeof(MPIDU_stream_workq_op_t), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!op, mpi_errno, MPI_ERR_OTHER, "**nomem");

op->cb = NULL;
op->data = NULL;
op->trigger_event = trigger_event;
op->done_event = done_event;
op->request = &req_ptr->u.enqueue.real_request;
if (array_of_statuses != MPI_STATUSES_IGNORE) {
op->status = &array_of_statuses[i];
} else {
op->status = NULL;
}

MPIDU_stream_workq_enqueue(workq, op);
}

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

0 comments on commit 7dd59d1

Please sign in to comment.