Skip to content

Commit

Permalink
pt2pt/stream: skip packing if the buffer is host buffer
Browse files Browse the repository at this point in the history
If the buffer isn't device buffer, we can directly enqueue the send/recv
operations. This can later extend to cases when lower layer can handle
GPU communications inside host function callbacks.
  • Loading branch information
hzhou committed Mar 24, 2022
1 parent e4f12ba commit 2830361
Showing 1 changed file with 67 additions and 28 deletions.
95 changes: 67 additions & 28 deletions src/mpi/pt2pt/stream.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

/* ---- Send stream ---- */
struct send_data {
const void *buf;
MPI_Aint count;
MPI_Datatype datatype;
int dest;
int tag;
MPIR_Comm *comm_ptr;
Expand All @@ -21,10 +24,15 @@ static void send_stream_cb(void *data)
MPIR_Request *request_ptr = NULL;

struct send_data *p = data;
assert(p->actual_pack_bytes == p->data_sz);

mpi_errno = MPID_Send(p->host_buf, p->data_sz, MPI_BYTE, p->dest, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, &request_ptr);
if (p->host_buf) {
assert(p->actual_pack_bytes == p->data_sz);

mpi_errno = MPID_Send(p->host_buf, p->data_sz, MPI_BYTE, p->dest, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, &request_ptr);
} else {
mpi_errno = MPID_Send(p->buf, p->count, p->datatype, p->dest, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, &request_ptr);
}
assert(mpi_errno == MPI_SUCCESS);
assert(request_ptr != NULL);

Expand All @@ -33,7 +41,9 @@ static void send_stream_cb(void *data)

MPIR_Request_free(request_ptr);

MPIR_gpu_free_host(p->host_buf);
if (p->host_buf) {
MPIR_gpu_free_host(p->host_buf);
}
MPL_free(data);
}

Expand All @@ -46,20 +56,27 @@ int MPIR_Send_stream_impl(const void *buf, MPI_Aint count, MPI_Datatype datatype
p = MPL_malloc(sizeof(struct send_data), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPI_Aint dt_size;
MPIR_Datatype_get_size_macro(datatype, dt_size);
p->data_sz = dt_size * count;

MPIR_gpu_malloc_host(&p->host_buf, p->data_sz);

mpi_errno = MPIR_Typerep_pack_stream(buf, count, datatype, 0, p->host_buf, p->data_sz,
&p->actual_pack_bytes, stream);
MPIR_ERR_CHECK(mpi_errno);

p->dest = dest;
p->tag = tag;
p->comm_ptr = comm_ptr;

if (MPIR_GPU_query_pointer_is_dev(buf)) {
MPI_Aint dt_size;
MPIR_Datatype_get_size_macro(datatype, dt_size);
p->data_sz = dt_size * count;

MPIR_gpu_malloc_host(&p->host_buf, p->data_sz);

mpi_errno = MPIR_Typerep_pack_stream(buf, count, datatype, 0, p->host_buf, p->data_sz,
&p->actual_pack_bytes, stream);
MPIR_ERR_CHECK(mpi_errno);
} else {
p->host_buf = NULL;
p->buf = buf;
p->count = count;
p->datatype = datatype;
}

MPL_gpu_launch_hostfn(stream, send_stream_cb, p);

fn_exit:
Expand All @@ -70,6 +87,9 @@ int MPIR_Send_stream_impl(const void *buf, MPI_Aint count, MPI_Datatype datatype

/* ---- Recv stream ---- */
struct recv_data {
void *buf;
MPI_Aint count;
MPI_Datatype datatype;
int source;
int tag;
MPIR_Comm *comm_ptr;
Expand All @@ -85,8 +105,13 @@ static void recv_stream_cb(void *data)
MPIR_Request *request_ptr = NULL;

struct recv_data *p = data;
mpi_errno = MPID_Recv(p->host_buf, p->data_sz, MPI_BYTE, p->source, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, p->status, &request_ptr);
if (p->host_buf) {
mpi_errno = MPID_Recv(p->host_buf, p->data_sz, MPI_BYTE, p->source, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, p->status, &request_ptr);
} else {
mpi_errno = MPID_Recv(p->buf, p->count, p->datatype, p->source, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, p->status, &request_ptr);
}
assert(mpi_errno == MPI_SUCCESS);
assert(request_ptr != NULL);

Expand All @@ -95,6 +120,11 @@ static void recv_stream_cb(void *data)

MPIR_Request_extract_status(request_ptr, p->status);
MPIR_Request_free(request_ptr);

if (!p->host_buf) {
/* we are done */
MPL_free(p);
}
}

static void recv_stream_cleanup_cb(void *data)
Expand All @@ -116,24 +146,33 @@ int MPIR_Recv_stream_impl(void *buf, MPI_Aint count, MPI_Datatype datatype,
p = MPL_malloc(sizeof(struct recv_data), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem");

MPI_Aint dt_size;
MPIR_Datatype_get_size_macro(datatype, dt_size);
p->data_sz = dt_size * count;

MPIR_gpu_malloc_host(&p->host_buf, p->data_sz);

p->source = source;
p->tag = tag;
p->comm_ptr = comm_ptr;
p->status = status;

MPL_gpu_launch_hostfn(stream, recv_stream_cb, p);
if (MPIR_GPU_query_pointer_is_dev(buf)) {
MPI_Aint dt_size;
MPIR_Datatype_get_size_macro(datatype, dt_size);
p->data_sz = dt_size * count;

MPIR_gpu_malloc_host(&p->host_buf, p->data_sz);

MPL_gpu_launch_hostfn(stream, recv_stream_cb, p);

mpi_errno = MPIR_Typerep_unpack_stream(p->host_buf, p->data_sz, buf, count, datatype, 0,
&p->actual_unpack_bytes, stream);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPIR_Typerep_unpack_stream(p->host_buf, p->data_sz, buf, count, datatype, 0,
&p->actual_unpack_bytes, stream);
MPIR_ERR_CHECK(mpi_errno);
MPL_gpu_launch_hostfn(stream, recv_stream_cleanup_cb, p);
} else {
p->host_buf = NULL;
p->buf = buf;
p->count = count;
p->datatype = datatype;

MPL_gpu_launch_hostfn(stream, recv_stream_cleanup_cb, p);
MPL_gpu_launch_hostfn(stream, recv_stream_cb, p);
}

fn_exit:
return mpi_errno;
Expand Down

0 comments on commit 2830361

Please sign in to comment.