Skip to content

Commit

Permalink
ADI: add MPID_Isend_enqueue and MPID_Irecv_enqueue
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhou committed Jul 14, 2022
1 parent 558efd5 commit 8a5419a
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/binding/c/stream_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,19 @@ MPIX_Isend_enqueue:
tag: TAG, [message tag]
comm: COMMUNICATOR
request: REQUEST, direction=out
.impl: mpid
.decl: MPIR_Isend_enqueue_impl

MPIX_Irecv_enqueue:
buf: BUFFER, direction=out, [initial address of receive buffer]
count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer]
datatype: DATATYPE, [datatype of each receive buffer element]
source: RANK, [rank of source or MPI_ANY_SOURCE]
tag: TAG, [message tag or MPI_ANY_TAG]

comm: COMMUNICATOR
request: REQUEST, direction=out
.impl: mpid
.decl: MPIR_Irecv_enqueue_impl

MPIX_Wait_enqueue:
request: REQUEST, direction=inout, [request]
Expand Down
4 changes: 4 additions & 0 deletions src/mpid/ch3/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,10 @@ int MPID_Send_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr);
int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status);
int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req);
int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req);

void MPID_Progress_start(MPID_Progress_state * state);
int MPID_Progress_wait(MPID_Progress_state * state);
Expand Down
13 changes: 13 additions & 0 deletions src/mpid/ch3/src/ch3_stream_enqueue.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,16 @@ int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
{
return MPIR_Recv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, status);
}


int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req)
{
return MPIR_Isend_enqueue_impl(buf, count, datatype, dest, tag, comm_ptr, req);
}

int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req)
{
return MPIR_Irecv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, req);
}
4 changes: 4 additions & 0 deletions src/mpid/ch4/include/mpidch4.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ int MPID_Send_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr);
int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status);
int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req);
int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req);
int MPID_Abort(struct MPIR_Comm *comm, int mpi_errno, int exit_code, const char *error_msg);

/* This function is not exposed to the upper layers but functions in a way
Expand Down
142 changes: 142 additions & 0 deletions src/mpid/ch4/src/ch4_stream_enqueue.c
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,145 @@ int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
fn_fail:
goto fn_exit;
}

/* ---- isend enqueue ---- */
static void isend_enqueue_cb(void *data)
{
int mpi_errno;
MPIR_Request *request_ptr = NULL;

struct send_data *p = data;
mpi_errno = MPID_Send(p->buf, p->count, p->datatype, p->dest, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, &request_ptr);
MPIR_Assertp(mpi_errno == MPI_SUCCESS);
MPIR_Assertp(request_ptr != NULL);

p->req->u.enqueue.real_request = request_ptr;
}

int MPID_Isend_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req)
{
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

if (!MPIR_CVAR_CH4_ENABLE_STREAM_WORKQ) {
mpi_errno = MPIR_Isend_enqueue_impl(buf, count, datatype, dest, tag, comm_ptr, req);
goto fn_exit;
}

MPL_gpu_stream_t gpu_stream;
MPIDU_stream_workq_t *workq;
GET_STREAM_AND_WORKQ(comm_ptr, gpu_stream, workq);

struct send_data *p;
p = MPL_malloc(sizeof(struct send_data), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPIDU_stream_workq_op_t *op;
op = MPL_malloc(sizeof(MPIDU_stream_workq_op_t), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!op, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPL_gpu_event_t *trigger_event;
MPIDU_stream_workq_alloc_event(&trigger_event);

mpi_errno = MPIR_allocate_enqueue_request(comm_ptr, req);
MPIR_ERR_CHECK(mpi_errno);
(*req)->u.enqueue.is_send = true;
(*req)->u.enqueue.data = p;

p->buf = buf;
p->count = count;
p->datatype = datatype;
p->dest = dest;
p->tag = tag;
p->comm_ptr = comm_ptr;
p->req = *req;

op->cb = isend_enqueue_cb;
op->data = p;
op->trigger_event = trigger_event;
op->done_event = NULL;
op->request = &(*req)->u.enqueue.real_request;
op->status = NULL;

MPL_gpu_enqueue_trigger(trigger_event, gpu_stream);
MPIDU_stream_workq_enqueue(workq, op);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

/* ---- irecv enqueue ---- */
static void irecv_enqueue_cb(void *data)
{
int mpi_errno;
MPIR_Request *request_ptr = NULL;

struct recv_data *p = data;
mpi_errno = MPID_Irecv(p->buf, p->count, p->datatype, p->source, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, &request_ptr);
MPIR_Assertp(mpi_errno == MPI_SUCCESS);
MPIR_Assertp(request_ptr != NULL);

p->req->u.enqueue.real_request = request_ptr;
}

int MPID_Irecv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPIR_Request ** req)
{
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

if (!MPIR_CVAR_CH4_ENABLE_STREAM_WORKQ) {
mpi_errno = MPIR_Irecv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, req);
goto fn_exit;
}

MPL_gpu_stream_t gpu_stream;
MPIDU_stream_workq_t *workq;
GET_STREAM_AND_WORKQ(comm_ptr, gpu_stream, workq);

struct recv_data *p;
p = MPL_malloc(sizeof(struct recv_data), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPIDU_stream_workq_op_t *op;
op = MPL_malloc(sizeof(MPIDU_stream_workq_op_t), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!op, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPL_gpu_event_t *trigger_event;
MPIDU_stream_workq_alloc_event(&trigger_event);

mpi_errno = MPIR_allocate_enqueue_request(comm_ptr, req);
MPIR_ERR_CHECK(mpi_errno);
(*req)->u.enqueue.is_send = false;
(*req)->u.enqueue.data = p;

p->buf = buf;
p->count = count;
p->datatype = datatype;
p->source = source;
p->tag = tag;
p->comm_ptr = comm_ptr;
p->req = *req;

op->cb = irecv_enqueue_cb;
op->data = p;
op->trigger_event = trigger_event;
op->done_event = NULL;
op->request = &(*req)->u.enqueue.real_request;
op->status = NULL;

MPL_gpu_enqueue_trigger(trigger_event, gpu_stream);
MPIDU_stream_workq_enqueue(workq, op);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

0 comments on commit 8a5419a

Please sign in to comment.