Skip to content

Commit

Permalink
Merge pull request #12376 from wenduwan/han_gatherv
Browse files Browse the repository at this point in the history
Implement hierarchical MPI_Gatherv and MPI_Scatterv
  • Loading branch information
wenduwan authored Mar 23, 2024
2 parents 0353f7e + 2152b61 commit 984944d
Show file tree
Hide file tree
Showing 11 changed files with 1,184 additions and 7 deletions.
5 changes: 4 additions & 1 deletion ompi/mca/coll/han/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ coll_han_barrier.c \
coll_han_bcast.c \
coll_han_reduce.c \
coll_han_scatter.c \
coll_han_scatterv.c \
coll_han_gather.c \
coll_han_gatherv.c \
coll_han_allreduce.c \
coll_han_allgather.c \
coll_han_component.c \
Expand All @@ -31,7 +33,8 @@ coll_han_algorithms.c \
coll_han_dynamic.c \
coll_han_dynamic_file.c \
coll_han_topo.c \
coll_han_subcomms.c
coll_han_subcomms.c \
coll_han_utils.c

# Make the output library in this directory, and name it either
# mca_<type>_<name>.la (for DSO builds) or libmca_<type>_<name>.la
Expand Down
48 changes: 44 additions & 4 deletions ompi/mca/coll/han/coll_han.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* reserved.
* Copyright (c) 2022 IBM Corporation. All rights reserved
* Copyright (c) 2020-2022 Bull S.A.S. All rights reserved.
* Copyright (c) Amazon.com, Inc. or its affiliates.
* All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand Down Expand Up @@ -189,7 +191,9 @@ typedef struct mca_coll_han_op_module_name_t {
mca_coll_han_op_up_low_module_name_t allreduce;
mca_coll_han_op_up_low_module_name_t allgather;
mca_coll_han_op_up_low_module_name_t gather;
mca_coll_han_op_up_low_module_name_t gatherv;
mca_coll_han_op_up_low_module_name_t scatter;
mca_coll_han_op_up_low_module_name_t scatterv;
} mca_coll_han_op_module_name_t;

/**
Expand Down Expand Up @@ -233,10 +237,18 @@ typedef struct mca_coll_han_component_t {
uint32_t han_gather_up_module;
/* low level module for gather */
uint32_t han_gather_low_module;
/* up level module for gatherv */
uint32_t han_gatherv_up_module;
/* low level module for gatherv */
uint32_t han_gatherv_low_module;
/* up level module for scatter */
uint32_t han_scatter_up_module;
/* low level module for scatter */
uint32_t han_scatter_low_module;
/* up level module for scatterv */
uint32_t han_scatterv_up_module;
/* low level module for scatterv */
uint32_t han_scatterv_low_module;
/* name of the modules */
mca_coll_han_op_module_name_t han_op_module_name;
/* whether we need reproducible results
Expand Down Expand Up @@ -277,8 +289,10 @@ typedef struct mca_coll_han_single_collective_fallback_s {
mca_coll_base_module_barrier_fn_t barrier;
mca_coll_base_module_bcast_fn_t bcast;
mca_coll_base_module_gather_fn_t gather;
mca_coll_base_module_gatherv_fn_t gatherv;
mca_coll_base_module_reduce_fn_t reduce;
mca_coll_base_module_scatter_fn_t scatter;
mca_coll_base_module_scatterv_fn_t scatterv;
} module_fn;
mca_coll_base_module_t* module;
} mca_coll_han_single_collective_fallback_t;
Expand All @@ -296,7 +310,9 @@ typedef struct mca_coll_han_collectives_fallback_s {
mca_coll_han_single_collective_fallback_t bcast;
mca_coll_han_single_collective_fallback_t reduce;
mca_coll_han_single_collective_fallback_t gather;
mca_coll_han_single_collective_fallback_t gatherv;
mca_coll_han_single_collective_fallback_t scatter;
mca_coll_han_single_collective_fallback_t scatterv;
} mca_coll_han_collectives_fallback_t;

/** Coll han module */
Expand Down Expand Up @@ -369,9 +385,14 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
#define previous_gather fallback.gather.module_fn.gather
#define previous_gather_module fallback.gather.module

#define previous_gatherv fallback.gatherv.module_fn.gatherv
#define previous_gatherv_module fallback.gatherv.module

#define previous_scatter fallback.scatter.module_fn.scatter
#define previous_scatter_module fallback.scatter.module

#define previous_scatterv fallback.scatterv.module_fn.scatterv
#define previous_scatterv_module fallback.scatterv.module

/* macro to correctly load a fallback collective module */
#define HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, COLL) \
Expand All @@ -391,7 +412,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, barrier); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatterv); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allreduce); \
HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgather); \
Expand Down Expand Up @@ -432,11 +455,16 @@ int *mca_coll_han_topo_init(struct ompi_communicator_t *comm, mca_coll_han_modul

/* Utils */
static inline void
mca_coll_han_get_ranks(int *vranks, int root, int low_size,
int *root_low_rank, int *root_up_rank)
mca_coll_han_get_ranks(int *vranks, int w_rank, int low_size,
int *low_rank, int *up_rank)
{
*root_up_rank = vranks[root] / low_size;
*root_low_rank = vranks[root] % low_size;
if (up_rank) {
*up_rank = vranks[w_rank] / low_size;
}

if (low_rank) {
*low_rank = vranks[w_rank] % low_size;
}
}

const char* mca_coll_han_topo_lvl_to_str(TOPO_LVL_T topo_lvl);
Expand Down Expand Up @@ -469,11 +497,17 @@ int
mca_coll_han_gather_intra_dynamic(GATHER_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_gatherv_intra_dynamic(GATHERV_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_reduce_intra_dynamic(REDUCE_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_scatter_intra_dynamic(SCATTER_BASE_ARGS,
mca_coll_base_module_t *module);
int
mca_coll_han_scatterv_intra_dynamic(SCATTERV_BASE_ARGS,
mca_coll_base_module_t *module);

int mca_coll_han_barrier_intra_simple(struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);
Expand All @@ -486,4 +520,10 @@ ompi_coll_han_reorder_gather(const void *sbuf,
struct ompi_communicator_t *comm,
int * topo);

size_t
coll_han_utils_gcd(const size_t *numerators, const size_t size);

int
coll_han_utils_create_contiguous_datatype(size_t count, const ompi_datatype_t *oldType,
ompi_datatype_t **newType);
#endif /* MCA_COLL_HAN_EXPORT_H */
8 changes: 8 additions & 0 deletions ompi/mca/coll/han/coll_han_algorithms.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,19 @@ mca_coll_han_algorithm_value_t* mca_coll_han_available_algorithms[COLLCOUNT] =
{"simple", (fnptr_t) &mca_coll_han_scatter_intra_simple}, // 2-level
{ 0 }
},
[SCATTERV] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t) &mca_coll_han_scatterv_intra}, // 2-level
{ 0 }
},
[GATHER] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t) &mca_coll_han_gather_intra}, // 2-level
{"simple", (fnptr_t) &mca_coll_han_gather_intra_simple}, // 2-level
{ 0 }
},
[GATHERV] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t) &mca_coll_han_gatherv_intra}, // 2-level
{ 0 }
},
[ALLGATHER] = (mca_coll_han_algorithm_value_t[]){
{"intra", (fnptr_t)&mca_coll_han_allgather_intra}, // 2-level
{"simple", (fnptr_t)&mca_coll_han_allgather_intra_simple}, // 2-level
Expand Down
17 changes: 17 additions & 0 deletions ompi/mca/coll/han/coll_han_algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount,
struct ompi_communicator_t *comm,
mca_coll_base_module_t * module);

/* Scatterv */
int
mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts,
const int *displs, struct ompi_datatype_t *sdtype,
void *rbuf, int rcount,
struct ompi_datatype_t *rdtype,
int root,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);

/* Gather */
int
mca_coll_han_gather_intra(const void *sbuf, int scount,
Expand All @@ -176,6 +186,13 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module);

/* Gatherv */
int
mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype,
void *rbuf, const int *rcounts, const int *displs,
struct ompi_datatype_t *rdtype, int root,
struct ompi_communicator_t *comm, mca_coll_base_module_t *module);

/* Allgather */
int
mca_coll_han_allgather_intra(const void *sbuf, int scount,
Expand Down
34 changes: 34 additions & 0 deletions ompi/mca/coll/han/coll_han_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,21 @@ static int han_close(void)
free(mca_coll_han_component.han_op_module_name.gather.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.gather.han_op_low_module_name = NULL;

free(mca_coll_han_component.han_op_module_name.gatherv.han_op_up_module_name);
mca_coll_han_component.han_op_module_name.gatherv.han_op_up_module_name = NULL;
free(mca_coll_han_component.han_op_module_name.gatherv.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.gatherv.han_op_low_module_name = NULL;

free(mca_coll_han_component.han_op_module_name.scatter.han_op_up_module_name);
mca_coll_han_component.han_op_module_name.scatter.han_op_up_module_name = NULL;
free(mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name = NULL;

free(mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name);
mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name = NULL;
free(mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name);
mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name = NULL;

return OMPI_SUCCESS;
}

Expand Down Expand Up @@ -344,6 +354,18 @@ static int han_register(void)
OPAL_INFO_LVL_9, &cs->han_gather_low_module,
&cs->han_op_module_name.gather.han_op_low_module_name);

cs->han_gatherv_up_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "gatherv_up_module",
"up level module for gatherv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_gatherv_up_module,
&cs->han_op_module_name.gatherv.han_op_up_module_name);

cs->han_gatherv_low_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "gatherv_low_module",
"low level module for gatherv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_gatherv_low_module,
&cs->han_op_module_name.gatherv.han_op_low_module_name);

cs->han_scatter_up_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "scatter_up_module",
"up level module for scatter, 0 libnbc, 1 adapt",
Expand All @@ -356,6 +378,18 @@ static int han_register(void)
OPAL_INFO_LVL_9, &cs->han_scatter_low_module,
&cs->han_op_module_name.scatter.han_op_low_module_name);

cs->han_scatterv_up_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "scatterv_up_module",
"up level module for scatterv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_scatterv_up_module,
&cs->han_op_module_name.scatterv.han_op_up_module_name);

cs->han_scatterv_low_module = 0;
(void) mca_coll_han_query_module_from_mca(c, "scatterv_low_module",
"low level module for scatterv, 0 basic",
OPAL_INFO_LVL_9, &cs->han_scatterv_low_module,
&cs->han_op_module_name.scatterv.han_op_low_module_name);

cs->han_reproducible = 0;
(void) mca_base_component_var_register(c, "reproducible",
"whether we need reproducible results "
Expand Down
Loading

0 comments on commit 984944d

Please sign in to comment.