Skip to content

Commit

Permalink
ADI: add MPID_Recv_enqueue
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhou committed Jul 13, 2022
1 parent e8377c6 commit f3b61e8
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/binding/c/stream_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ MPIX_Recv_enqueue:
tag: TAG, [message tag or MPI_ANY_TAG]
comm: COMMUNICATOR
status: STATUS, direction=out
.impl: mpid
.decl: MPIR_Recv_enqueue_impl

MPIX_Isend_enqueue:
buf: BUFFER, constant=True, [initial address of send buffer]
Expand Down
2 changes: 2 additions & 0 deletions src/mpid/ch3/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,8 @@ int MPID_Win_sync(MPIR_Win *win);

int MPID_Send_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr);
int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status);

void MPID_Progress_start(MPID_Progress_state * state);
int MPID_Progress_wait(MPID_Progress_state * state);
Expand Down
6 changes: 6 additions & 0 deletions src/mpid/ch3/src/ch3_stream_enqueue.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ int MPID_Send_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
{
return MPIR_Send_enqueue_impl(buf, count, datatype, dest, tag, comm_ptr);
}

int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status)
{
return MPIR_Recv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, status);
}
2 changes: 2 additions & 0 deletions src/mpid/ch4/include/mpidch4.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ MPL_STATIC_INLINE_PREFIX int MPID_Iscatterv(const void *, const MPI_Aint *, cons
MPIR_Comm *, MPIR_Request **) MPL_STATIC_INLINE_SUFFIX;
int MPID_Send_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr);
int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status);
int MPID_Abort(struct MPIR_Comm *comm, int mpi_errno, int exit_code, const char *error_msg);

/* This function is not exposed to the upper layers but functions in a way
Expand Down
83 changes: 83 additions & 0 deletions src/mpid/ch4/src/ch4_stream_enqueue.c
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,86 @@ int MPID_Send_enqueue(const void *buf, MPI_Aint count, MPI_Datatype datatype,
fn_fail:
goto fn_exit;
}

/* ---- recv enqueue ---- */
struct recv_data {
void *buf;
MPI_Aint count;
MPI_Datatype datatype;
int source;
int tag;
MPIR_Comm *comm_ptr;
/* req is a recv request for Recv, but an enqueue request for Irecv */
MPIR_Request *req;
};

static void recv_enqueue_cb(void *data)
{
int mpi_errno;
MPIR_Request *request_ptr = NULL;

struct recv_data *p = data;
mpi_errno = MPID_Irecv(p->buf, p->count, p->datatype, p->source, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, &request_ptr);
MPIR_Assertp(mpi_errno == MPI_SUCCESS);
MPIR_Assertp(request_ptr != NULL);

p->req = request_ptr;
}

int MPID_Recv_enqueue(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status)
{
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

if (!MPIR_CVAR_CH4_ENABLE_STREAM_WORKQ) {
mpi_errno = MPIR_Recv_enqueue_impl(buf, count, datatype, source, tag, comm_ptr, status);
goto fn_exit;
}

MPL_gpu_stream_t gpu_stream;
MPIDU_stream_workq_t *workq;
GET_STREAM_AND_WORKQ(comm_ptr, gpu_stream, workq);

struct recv_data *p;
p = MPL_malloc(sizeof(struct recv_data), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPIDU_stream_workq_op_t *op;
op = MPL_malloc(sizeof(MPIDU_stream_workq_op_t), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!op, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPL_gpu_event_t *trigger_event;
MPL_gpu_event_t *done_event;
MPIDU_stream_workq_alloc_event(&trigger_event);
MPIDU_stream_workq_alloc_event(&done_event);

p->buf = buf;
p->count = count;
p->datatype = datatype;
p->source = source;
p->tag = tag;
p->comm_ptr = comm_ptr;

op->cb = recv_enqueue_cb;
op->data = p;
op->trigger_event = trigger_event;
op->done_event = done_event;
op->request = &(p->req);
if (status != MPI_STATUSES_IGNORE) {
op->status = status;
} else {
op->status = NULL;
}

MPL_gpu_enqueue_trigger(trigger_event, gpu_stream);
MPL_gpu_enqueue_wait(done_event, gpu_stream);
MPIDU_stream_workq_enqueue(workq, op);

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

0 comments on commit f3b61e8

Please sign in to comment.