Skip to content

Commit

Permalink
ch4: match num_vnis/vsis to MPIDI_glboal.n_total_vcis
Browse files Browse the repository at this point in the history
Initialize netmod to support MPIDI_glboal.n_total_vcis.

Now that all netmod and shmmod support multiple vcis, it is simpler to
move the mod logic into the ch4-layer hashing functions. Netmod still
can add another mod or simply overwrite the vci if it doesn't support
multiple vci or support less number of vcis. For now, we remove them for
cleaner code.

MPIR_CVAR_CH4_OFI_MAX_VNIS and MPIR_CVAR_CH4_UCX_MAX_VNIS are removed
since we can't have arbitrary vnis anyway.

Moving the mod into hashing functions allows implementing the reserved
vci logic.
  • Loading branch information
hzhou committed Apr 15, 2022
1 parent cd9c4b0 commit 4168acb
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 80 deletions.
4 changes: 3 additions & 1 deletion src/mpid/ch4/netmod/ofi/ofi_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_get_vni(int flag, MPIR_Comm * comm_ptr,
#if MPIDI_CH4_MAX_VCIS == 1
return 0;
#else
return MPIDI_get_vci(flag, comm_ptr, src_rank, dst_rank, tag) % MPIDI_OFI_global.num_vnis;
int vni = MPIDI_get_vci(flag, comm_ptr, src_rank, dst_rank, tag);
MPIR_Assert(vni < MPIDI_OFI_global.num_vnis);
return vni;
#endif
}

Expand Down
30 changes: 3 additions & 27 deletions src/mpid/ch4/netmod/ofi/ofi_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,6 @@ categories :
minor version of the OFI library used with MPICH. If using this CVAR,
it is recommended that the user also specifies a specific OFI provider.
- name : MPIR_CVAR_CH4_OFI_MAX_VNIS
category : CH4_OFI
type : int
default : 0
class : none
verbosity : MPI_T_VERBOSITY_USER_BASIC
scope : MPI_T_SCOPE_LOCAL
description : >-
If set to positive, this CVAR specifies the maximum number of CH4 VNIs
that OFI netmod exposes. If set to 0 (the default) or bigger than
MPIR_CVAR_CH4_NUM_VCIS, the number of exposed VNIs is set to MPIR_CVAR_CH4_NUM_VCIS.
- name : MPIR_CVAR_CH4_OFI_MAX_RMA_SEP_CTX
category : CH4_OFI
type : int
Expand Down Expand Up @@ -613,22 +601,10 @@ int MPIDI_OFI_init_local(int *tag_bits)
/* Create transport level communication contexts. */
/* ------------------------------------------------------------------------ */

int num_vnis = 1;
if (MPIR_CVAR_CH4_OFI_MAX_VNIS == 0 || MPIR_CVAR_CH4_OFI_MAX_VNIS > MPIDI_global.n_vcis) {
num_vnis = MPIDI_global.n_vcis;
} else {
num_vnis = MPIR_CVAR_CH4_OFI_MAX_VNIS;
}

/* TODO: update num_vnis according to provider capabilities, such as
* prov_use->domain_attr->{tx,rx}_ctx_cnt
/* TODO: check provider capabilities, such as prov_use->domain_attr->{tx,rx}_ctx_cnt,
* abort if we can't support the requested number of vnis.
*/
if (num_vnis > MPIDI_OFI_MAX_VNIS) {
num_vnis = MPIDI_OFI_MAX_VNIS;
}
/* for best performance, we ensure 1-to-1 vci/vni mapping. ref: MPIDI_OFI_vci_to_vni */
/* TODO: allow less num_vnis. Option 1. runtime MOD; 2. override MPIDI_global.n_vcis */
MPIR_Assert(num_vnis == MPIDI_global.n_vcis);
int num_vnis = MPIDI_global.n_total_vcis;

/* Multiple vni without using domain require MPIDI_OFI_ENABLE_SCALABLE_ENDPOINTS */
#ifndef MPIDI_OFI_VNI_USE_DOMAIN
Expand Down
2 changes: 1 addition & 1 deletion src/mpid/ch4/netmod/ofi/ofi_win.c
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ static int win_set_per_win_sync(MPIR_Win * win)

static void win_init_am(MPIR_Win * win)
{
MPIDI_WIN(win, am_vci) %= MPIDI_OFI_global.num_vnis;
MPIR_Assert(MPIDI_WIN(win, am_vci) < MPIDI_OFI_global.num_vnis);
}

/*
Expand Down
12 changes: 8 additions & 4 deletions src/mpid/ch4/netmod/ucx/ucx_am.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_am_isend(int rank,

MPIR_FUNC_ENTER;

int src_vni = src_vci % MPIDI_UCX_global.num_vnis;
int dst_vni = dst_vci % MPIDI_UCX_global.num_vnis;
int src_vni = src_vci;
int dst_vni = dst_vci;
MPIR_Assert(src_vni < MPIDI_UCX_global.num_vnis);
MPIR_Assert(dst_vni < MPIDI_UCX_global.num_vnis);
ep = MPIDI_UCX_COMM_TO_EP(comm, rank, src_vni, dst_vni);

int dt_contig;
Expand Down Expand Up @@ -186,8 +188,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_NM_am_send_hdr(int rank,

MPIR_FUNC_ENTER;

int src_vni = src_vci % MPIDI_UCX_global.num_vnis;
int dst_vni = dst_vci % MPIDI_UCX_global.num_vnis;
int src_vni = src_vci;
int dst_vni = dst_vci;
MPIR_Assert(src_vni < MPIDI_UCX_global.num_vnis);
MPIR_Assert(dst_vni < MPIDI_UCX_global.num_vnis);
ep = MPIDI_UCX_COMM_TO_EP(comm, rank, src_vni, dst_vni);

/* initialize our portion of the hdr */
Expand Down
5 changes: 4 additions & 1 deletion src/mpid/ch4/netmod/ucx/ucx_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_vci_to_vni(int vci)
MPL_STATIC_INLINE_PREFIX int MPIDI_UCX_get_vni(int flag, MPIR_Comm * comm_ptr,
int src_rank, int dst_rank, int tag)
{
return MPIDI_get_vci(flag, comm_ptr, src_rank, dst_rank, tag) % MPIDI_UCX_global.num_vnis;
int vni;
return MPIDI_get_vci(flag, comm_ptr, src_rank, dst_rank, tag);
MPIR_Assert(vni < MPIDI_UCX_global.num_vnis);
return vni;
}

/* for rma, we need ensure rkey is consistent with the per-vni ep,
Expand Down
37 changes: 2 additions & 35 deletions src/mpid/ch4/netmod/ucx/ucx_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,6 @@
#include "mpidu_bc.h"
#include <ucp/api/ucp.h>

/*
=== BEGIN_MPI_T_CVAR_INFO_BLOCK ===
categories :
- name : CH4_UCX
description : A category for CH4 UCX netmod variables
cvars:
- name : MPIR_CVAR_CH4_UCX_MAX_VNIS
category : CH4_UCX
type : int
default : 0
class : none
verbosity : MPI_T_VERBOSITY_USER_BASIC
scope : MPI_T_SCOPE_LOCAL
description : >-
If set to positive, this CVAR specifies the maximum number of CH4 VNIs
that UCX netmod exposes. If set to 0 (the default) or bigger than
MPIR_CVAR_CH4_NUM_VCIS, the number of exposed VNIs is set to MPIR_CVAR_CH4_NUM_VCIS.
=== END_MPI_T_CVAR_INFO_BLOCK ===
*/

static void request_init_callback(void *request);

static void request_init_callback(void *request)
Expand All @@ -43,18 +20,8 @@ static void request_init_callback(void *request)

static void init_num_vnis(void)
{
int num_vnis = 1;
if (MPIR_CVAR_CH4_UCX_MAX_VNIS == 0 || MPIR_CVAR_CH4_UCX_MAX_VNIS > MPIDI_global.n_vcis) {
num_vnis = MPIDI_global.n_vcis;
} else {
num_vnis = MPIR_CVAR_CH4_UCX_MAX_VNIS;
}

/* for best performance, we ensure 1-to-1 vci/vni mapping. ref: MPIDI_OFI_vci_to_vni */
/* TODO: allow less num_vnis. Option 1. runtime MOD; 2. override MPIDI_global.n_vcis */
MPIR_Assert(num_vnis == MPIDI_global.n_vcis);

MPIDI_UCX_global.num_vnis = num_vnis;
/* TODO: check capabilities, abort if we can't support the requested number of vnis. */
MPIDI_UCX_global.num_vnis = MPIDI_global.n_total_vcis;
}

static int init_worker(int vni)
Expand Down
2 changes: 1 addition & 1 deletion src/mpid/ch4/netmod/ucx/ucx_win.c
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ static int win_init(MPIR_Win * win)
int mpi_errno = MPI_SUCCESS;
MPIR_FUNC_ENTER;

MPIDI_WIN(win, am_vci) %= MPIDI_UCX_global.num_vnis;
MPIR_Assert(MPIDI_WIN(win, am_vci) < MPIDI_UCX_global.num_vnis);

memset(&MPIDI_UCX_WIN(win), 0, sizeof(MPIDI_UCX_win_t));

Expand Down
17 changes: 7 additions & 10 deletions src/mpid/ch4/src/ch4_vci.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@
#define MPIDI_Request_get_vci(req) MPIR_REQUEST_POOL(req)
#define MPIDI_VCI_INVALID (-1)

/* VCI hashing function (fast path)
* NOTE: The returned vci should always MOD NUMVCIS, where NUMVCIS is
* the number of VCIs determined at init time
* Potentially, we'd like to make it config constants of power of 2
* TODO: move the MOD here.
*/
/* VCI hashing function (fast path) */

/* For consistent hashing, we may need differentiate between src and dst vci and whether
* it is being called from sender side or receiver side (consdier intercomm). We use an
Expand Down Expand Up @@ -47,7 +42,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_get_vci(int flag, MPIR_Comm * comm_ptr,
MPL_STATIC_INLINE_PREFIX int MPIDI_get_vci(int flag, MPIR_Comm * comm_ptr,
int src_rank, int dst_rank, int tag)
{
return comm_ptr->seq;
return comm_ptr->seq % MPIDI_global.n_vcis;
}

#elif MPIDI_CH4_VCI_METHOD == MPICH_VCI__TAG
Expand All @@ -59,13 +54,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_get_vci(int flag, MPIR_Comm * comm_ptr,
MPL_STATIC_INLINE_PREFIX int MPIDI_get_vci(int flag, MPIR_Comm * comm_ptr,
int src_rank, int dst_rank, int tag)
{
int vci;
if (!(flag & 0x1)) {
/* src */
return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 10) & 0x1f);
vci = (tag == MPI_ANY_TAG) ? 0 : ((tag >> 10) & 0x1f);
} else {
/* dst */
return (tag == MPI_ANY_TAG) ? 0 : ((tag >> 5) & 0x1f);
vci = (tag == MPI_ANY_TAG) ? 0 : ((tag >> 5) & 0x1f);
}
return vci % MPIDI_global.n_vcis;
}

#elif MPIDI_CH4_VCI_METHOD == MPICH_VCI__IMPLICIT
Expand Down Expand Up @@ -183,7 +180,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_get_receiver_vci(MPIR_Comm * comm,
if (is_vci_restricted_to_zero(comm)) {
vci_idx = 0;
} else if (use_user_defined_vci) {
vci_idx = comm->hints[MPIR_COMM_HINT_RECEIVER_VCI];
vci_idx = comm->hints[MPIR_COMM_HINT_RECEIVER_VCI] % MPIDI_global.n_vcis;
} else {
/* If mpi_any_tag and mpi_any_source can be used for recv, all messages
* should be received on a single vci. Otherwise, messages sent from a
Expand Down

0 comments on commit 4168acb

Please sign in to comment.