Skip to content

Commit

Permalink
Merge pull request #5946 from hzhou/2204_stream_multiplex
Browse files Browse the repository at this point in the history
stream: implement multiplex stream comm

Approved-by: Ken Raffenetti
  • Loading branch information
hzhou authored May 2, 2022
2 parents 6b09399 + 221285f commit 650ca82
Show file tree
Hide file tree
Showing 25 changed files with 454 additions and 34 deletions.
1 change: 1 addition & 0 deletions dummy
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1
10 changes: 5 additions & 5 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def process_func_parameters(func):
if kind == "REQUEST":
if RE.match(r'mpi_startall', func_name, re.IGNORECASE):
do_handle_ptr = 3
elif RE.match(r'mpi_(wait|test)', func_name, re.IGNORECASE):
elif RE.match(r'mpix?_(wait|test)', func_name, re.IGNORECASE):
do_handle_ptr = 3
elif kind == "RANK":
validation_list.append({'kind': "RANK-ARRAY", 'name': name})
Expand Down Expand Up @@ -588,7 +588,7 @@ def process_func_parameters(func):
do_handle_ptr = 1
if kind == "INFO" and not RE.match(r'mpi_(info_.*|.*_set_info)$', func_name, re.IGNORECASE):
p['can_be_null'] = "MPI_INFO_NULL"
elif kind == "REQUEST" and RE.match(r'mpi_(wait|test|request_get_status|parrived)', func_name, re.IGNORECASE):
elif kind == "REQUEST" and RE.match(r'mpix?_(wait|test|request_get_status|parrived)', func_name, re.IGNORECASE):
p['can_be_null'] = "MPI_REQUEST_NULL"
elif kind == "STREAM" and RE.match(r'mpix?_stream_comm_create', func_name, re.IGNORECASE):
p['can_be_null'] = "MPIX_STREAM_NULL"
Expand Down Expand Up @@ -1674,7 +1674,7 @@ def dump_early_return_pt2pt_proc_null(func):
def dump_handle_ptr_var(func, p):
(kind, name) = (p['kind'], p['name'])
if kind == "REQUEST" and p['length']:
if RE.match(r'mpi_(test|wait)all', func['name'], re.IGNORECASE):
if RE.match(r'mpix?_(test|wait)all', func['name'], re.IGNORECASE):
# FIXME:we do not convert pointers for MPI_Testall and MPI_Waitall
# (for performance reasons). This probably this can be changed
pass
Expand Down Expand Up @@ -1748,7 +1748,7 @@ def dump_convert_handle(func, p):
name = "*" + p['name']

if kind == "REQUEST" and p['length']:
if RE.match(r'mpi_(test|wait)all', func['name'], re.IGNORECASE):
if RE.match(r'mpix?_(test|wait)all', func['name'], re.IGNORECASE):
# We do not convert pointers for MPI_Testall and MPI_Waitall
pass
else:
Expand Down Expand Up @@ -1792,7 +1792,7 @@ def dump_validate_handle_ptr(func, p):
mpir = G.handle_mpir_types[kind]
if kind == "REQUEST" and p['length']:
G.err_codes['MPI_ERR_REQUEST'] = 1
if RE.match(r'mpi_(test|wait)all', func['name'], re.IGNORECASE):
if RE.match(r'mpix?_(test|wait)all', func['name'], re.IGNORECASE):
# MPI_Testall and MPI_Waitall do pointer conversion inside MPIR_{Test,Wait}all
pass
else:
Expand Down
8 changes: 4 additions & 4 deletions maint/local_python/binding_f08.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def process_status(p):
arg_1 = ":STATUS:"
arg_2 = ":STATUS:"
length = p['_array_length']
if RE.match(r'mpi_(test|wait)some', func['name'], re.IGNORECASE):
if RE.match(r'mpix?_(test|wait)some', func['name'], re.IGNORECASE):
length = "outcount_c"
p['_status_convert'] = "%s(1:%s) = %s_c(1:%s)" % (p['name'], length, p['name'], length)
else:
Expand Down Expand Up @@ -571,7 +571,7 @@ def process_array(p):
if RE.match(r'in|inout', p['param_direction']):
convert_list_1.append("%s = %s" % (argv_2, argv_1))
if RE.match(r'out|inout', p['param_direction']):
if RE.match(r'mpi_(test|wait)some', func['name'], re.IGNORECASE) and p['name'] == "array_of_indices":
if RE.match(r'mpix?_(test|wait)some', func['name'], re.IGNORECASE) and p['name'] == "array_of_indices":
argv_1 = "array_of_indices(1:outcount_c)"
argv_2 = "array_of_indices_c(1:outcount_c)"
convert_list_2.append("%s = %s" % (argv_1, argv_2))
Expand Down Expand Up @@ -1495,9 +1495,9 @@ def get_string():
def get_array_decl():
# Arrays: we'll use assumptions (since only with limited num of functions)
length = get_F_decl_length(p)
if RE.match(r'mpi_(Test|Wait)(all|any)', func['name'], re.IGNORECASE):
if RE.match(r'mpix?_(Test|Wait)(all|any)', func['name'], re.IGNORECASE):
length = 'count'
elif RE.match(r'mpi_(Test|Wait)(some)', func['name'], re.IGNORECASE):
elif RE.match(r'mpix?_(Test|Wait)(some)', func['name'], re.IGNORECASE):
length = 'incount'
elif RE.match(r'mpi_cart_(rank|sub)', func['name'], re.IGNORECASE):
length = 'cart_dim'
Expand Down
4 changes: 2 additions & 2 deletions maint/local_python/binding_f77.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,9 @@ def process_func_parameters():
if p['param_direction'] == 'out':
if p['length'] is None:
dump_status(p['name'], False, True)
elif RE.match(r'mpi_(wait|test)all', func['name'], re.IGNORECASE):
elif RE.match(r'mpix?_(wait|test)all', func['name'], re.IGNORECASE):
dump_statuses(p['name'], "(*count)", "(*count)", False, True)
elif RE.match(r'mpi_(wait|test)some', func['name'], re.IGNORECASE):
elif RE.match(r'mpix?_(wait|test)some', func['name'], re.IGNORECASE):
dump_statuses(p['name'], "(*incount)", "(*outcount)", False, True)
else:
raise Exception("Unhandled: %s - %s" % (func['name'], p['name']))
Expand Down
136 changes: 134 additions & 2 deletions src/binding/c/stream_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,137 @@ MPIX_Stream_comm_create:
stream: STREAM, [stream object]
newcomm: COMMUNICATOR, direction=out, [new stream-associated communicator]

MPIX_Stream_comm_create_multiplex:
comm: COMMUNICATOR, [communicator]
count: ARRAY_LENGTH_NNI, [list length]
array_of_streams: STREAM, length=count, [stream object array]
newcomm: COMMUNICATOR, direction=out, [new stream-associated communicator]

MPIX_Stream_send:
buf: BUFFER, constant=True, [initial address of send buffer]
count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
datatype: DATATYPE, [datatype of each send buffer element]
dest: RANK, [rank of destination]
tag: TAG, [message tag]
comm: COMMUNICATOR
source_stream_index: INDEX
dest_stream_index: INDEX
{
MPIR_Request *request_ptr = NULL;

int attr;
mpi_errno = MPIR_Stream_comm_set_attr(comm_ptr, comm_ptr->rank, dest,
source_stream_index, dest_stream_index, &attr);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPID_Send(buf, count, datatype, dest, tag, comm_ptr, attr, &request_ptr);
MPIR_ERR_CHECK(mpi_errno);

if (request_ptr == NULL) {
goto fn_exit;
}

mpi_errno = MPID_Wait(request_ptr, MPI_STATUS_IGNORE);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = request_ptr->status.MPI_ERROR;
MPIR_Request_free(request_ptr);

MPIR_ERR_CHECK(mpi_errno);
}

MPIX_Stream_isend:
buf: BUFFER, asynchronous=True, constant=True, [initial address of send buffer]
count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
datatype: DATATYPE, [datatype of each send buffer element]
dest: RANK, [rank of destination]
tag: TAG, [message tag]
comm: COMMUNICATOR
source_stream_index: INDEX
dest_stream_index: INDEX
request: REQUEST, direction=out
{
MPIR_Request *request_ptr = NULL;

int attr;
mpi_errno = MPIR_Stream_comm_set_attr(comm_ptr, comm_ptr->rank, dest,
source_stream_index, dest_stream_index, &attr);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPID_Isend(buf, count, datatype, dest, tag, comm_ptr,
attr, &request_ptr);
MPIR_ERR_CHECK(mpi_errno);

MPII_SENDQ_REMEMBER(request_ptr, dest, tag, comm_ptr->context_id, buf, count);

/* return the handle of the request to the user */
/* MPIU_OBJ_HANDLE_PUBLISH is unnecessary for isend, lower-level access is
* responsible for its own consistency, while upper-level field access is
* controlled by the completion counter */
*request = request_ptr->handle;
}

MPIX_Stream_recv:
buf: BUFFER, direction=out, [initial address of receive buffer]
count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer]
datatype: DATATYPE, [datatype of each receive buffer element]
source: RANK, [rank of source or MPI_ANY_SOURCE]
tag: TAG, [message tag or MPI_ANY_TAG]
comm: COMMUNICATOR
source_stream_index: INDEX
dest_stream_index: INDEX
status: STATUS, direction=out
{
MPIR_Request *request_ptr = NULL;

int attr;
mpi_errno = MPIR_Stream_comm_set_attr(comm_ptr, source, comm_ptr->rank,
source_stream_index, dest_stream_index, &attr);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPID_Recv(buf, count, datatype, source, tag, comm_ptr,
attr, status, &request_ptr);
MPIR_ERR_CHECK(mpi_errno);

if (request_ptr == NULL) {
goto fn_exit;
}

mpi_errno = MPID_Wait(request_ptr, MPI_STATUS_IGNORE);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = request_ptr->status.MPI_ERROR;
MPIR_Request_extract_status(request_ptr, status);
MPIR_Request_free(request_ptr);

MPIR_ERR_CHECK(mpi_errno);
}

MPIX_Stream_irecv:
buf: BUFFER, direction=out, asynchronous=True, suppress=f08_intent, [initial address of receive buffer]
count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer]
datatype: DATATYPE, [datatype of each receive buffer element]
source: RANK, [rank of source or MPI_ANY_SOURCE]
tag: TAG, [message tag or MPI_ANY_TAG]
comm: COMMUNICATOR
source_stream_index: INDEX
dest_stream_index: INDEX
request: REQUEST, direction=out
{
MPIR_Request *request_ptr = NULL;

int attr;
mpi_errno = MPIR_Stream_comm_set_attr(comm_ptr, source, comm_ptr->rank,
source_stream_index, dest_stream_index, &attr);
MPIR_ERR_CHECK(mpi_errno);

mpi_errno = MPID_Irecv(buf, count, datatype, source, tag, comm_ptr,
attr, &request_ptr);

*request = request_ptr->handle;
MPIR_ERR_CHECK(mpi_errno);
}

MPIX_Send_enqueue:
buf: BUFFER, constant=True, [initial address of send buffer]
count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
Expand Down Expand Up @@ -44,14 +175,15 @@ MPIX_Irecv_enqueue:
datatype: DATATYPE, [datatype of each receive buffer element]
source: RANK, [rank of source or MPI_ANY_SOURCE]
tag: TAG, [message tag or MPI_ANY_TAG]

comm: COMMUNICATOR
request: REQUEST, direction=out

MPI_Wait_enqueue:
MPIX_Wait_enqueue:
request: REQUEST, direction=inout, [request]
status: STATUS, direction=out

MPI_Waitall_enqueue:
MPIX_Waitall_enqueue:
count: ARRAY_LENGTH_NNI, [lists length]
array_of_requests: REQUEST, direction=inout, length=count, [array of requests]
array_of_statuses: STATUS, direction=out, length=*, pointer=False, [array of status objects]
Expand Down
2 changes: 1 addition & 1 deletion src/include/mpiimpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ typedef struct MPIR_Stream MPIR_Stream;
/*****************************************************************************/

#include "mpir_thread.h" /* come first as mutexes are often depended on, e.g. request */
#include "mpir_err.h"
#include "mpir_attr.h"
#include "mpir_group.h"
#include "mpir_comm.h"
Expand All @@ -211,7 +212,6 @@ typedef struct MPIR_Stream MPIR_Stream;
#include "mpir_coll.h"
#include "mpir_csel.h"
#include "mpir_func.h"
#include "mpir_err.h"
#include "mpir_nbc.h"
#include "mpir_bsend.h"
#include "mpir_process.h"
Expand Down
33 changes: 31 additions & 2 deletions src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ struct MPIR_Comm {
} single;
struct {
struct MPIR_Stream **local_streams;
int *vci_displs;
int *vci_table;
MPI_Aint *vci_displs; /* comm size + 1 */
int *vci_table; /* comm size */
} multiplex;
} stream_comm;

Expand Down Expand Up @@ -300,6 +300,35 @@ static inline int MPIR_Comm_release(MPIR_Comm * comm_ptr)
return mpi_errno;
}

MPL_STATIC_INLINE_PREFIX int MPIR_Stream_comm_set_attr(MPIR_Comm * comm, int src_rank, int dst_rank,
int src_index, int dst_index, int *attr_out)
{
int mpi_errno = MPI_SUCCESS;

MPIR_ERR_CHKANDJUMP(comm->stream_comm_type != MPIR_STREAM_COMM_MULTIPLEX,
mpi_errno, MPI_ERR_OTHER, "**streamcomm_notmult");

MPI_Aint *displs = comm->stream_comm.multiplex.vci_displs;

MPIR_ERR_CHKANDJUMP(displs[src_rank] + src_index >= displs[src_rank + 1],
mpi_errno, MPI_ERR_OTHER, "**streamcomm_srcidx");
MPIR_ERR_CHKANDJUMP(displs[dst_rank] + dst_index >= displs[dst_rank + 1],
mpi_errno, MPI_ERR_OTHER, "**streamcomm_dstidx");

int src_vci = comm->stream_comm.multiplex.vci_table[displs[src_rank] + src_index];
int dst_vci = comm->stream_comm.multiplex.vci_table[displs[src_rank] + dst_index];

int attr = MPIR_CONTEXT_INTRA_PT2PT;
MPIR_PT2PT_ATTR_SET_VCIS(attr, src_vci, dst_vci);

*attr_out = attr;

fn_exit:
return mpi_errno;
fn_fail:
goto fn_exit;
}


/* MPIR_Comm_release_always is the same as MPIR_Comm_release except it uses
MPIR_Comm_release_ref_always instead.
Expand Down
3 changes: 3 additions & 0 deletions src/mpi/errhan/errnames.txt
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,9 @@ is too big (> MPIU_SHMW_GHND_SZ)
**missinggpustream:Info hint 'type' is set, but info hint 'value' is missing.
**invalidgpustream:Info hint 'type' is set, but info hint 'value' is invalid.
**notgpustream:The communicator does not have a local gpu stream attached.
**streamcomm_notmult:The communicator is not a multiplex stream communicator.
**streamcomm_srcidx:The source stream index exceeds the number of streams.
**streamcomm_dstidx:The destination stream index exceeds the number of streams.

# -----------------------------------------------------------------------------
# The following names are defined but not used (see the -careful option
Expand Down
Loading

0 comments on commit 650ca82

Please sign in to comment.