Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpix: gpu stream aware extensions #5905

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions maint/local_python/binding_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def is_pointer_type(param):
return 1
elif RE.match(r'(ATTRIBUTE_VAL\w*|(C_)?BUFFER\d?|EXTRA_STATE\d*|TOOL_MPI_OBJ|(POLY)?FUNCTION\w*)$', param['kind']):
return 1
elif RE.match(r'(GPU_STREAM)$', param['kind']):
return 1
elif param['param_direction'] != 'in':
return 1
elif param['length']:
Expand Down
2 changes: 2 additions & 0 deletions maint/local_python/binding_f77.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,8 @@ def check_func_directives(func):
func['_skip_fortran'] = 1
elif RE.match(r'mpix_grequest_', func['name'], re.IGNORECASE):
func['_skip_fortran'] = 1
elif RE.match(r'mpix_(Send|Recv)_enqueue', func['name'], re.IGNORECASE):
func['_skip_fortran'] = 1
elif RE.match(r'mpi_\w+_(f|f08|c)2(f|f08|c)$', func['name'], re.IGNORECASE):
# implemented in mpi_f08_types.f90
func['_skip_fortran'] = 1
Expand Down
21 changes: 21 additions & 0 deletions src/binding/c/pt2pt_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,24 @@ MPI_Isendrecv:

MPI_Isendrecv_replace:
.desc: Starts a nonblocking send and receive with a single buffer

MPIX_Send_enqueue:
.desc: Enqueue a send operation on a gpu stream
buf: BUFFER, constant=True, [initial address of send buffer]
count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
datatype: DATATYPE, [datatype of each send buffer element]
dest: RANK, [rank of destination]
tag: TAG, [message tag]
comm: COMMUNICATOR
stream: GPU_STREAM, [gpu stream to enqueue on]

MPIX_Recv_enqueue:
.desc: Enqueue a send operation on a gpu stream
buf: BUFFER, direction=out, [initial address of receive buffer]
count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer]
datatype: DATATYPE, [datatype of each receive buffer element]
source: RANK, [rank of source or MPI_ANY_SOURCE]
tag: TAG, [message tag or MPI_ANY_TAG]
comm: COMMUNICATOR
status: STATUS, direction=out
stream: GPU_STREAM, [gpu stream to enqueue on]
7 changes: 7 additions & 0 deletions src/binding/custom_mapping.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,35 @@

LIS_KIND_MAP:
GPU_TYPE: integer
GPU_STREAM: integer(kind=cuda_stream_kind)
GREQUEST_CLASS: None

SMALL_F90_KIND_MAP:
GPU_TYPE: INTEGER
GPU_STREAM: integer(kind=cuda_stream_kind)
GREQUEST_CLASS: INTEGER

BIG_F90_KIND_MAP:
GPU_TYPE: INTEGER
GPU_STREAM: integer(kind=cuda_stream_kind)
GREQUEST_CLASS: INTEGER

SMALL_F08_KIND_MAP:
GPU_TYPE: INTEGER
GPU_STREAM: integer(kind=cuda_stream_kind)
GREQUEST_CLASS: INTEGER

BIG_F08_KIND_MAP:
GPU_TYPE: INTEGER
GPU_STREAM: integer(kind=cuda_stream_kind)
GREQUEST_CLASS: INTEGER

SMALL_C_KIND_MAP:
GPU_TYPE: int
GPU_STREAM: void
GREQUEST_CLASS: MPIX_Grequest_class

BIG_C_KIND_MAP:
GPU_TYPE: int
GPU_STREAM: void
GREQUEST_CLASS: MPIX_Grequest_class
7 changes: 7 additions & 0 deletions src/include/mpir_typerep.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,11 @@ int MPIR_Typerep_op(void *source_buf, MPI_Aint source_count, MPI_Datatype source
bool source_is_packed, int mapped_device);
int MPIR_Typerep_reduce(const void *in_buf, void *out_buf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op);

int MPIR_Typerep_pack_stream(const void *inbuf, MPI_Aint incount, MPI_Datatype datatype,
MPI_Aint inoffset, void *outbuf, MPI_Aint max_pack_bytes,
MPI_Aint * actual_pack_bytes, void *stream);
int MPIR_Typerep_unpack_stream(const void *inbuf, MPI_Aint insize,
void *outbuf, MPI_Aint outcount, MPI_Datatype datatype,
MPI_Aint outoffset, MPI_Aint * actual_unpack_bytes, void *stream);
#endif /* MPIR_TYPEREP_H_INCLUDED */
22 changes: 22 additions & 0 deletions src/mpi/datatype/typerep/src/typerep_dataloop_pack.c
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,25 @@ int MPIR_Typerep_reduce(const void *in_buf, void *out_buf, MPI_Aint count, MPI_D

return mpi_errno;
}

int MPIR_Typerep_pack_stream(const void *inbuf, MPI_Aint incount, MPI_Datatype datatype,
MPI_Aint inoffset, void *outbuf, MPI_Aint max_pack_bytes,
MPI_Aint * actual_pack_bytes, void *stream)
{
int mpi_errno = MPI_SUCCESS;

MPIR_Assert(0);

return mpi_errno;
}

int MPIR_Typerep_unpack_stream(const void *inbuf, MPI_Aint insize, void *outbuf,
MPI_Aint outcount, MPI_Datatype datatype, MPI_Aint outoffset,
MPI_Aint * actual_unpack_bytes, void *stream)
{
int mpi_errno = MPI_SUCCESS;

MPIR_Assert(0);

return mpi_errno;
}
46 changes: 46 additions & 0 deletions src/mpi/datatype/typerep/src/typerep_yaksa_pack.c
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,49 @@ static int typerep_op_pack(void *source_buf, void *target_buf, MPI_Aint count,
fn_fail:
goto fn_exit;
}

int MPIR_Typerep_pack_stream(const void *inbuf, MPI_Aint incount, MPI_Datatype datatype,
MPI_Aint inoffset, void *outbuf, MPI_Aint max_pack_bytes,
MPI_Aint * actual_pack_bytes, void *stream)
{
MPIR_FUNC_ENTER;

int mpi_errno = MPI_SUCCESS;
int rc;

yaksa_type_t type = MPII_Typerep_get_yaksa_type(datatype);
uintptr_t packed_bytes;;
rc = yaksa_pack_stream(inbuf, incount, type, inoffset, outbuf, max_pack_bytes,
&packed_bytes, NULL, YAKSA_OP__REPLACE, stream);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
*actual_pack_bytes = packed_bytes;

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

int MPIR_Typerep_unpack_stream(const void *inbuf, MPI_Aint insize, void *outbuf,
MPI_Aint outcount, MPI_Datatype datatype, MPI_Aint outoffset,
MPI_Aint * actual_unpack_bytes, void *stream)
{
MPIR_FUNC_ENTER;

int mpi_errno = MPI_SUCCESS;
int rc;

yaksa_type_t type = MPII_Typerep_get_yaksa_type(datatype);
uintptr_t unpacked_bytes;;
rc = yaksa_unpack_stream(inbuf, insize, outbuf, outcount, type, outoffset,
&unpacked_bytes, NULL, YAKSA_OP__REPLACE, stream);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
*actual_unpack_bytes = unpacked_bytes;

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}
1 change: 1 addition & 0 deletions src/mpi/pt2pt/Makefile.mk
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@

mpi_core_sources += \
src/mpi/pt2pt/sendrecv.c \
src/mpi/pt2pt/stream.c \
src/mpi/pt2pt/bsendutil.c
181 changes: 181 additions & 0 deletions src/mpi/pt2pt/stream.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Copyright (C) by Argonne National Laboratory
* See COPYRIGHT in top-level directory
*/

#include "mpiimpl.h"

/* ---- Send stream ---- */
struct send_data {
const void *buf;
MPI_Aint count;
MPI_Datatype datatype;
int dest;
int tag;
MPIR_Comm *comm_ptr;
void *host_buf;
MPI_Aint data_sz;
MPI_Aint actual_pack_bytes;
};

static void send_stream_cb(void *data)
{
int mpi_errno;
MPIR_Request *request_ptr = NULL;

struct send_data *p = data;
if (p->host_buf) {
assert(p->actual_pack_bytes == p->data_sz);

mpi_errno = MPID_Send(p->host_buf, p->data_sz, MPI_BYTE, p->dest, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, &request_ptr);
} else {
mpi_errno = MPID_Send(p->buf, p->count, p->datatype, p->dest, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, &request_ptr);
}
assert(mpi_errno == MPI_SUCCESS);
assert(request_ptr != NULL);

mpi_errno = MPID_Wait(request_ptr, MPI_STATUS_IGNORE);
assert(mpi_errno == MPI_SUCCESS);

MPIR_Request_free(request_ptr);

if (p->host_buf) {
MPIR_gpu_free_host(p->host_buf);
}
MPL_free(data);
}

int MPIR_Send_enqueue_impl(const void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int tag, MPIR_Comm * comm_ptr, void *stream)
{
int mpi_errno = MPI_SUCCESS;

struct send_data *p;
p = MPL_malloc(sizeof(struct send_data), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem");

p->dest = dest;
p->tag = tag;
p->comm_ptr = comm_ptr;

if (MPIR_GPU_query_pointer_is_dev(buf)) {
MPI_Aint dt_size;
MPIR_Datatype_get_size_macro(datatype, dt_size);
p->data_sz = dt_size * count;

MPIR_gpu_malloc_host(&p->host_buf, p->data_sz);

mpi_errno = MPIR_Typerep_pack_stream(buf, count, datatype, 0, p->host_buf, p->data_sz,
&p->actual_pack_bytes, stream);
MPIR_ERR_CHECK(mpi_errno);
} else {
p->host_buf = NULL;
p->buf = buf;
p->count = count;
p->datatype = datatype;
}

MPL_gpu_launch_hostfn(stream, send_stream_cb, p);

fn_exit:
return mpi_errno;
fn_fail:
goto fn_exit;
}

/* ---- Recv stream ---- */
struct recv_data {
void *buf;
MPI_Aint count;
MPI_Datatype datatype;
int source;
int tag;
MPIR_Comm *comm_ptr;
MPI_Status *status;
void *host_buf;
MPI_Aint data_sz;
MPI_Aint actual_unpack_bytes;
};

static void recv_stream_cb(void *data)
{
int mpi_errno;
MPIR_Request *request_ptr = NULL;

struct recv_data *p = data;
if (p->host_buf) {
mpi_errno = MPID_Recv(p->host_buf, p->data_sz, MPI_BYTE, p->source, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, p->status, &request_ptr);
} else {
mpi_errno = MPID_Recv(p->buf, p->count, p->datatype, p->source, p->tag, p->comm_ptr,
MPIR_CONTEXT_INTRA_PT2PT, p->status, &request_ptr);
}
assert(mpi_errno == MPI_SUCCESS);
assert(request_ptr != NULL);

mpi_errno = MPID_Wait(request_ptr, MPI_STATUS_IGNORE);
assert(mpi_errno == MPI_SUCCESS);

MPIR_Request_extract_status(request_ptr, p->status);
MPIR_Request_free(request_ptr);

if (!p->host_buf) {
/* we are done */
MPL_free(p);
}
}

static void recv_stream_cleanup_cb(void *data)
{
struct recv_data *p = data;
assert(p->actual_unpack_bytes == p->data_sz);

MPIR_gpu_free_host(p->host_buf);
MPL_free(data);
}

int MPIR_Recv_enqueue_impl(void *buf, MPI_Aint count, MPI_Datatype datatype,
int source, int tag, MPIR_Comm * comm_ptr, MPI_Status * status,
void *stream)
{
int mpi_errno = MPI_SUCCESS;

struct recv_data *p;
p = MPL_malloc(sizeof(struct recv_data), MPL_MEM_OTHER);
MPIR_ERR_CHKANDJUMP(!p, mpi_errno, MPI_ERR_OTHER, "**nomem");

p->source = source;
p->tag = tag;
p->comm_ptr = comm_ptr;
p->status = status;

if (MPIR_GPU_query_pointer_is_dev(buf)) {
MPI_Aint dt_size;
MPIR_Datatype_get_size_macro(datatype, dt_size);
p->data_sz = dt_size * count;

MPIR_gpu_malloc_host(&p->host_buf, p->data_sz);

MPL_gpu_launch_hostfn(stream, recv_stream_cb, p);

mpi_errno = MPIR_Typerep_unpack_stream(p->host_buf, p->data_sz, buf, count, datatype, 0,
&p->actual_unpack_bytes, stream);
MPIR_ERR_CHECK(mpi_errno);

MPL_gpu_launch_hostfn(stream, recv_stream_cleanup_cb, p);
} else {
p->host_buf = NULL;
p->buf = buf;
p->count = count;
p->datatype = datatype;

MPL_gpu_launch_hostfn(stream, recv_stream_cb, p);
}

fn_exit:
return mpi_errno;
fn_fail:
goto fn_exit;
}
3 changes: 3 additions & 0 deletions src/mpl/include/mpl_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,7 @@ int MPL_gpu_get_buffer_bounds(const void *ptr, void **pbase, uintptr_t * len);
int MPL_gpu_free_hook_register(void (*free_hook) (void *dptr));
int MPL_gpu_get_dev_count(int *dev_cnt, int *dev_id);

typedef void (*MPL_gpu_hostfn) (void *data);
int MPL_gpu_launch_hostfn(void *stream, MPL_gpu_hostfn fn, void *data);

#endif /* ifndef MPL_GPU_H_INCLUDED */
7 changes: 7 additions & 0 deletions src/mpl/src/gpu/mpl_gpu_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,10 @@ cudaError_t CUDARTAPI cudaFree(void *dptr)
result = sys_cudaFree(dptr);
return result;
}

int MPL_gpu_launch_hostfn(void *stream, MPL_gpu_hostfn fn, void *data)
{
cudaError_t result;
result = cudaLaunchHostFunc(*(cudaStream_t *) stream, fn, data);
return result;
}
5 changes: 5 additions & 0 deletions src/mpl/src/gpu/mpl_gpu_fallback.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,8 @@ int MPL_gpu_free_hook_register(void (*free_hook) (void *dptr))
{
return MPL_SUCCESS;
}

int MPL_gpu_launch_hostfn(void *stream, MPL_gpu_hostfn fn, void *data)
{
return -1;
}
6 changes: 6 additions & 0 deletions src/mpl/src/gpu/mpl_gpu_hip.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,10 @@ hipError_t hipFree(void *dptr)
result = sys_hipFree(dptr);
return result;
}

int MPL_gpu_launch_hostfn(void *stream, MPL_gpu_hostfn fn, void *data)
{
return -1;
}

#endif /* MPL_HAVE_HIP */
Loading