From b22e5fa8e3fd681f3b9a8a90223287b4e34ae091 Mon Sep 17 00:00:00 2001 From: Wenduo Wang Date: Mon, 26 Feb 2024 15:50:03 +0000 Subject: [PATCH 1/3] coll/han: refactor mca_coll_han_get_ranks function Relax the function requirement to allow null low/up_rank output pointers, and rename the arguments because the function works for non-root ranks as well. Signed-off-by: Wenduo Wang --- ompi/mca/coll/han/coll_han.h | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ompi/mca/coll/han/coll_han.h b/ompi/mca/coll/han/coll_han.h index b317dc5185b..d9eada8cff1 100644 --- a/ompi/mca/coll/han/coll_han.h +++ b/ompi/mca/coll/han/coll_han.h @@ -4,6 +4,8 @@ * reserved. * Copyright (c) 2022 IBM Corporation. All rights reserved * Copyright (c) 2020-2022 Bull S.A.S. All rights reserved. + * Copyright (c) Amazon.com, Inc. or its affiliates. + * All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -432,11 +434,16 @@ int *mca_coll_han_topo_init(struct ompi_communicator_t *comm, mca_coll_han_modul /* Utils */ static inline void -mca_coll_han_get_ranks(int *vranks, int root, int low_size, - int *root_low_rank, int *root_up_rank) +mca_coll_han_get_ranks(int *vranks, int w_rank, int low_size, + int *low_rank, int *up_rank) { - *root_up_rank = vranks[root] / low_size; - *root_low_rank = vranks[root] % low_size; + if (up_rank) { + *up_rank = vranks[w_rank] / low_size; + } + + if (low_rank) { + *low_rank = vranks[w_rank] % low_size; + } } const char* mca_coll_han_topo_lvl_to_str(TOPO_LVL_T topo_lvl); From 48c125e470fefe73dfa7242c8b7504221d473dab Mon Sep 17 00:00:00 2001 From: Wenduo Wang Date: Mon, 26 Feb 2024 15:54:31 +0000 Subject: [PATCH 2/3] coll/han: implement hierarchical gatherv Add gatherv implementation to optimize large-scale communications on multiple nodes and multiple processes per node, by avoiding high-incast traffic on the root process. Because *V collectives do not have equal datatype/count on every process, it does not natively support message-size based tuning without an additional global communication. Similar to gather and allgather, the hierarchical gatherv requires a temporary buffer and memory copy to handle out-of-order data, or non-contiguous placement on the output buffer, which results in worse performance for large messages compared to the linear implementation. Signed-off-by: Wenduo Wang --- ompi/mca/coll/han/Makefile.am | 4 +- ompi/mca/coll/han/coll_han.h | 20 ++ ompi/mca/coll/han/coll_han_algorithms.c | 4 + ompi/mca/coll/han/coll_han_algorithms.h | 7 + ompi/mca/coll/han/coll_han_component.c | 17 ++ ompi/mca/coll/han/coll_han_dynamic.c | 100 ++++++ ompi/mca/coll/han/coll_han_gatherv.c | 385 ++++++++++++++++++++++++ ompi/mca/coll/han/coll_han_module.c | 7 +- ompi/mca/coll/han/coll_han_subcomms.c | 6 + ompi/mca/coll/han/coll_han_utils.c | 72 +++++ 10 files changed, 620 insertions(+), 2 deletions(-) create mode 100644 ompi/mca/coll/han/coll_han_gatherv.c create mode 100644 ompi/mca/coll/han/coll_han_utils.c diff --git a/ompi/mca/coll/han/Makefile.am b/ompi/mca/coll/han/Makefile.am index acaaab5c749..b75513b9130 100644 --- a/ompi/mca/coll/han/Makefile.am +++ b/ompi/mca/coll/han/Makefile.am @@ -22,6 +22,7 @@ coll_han_bcast.c \ coll_han_reduce.c \ coll_han_scatter.c \ coll_han_gather.c \ +coll_han_gatherv.c \ coll_han_allreduce.c \ coll_han_allgather.c \ coll_han_component.c \ @@ -31,7 +32,8 @@ coll_han_algorithms.c \ coll_han_dynamic.c \ coll_han_dynamic_file.c \ coll_han_topo.c \ -coll_han_subcomms.c +coll_han_subcomms.c \ +coll_han_utils.c # Make the output library in this directory, and name it either # mca__.la (for DSO builds) or libmca__.la diff --git a/ompi/mca/coll/han/coll_han.h b/ompi/mca/coll/han/coll_han.h index d9eada8cff1..f79cb06105c 100644 --- a/ompi/mca/coll/han/coll_han.h +++ b/ompi/mca/coll/han/coll_han.h @@ -191,6 +191,7 @@ typedef struct mca_coll_han_op_module_name_t { mca_coll_han_op_up_low_module_name_t allreduce; mca_coll_han_op_up_low_module_name_t allgather; mca_coll_han_op_up_low_module_name_t gather; + mca_coll_han_op_up_low_module_name_t gatherv; mca_coll_han_op_up_low_module_name_t scatter; } mca_coll_han_op_module_name_t; @@ -235,6 +236,10 @@ typedef struct mca_coll_han_component_t { uint32_t han_gather_up_module; /* low level module for gather */ uint32_t han_gather_low_module; + /* up level module for gatherv */ + uint32_t han_gatherv_up_module; + /* low level module for gatherv */ + uint32_t han_gatherv_low_module; /* up level module for scatter */ uint32_t han_scatter_up_module; /* low level module for scatter */ @@ -279,6 +284,7 @@ typedef struct mca_coll_han_single_collective_fallback_s { mca_coll_base_module_barrier_fn_t barrier; mca_coll_base_module_bcast_fn_t bcast; mca_coll_base_module_gather_fn_t gather; + mca_coll_base_module_gatherv_fn_t gatherv; mca_coll_base_module_reduce_fn_t reduce; mca_coll_base_module_scatter_fn_t scatter; } module_fn; @@ -298,6 +304,7 @@ typedef struct mca_coll_han_collectives_fallback_s { mca_coll_han_single_collective_fallback_t bcast; mca_coll_han_single_collective_fallback_t reduce; mca_coll_han_single_collective_fallback_t gather; + mca_coll_han_single_collective_fallback_t gatherv; mca_coll_han_single_collective_fallback_t scatter; } mca_coll_han_collectives_fallback_t; @@ -371,6 +378,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t); #define previous_gather fallback.gather.module_fn.gather #define previous_gather_module fallback.gather.module +#define previous_gatherv fallback.gatherv.module_fn.gatherv +#define previous_gatherv_module fallback.gatherv.module + #define previous_scatter fallback.scatter.module_fn.scatter #define previous_scatter_module fallback.scatter.module @@ -394,6 +404,7 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t); HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \ + HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allreduce); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, allgather); \ @@ -476,6 +487,9 @@ int mca_coll_han_gather_intra_dynamic(GATHER_BASE_ARGS, mca_coll_base_module_t *module); int +mca_coll_han_gatherv_intra_dynamic(GATHERV_BASE_ARGS, + mca_coll_base_module_t *module); +int mca_coll_han_reduce_intra_dynamic(REDUCE_BASE_ARGS, mca_coll_base_module_t *module); int @@ -493,4 +507,10 @@ ompi_coll_han_reorder_gather(const void *sbuf, struct ompi_communicator_t *comm, int * topo); +size_t +coll_han_utils_gcd(const size_t *numerators, const size_t size); + +int +coll_han_utils_create_contiguous_datatype(size_t count, const ompi_datatype_t *oldType, + ompi_datatype_t **newType); #endif /* MCA_COLL_HAN_EXPORT_H */ diff --git a/ompi/mca/coll/han/coll_han_algorithms.c b/ompi/mca/coll/han/coll_han_algorithms.c index bc2bd5ebade..dbeb7ebe07d 100644 --- a/ompi/mca/coll/han/coll_han_algorithms.c +++ b/ompi/mca/coll/han/coll_han_algorithms.c @@ -64,6 +64,10 @@ mca_coll_han_algorithm_value_t* mca_coll_han_available_algorithms[COLLCOUNT] = {"simple", (fnptr_t) &mca_coll_han_gather_intra_simple}, // 2-level { 0 } }, + [GATHERV] = (mca_coll_han_algorithm_value_t[]){ + {"intra", (fnptr_t) &mca_coll_han_gatherv_intra}, // 2-level + { 0 } + }, [ALLGATHER] = (mca_coll_han_algorithm_value_t[]){ {"intra", (fnptr_t)&mca_coll_han_allgather_intra}, // 2-level {"simple", (fnptr_t)&mca_coll_han_allgather_intra_simple}, // 2-level diff --git a/ompi/mca/coll/han/coll_han_algorithms.h b/ompi/mca/coll/han/coll_han_algorithms.h index 3a247ff9fbd..d73250d5963 100644 --- a/ompi/mca/coll/han/coll_han_algorithms.h +++ b/ompi/mca/coll/han/coll_han_algorithms.h @@ -176,6 +176,13 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t *module); +/* Gatherv */ +int +mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, + void *rbuf, const int *rcounts, const int *displs, + struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module); + /* Allgather */ int mca_coll_han_allgather_intra(const void *sbuf, int scount, diff --git a/ompi/mca/coll/han/coll_han_component.c b/ompi/mca/coll/han/coll_han_component.c index 1d14baf538a..e4ffd6fdfca 100644 --- a/ompi/mca/coll/han/coll_han_component.c +++ b/ompi/mca/coll/han/coll_han_component.c @@ -146,6 +146,11 @@ static int han_close(void) free(mca_coll_han_component.han_op_module_name.gather.han_op_low_module_name); mca_coll_han_component.han_op_module_name.gather.han_op_low_module_name = NULL; + free(mca_coll_han_component.han_op_module_name.gatherv.han_op_up_module_name); + mca_coll_han_component.han_op_module_name.gatherv.han_op_up_module_name = NULL; + free(mca_coll_han_component.han_op_module_name.gatherv.han_op_low_module_name); + mca_coll_han_component.han_op_module_name.gatherv.han_op_low_module_name = NULL; + free(mca_coll_han_component.han_op_module_name.scatter.han_op_up_module_name); mca_coll_han_component.han_op_module_name.scatter.han_op_up_module_name = NULL; free(mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name); @@ -344,6 +349,18 @@ static int han_register(void) OPAL_INFO_LVL_9, &cs->han_gather_low_module, &cs->han_op_module_name.gather.han_op_low_module_name); + cs->han_gatherv_up_module = 0; + (void) mca_coll_han_query_module_from_mca(c, "gatherv_up_module", + "up level module for gatherv, 0 basic", + OPAL_INFO_LVL_9, &cs->han_gatherv_up_module, + &cs->han_op_module_name.gatherv.han_op_up_module_name); + + cs->han_gatherv_low_module = 0; + (void) mca_coll_han_query_module_from_mca(c, "gatherv_low_module", + "low level module for gatherv, 0 basic", + OPAL_INFO_LVL_9, &cs->han_gatherv_low_module, + &cs->han_op_module_name.gatherv.han_op_low_module_name); + cs->han_scatter_up_module = 0; (void) mca_coll_han_query_module_from_mca(c, "scatter_up_module", "up level module for scatter, 0 libnbc, 1 adapt", diff --git a/ompi/mca/coll/han/coll_han_dynamic.c b/ompi/mca/coll/han/coll_han_dynamic.c index 9063cf926fe..2c73dfea456 100644 --- a/ompi/mca/coll/han/coll_han_dynamic.c +++ b/ompi/mca/coll/han/coll_han_dynamic.c @@ -26,6 +26,8 @@ #include "ompi/mca/coll/han/coll_han_algorithms.h" #include "ompi/mca/coll/base/coll_base_util.h" +#define MCA_COLL_HAN_ANY_MESSAGE_SIZE 0 + /* * Tests if a dynamic collective is implemented * Useful for file reading warnings and MCA parameter generation @@ -41,6 +43,7 @@ bool mca_coll_han_is_coll_dynamic_implemented(COLLTYPE_T coll_id) case BARRIER: case BCAST: case GATHER: + case GATHERV: case REDUCE: case SCATTER: return true; @@ -1045,6 +1048,103 @@ mca_coll_han_gather_intra_dynamic(const void *sbuf, int scount, sub_module); } +/* + * Gatherv selector: + * On a sub-communicator, checks the stored rules to find the module to use + * On the global communicator, calls the han collective implementation, or + * calls the correct module if fallback mechanism is activated + */ +int mca_coll_han_gatherv_intra_dynamic(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, + void *rbuf, const int *rcounts, const int *displs, + struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mca_coll_han_module_t *han_module = (mca_coll_han_module_t*) module; + TOPO_LVL_T topo_lvl = han_module->topologic_level; + mca_coll_base_module_gatherv_fn_t gatherv; + mca_coll_base_module_t *sub_module; + int rank, verbosity = 0; + + if (!han_module->enabled) { + return han_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts, displs, rdtype, + root, comm, han_module->previous_gatherv_module); + } + + /* v collectives do not support message-size based dynamic rules */ + sub_module = get_module(GATHERV, MCA_COLL_HAN_ANY_MESSAGE_SIZE, comm, han_module); + + /* First errors are always printed by rank 0 */ + rank = ompi_comm_rank(comm); + if( (0 == rank) && (han_module->dynamic_errors < mca_coll_han_component.max_dynamic_errors) ) { + verbosity = 30; + } + + if(NULL == sub_module) { + /* + * No valid collective module from dynamic rules + * nor from mca parameter + */ + han_module->dynamic_errors++; + opal_output_verbose(verbosity, mca_coll_han_component.han_output, + "coll:han:mca_coll_han_gatherv_intra_dynamic " + "HAN did not find any valid module for collective %d (%s) " + "with topological level %d (%s) on communicator (%s/%s). " + "Please check dynamic file/mca parameters\n", + GATHERV, mca_coll_base_colltype_to_str(GATHERV), + topo_lvl, mca_coll_han_topo_lvl_to_str(topo_lvl), + ompi_comm_print_cid(comm), comm->c_name); + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "HAN/GATHERV: No module found for the sub-communicator. " + "Falling back to another component\n")); + gatherv = han_module->previous_gatherv; + sub_module = han_module->previous_gatherv_module; + } else if (NULL == sub_module->coll_gatherv) { + /* + * No valid collective from dynamic rules + * nor from mca parameter + */ + han_module->dynamic_errors++; + opal_output_verbose(verbosity, mca_coll_han_component.han_output, + "coll:han:mca_coll_han_gatherv_intra_dynamic " + "HAN found valid module for collective %d (%s) " + "with topological level %d (%s) on communicator (%s/%s) " + "but this module cannot handle this collective. " + "Please check dynamic file/mca parameters\n", + GATHERV, mca_coll_base_colltype_to_str(GATHERV), + topo_lvl, mca_coll_han_topo_lvl_to_str(topo_lvl), + ompi_comm_print_cid(comm), comm->c_name); + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "HAN/GATHERV: the module found for the sub-" + "communicator cannot handle the GATHERV operation. " + "Falling back to another component\n")); + gatherv = han_module->previous_gatherv; + sub_module = han_module->previous_gatherv_module; + } else if (GLOBAL_COMMUNICATOR == topo_lvl && sub_module == module) { + /* + * No fallback mechanism activated for this configuration + * sub_module is valid + * sub_module->coll_gatherv is valid and point to this function + * Call han topological collective algorithm + */ + int algorithm_id = get_algorithm(GATHERV, MCA_COLL_HAN_ANY_MESSAGE_SIZE, comm, han_module); + gatherv = (mca_coll_base_module_gatherv_fn_t) mca_coll_han_algorithm_id_to_fn(GATHERV, algorithm_id); + if (NULL == gatherv) { /* default behaviour */ + gatherv = mca_coll_han_gatherv_intra; + } + } else { + /* + * If we get here: + * sub_module is valid + * sub_module->coll_gatherv is valid + * They points to the collective to use, according to the dynamic rules + * Selector's job is done, call the collective + */ + gatherv = sub_module->coll_gatherv; + } + return gatherv(sbuf, scount, sdtype, rbuf, rcounts, displs, rdtype, root, comm, sub_module); +} + /* * Reduce selector: diff --git a/ompi/mca/coll/han/coll_han_gatherv.c b/ompi/mca/coll/han/coll_han_gatherv.c new file mode 100644 index 00000000000..6ec095c90bc --- /dev/null +++ b/ompi/mca/coll/han/coll_han_gatherv.c @@ -0,0 +1,385 @@ +/* + * Copyright (c) 2018-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Bull S.A.S. All rights reserved. + * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) Amazon.com, Inc. or its affiliates. + * All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "coll_han.h" +#include "coll_han_trigger.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/pml/pml.h" + +/* + * @file + * + * This files contains the hierarchical implementations of gatherv. + * Only work with regular situation (each node has equal number of processes). + */ + +/* + * Implement hierarchical Gatherv to optimize large-scale communications where multiple nodes and + * multiple processes per node send non-zero sized messages to the root, i.e. high incast. + * + * In Gatherv, only the root(receiver) process has the information of the amount of data, i.e. + * datatype and count, from each sender process. Therefore node leaders need an additional step to + * collect the expected data from its local peers. In summary, the steps are: + * 1. Root + * a. Receive data from local peers (Low Gatherv) + * b. Receive data from other node leaders (Up Gatherv) + * c. If necessary reorder data from node leaders(see discussion below) + * 2. Root's local peers + * a. Send data to root. (Low Gatherv) + * 3. Node leaders: + * a. Collect the data transfer sizes(in bytes) from local peers (Low Gather) + * b. Receive data from local peers (Low Gatherv) + * c. Send data to the root (Up Gatherv) + * 4. Node followers: + * a. Send the data transfer size(in bytes) to the node leader (Low Gather) + * b. Send data to the node leader (Low Gatherv) + * + * Note on reodering: + * In Up Gatherv, data from each node is stored in a contiguous buffer sorted by the sender's + * local rank, and MUST be reordered according to the root's displacement requirement on the output + * buffer. Concretely, reordering can avoided if and only if both of following conditions are met: + * 1. Data from processes on each node, other than the root's node, are placed in the output buffer + * in the same **order** as their local ranks. Note, it is possible to receive the data in the + * correct order even if the process are NOT mapped by core. + * 2. No **gap** exists between data from the same node, other than the root's node, in the output + * buffer - it is ok if data from different nodes has gap. + */ +int mca_coll_han_gatherv_intra(const void *sbuf, int scount, struct ompi_datatype_t *sdtype, + void *rbuf, const int *rcounts, const int *displs, + struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + mca_coll_han_module_t *han_module = (mca_coll_han_module_t *) module; + int w_rank, w_size; /* information about the global communicator */ + int root_low_rank, root_up_rank; /* root ranks for both sub-communicators */ + int err, *vranks, low_rank, low_size, up_rank, up_size, *topo; + int *low_rcounts = NULL, *low_displs = NULL; + + /* Create the subcommunicators */ + err = mca_coll_han_comm_create(comm, han_module); + if (OMPI_SUCCESS != err) { + OPAL_OUTPUT_VERBOSE( + (30, mca_coll_han_component.han_output, + "han cannot handle gatherv with this communicator. Fall back on another component\n")); + /* HAN cannot work with this communicator so fallback on all collectives */ + HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + return han_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts, displs, rdtype, + root, comm, han_module->previous_gatherv_module); + } + + /* Topo must be initialized to know rank distribution which then is used to determine if han can + * be used */ + topo = mca_coll_han_topo_init(comm, han_module, 2); + if (han_module->are_ppn_imbalanced) { + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "han cannot handle gatherv with this communicator (imbalance). Fall " + "back on another component\n")); + /* Put back the fallback collective support and call it once. All + * future calls will then be automatically redirected. + */ + HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gatherv); + return han_module->previous_gatherv(sbuf, scount, sdtype, rbuf, rcounts, displs, rdtype, + root, comm, han_module->previous_gatherv_module); + } + + w_rank = ompi_comm_rank(comm); + w_size = ompi_comm_size(comm); + + /* create the subcommunicators */ + ompi_communicator_t *low_comm + = han_module->cached_low_comms[mca_coll_han_component.han_gatherv_low_module]; + ompi_communicator_t *up_comm + = han_module->cached_up_comms[mca_coll_han_component.han_gatherv_up_module]; + + /* Get the 'virtual ranks' mapping corresponding to the communicators */ + vranks = han_module->cached_vranks; + /* information about sub-communicators */ + low_rank = ompi_comm_rank(low_comm); + low_size = ompi_comm_size(low_comm); + up_rank = ompi_comm_rank(up_comm); + up_size = ompi_comm_size(up_comm); + /* Get root ranks for low and up comms */ + mca_coll_han_get_ranks(vranks, root, low_size, &root_low_rank, &root_up_rank); + + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "[%d]: Han Gatherv root %d root_low_rank %d root_up_rank %d\n", w_rank, + root, root_low_rank, root_up_rank)); + + err = OMPI_SUCCESS; + /* #################### Root ########################### */ + if (root == w_rank) { + int need_bounce_buf = 0, total_up_rcounts = 0, *up_displs = NULL, *up_rcounts = NULL, + *up_peer_lb = NULL, *up_peer_ub = NULL; + char *bounce_buf = NULL; + + low_rcounts = malloc(low_size * sizeof(int)); + low_displs = malloc(low_size * sizeof(int)); + if (!low_rcounts || !low_displs) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto root_out; + } + + int low_peer, up_peer, w_peer; + for (w_peer = 0; w_peer < w_size; ++w_peer) { + mca_coll_han_get_ranks(vranks, w_peer, low_size, &low_peer, &up_peer); + if (root_up_rank != up_peer) { + /* Not a local peer */ + continue; + } + low_displs[low_peer] = displs[w_peer]; + low_rcounts[low_peer] = rcounts[w_peer]; + } + + /* Low Gatherv */ + low_comm->c_coll->coll_gatherv(sbuf, scount, sdtype, rbuf, low_rcounts, low_displs, rdtype, + root_low_rank, low_comm, + low_comm->c_coll->coll_gatherv_module); + + size_t rdsize; + char *tmp_rbuf = rbuf; + + ompi_datatype_type_size(rdtype, &rdsize); + + up_rcounts = calloc(up_size, sizeof(int)); + up_displs = malloc(up_size * sizeof(int)); + up_peer_ub = calloc(up_size, sizeof(int)); + if (!up_rcounts || !up_displs || !up_peer_ub) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto root_out; + } + + for (up_peer = 0; up_peer < up_size; ++up_peer) { + up_displs[up_peer] = INT_MAX; + } + + /* Calculate recv counts for the inter-node gatherv - no need to gather + * from self again because the data is already in place */ + for (w_peer = 0; w_peer < w_size; ++w_peer) { + mca_coll_han_get_ranks(vranks, w_peer, low_size, NULL, &up_peer); + + if (!need_bounce_buf && root_up_rank != up_peer && 0 < rcounts[w_peer] && 0 < w_peer + && displs[w_peer] < displs[w_peer - 1]) { + /* Data is not placed in the rank order so reordering is needed */ + need_bounce_buf = 1; + } + + if (root_up_rank == up_peer) { + /* No need to gather data on the same node again */ + continue; + } + + up_peer_ub[up_peer] = 0 < rcounts[w_peer] + && displs[w_peer] + rcounts[w_peer] > up_peer_ub[up_peer] + ? displs[w_peer] + rcounts[w_peer] + : up_peer_ub[up_peer]; + + up_rcounts[up_peer] += rcounts[w_peer]; + total_up_rcounts += rcounts[w_peer]; + + /* Optimize for the happy path */ + up_displs[up_peer] = 0 < rcounts[w_peer] && displs[w_peer] < up_displs[up_peer] + ? displs[w_peer] + : up_displs[up_peer]; + } + + /* If the data is not placed contiguously on recv buf, then we will need temp buf to store + * the gap data and recover it later */ + for (up_peer = 0; up_peer < up_size; ++up_peer) { + if (root_up_rank == up_peer) { + continue; + } + if (!need_bounce_buf && 0 < up_rcounts[up_peer] + && up_rcounts[up_peer] < up_peer_ub[up_peer] - up_displs[up_peer]) { + need_bounce_buf = 1; + break; + } + } + + if (need_bounce_buf) { + bounce_buf = malloc(rdsize * total_up_rcounts); + if (!bounce_buf) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto root_out; + } + + /* Calculate displacements for the inter-node gatherv */ + for (up_peer = 0; up_peer < up_size; ++up_peer) { + up_displs[up_peer] = 0 < up_peer ? up_displs[up_peer - 1] + up_rcounts[up_peer - 1] + : 0; + } + + tmp_rbuf = bounce_buf; + } + + /* Up Gatherv */ + up_comm->c_coll->coll_gatherv(sbuf, 0, sdtype, tmp_rbuf, up_rcounts, up_displs, rdtype, + root_up_rank, up_comm, up_comm->c_coll->coll_gatherv_module); + + /* Use a temp buffer to reorder the output buffer if needed */ + if (need_bounce_buf) { + ptrdiff_t offset = 0; + + for (int i = 0; i < w_size; ++i) { + up_peer = topo[2 * i]; + if (root_up_rank == up_peer) { + continue; + } + + w_peer = topo[2 * i + 1]; + + ompi_datatype_copy_content_same_ddt(rdtype, (size_t) rcounts[w_peer], + (char *) rbuf + + (size_t) displs[w_peer] * rdsize, + bounce_buf + offset); + offset += rdsize * (size_t) rcounts[w_peer]; + } + } + + root_out: + if (low_displs) { + free(low_displs); + } + if (low_rcounts) { + free(low_rcounts); + } + if (up_displs) { + free(up_displs); + } + if (up_rcounts) { + free(up_rcounts); + } + if (up_peer_lb) { + free(up_peer_lb); + } + if (up_peer_ub) { + free(up_peer_ub); + } + if (bounce_buf) { + free(bounce_buf); + } + + return err; + } + + /* #################### Root's local peers ########################### */ + if (root_up_rank == up_rank) { + /* Low Gatherv */ + low_comm->c_coll->coll_gatherv(sbuf, scount, sdtype, NULL, NULL, NULL, NULL, root_low_rank, + low_comm, low_comm->c_coll->coll_gatherv_module); + return OMPI_SUCCESS; + } + + size_t sdsize = 0; + uint64_t send_size = 0; + + ompi_datatype_type_size(sdtype, &sdsize); + send_size = (uint64_t) sdsize * (uint64_t) scount; + + /* #################### Other node followers ########################### */ + if (root_low_rank != low_rank) { + /* Low Gather - Gather each local peer's send data size */ + low_comm->c_coll->coll_gather((const void *) &send_size, 1, MPI_UINT64_T, NULL, 1, + MPI_UINT64_T, root_low_rank, low_comm, + low_comm->c_coll->coll_gather_module); + /* Low Gatherv */ + low_comm->c_coll->coll_gatherv(sbuf, scount, sdtype, NULL, NULL, NULL, NULL, root_low_rank, + low_comm, low_comm->c_coll->coll_gatherv_module); + return OMPI_SUCCESS; + } + + /* #################### Node leaders ########################### */ + + uint64_t *low_data_size = NULL; + char *tmp_buf = NULL; + ompi_datatype_t *temptype = MPI_BYTE; + + /* Allocate a temporary array to gather the data size, i.e. data type size x count, + * in bytes from local peers */ + low_data_size = malloc(low_size * sizeof(uint64_t)); + if (!low_data_size) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto node_leader_out; + } + + /* Low Gather - Gather local peers' send data sizes */ + low_comm->c_coll->coll_gather((const void *) &send_size, 1, MPI_UINT64_T, + (void *) low_data_size, 1, MPI_UINT64_T, root_low_rank, low_comm, + low_comm->c_coll->coll_gather_module); + + /* Determine if we need to create a custom datatype instead of MPI_BYTE, + * to avoid count(type int) overflow + * TODO: Remove this logic once we adopt large-count, i.e. count will become 64-bit. + */ + int total_up_scount = 0; + size_t rsize = 0, datatype_size = 1, max_data_size = 0; + for (int i = 0; i < low_size; ++i) { + rsize += (size_t) low_data_size[i]; + max_data_size = (size_t) low_data_size[i] > max_data_size ? (size_t) low_data_size[i] + : max_data_size; + } + + if (max_data_size > (size_t) INT_MAX) { + datatype_size = coll_han_utils_gcd(low_data_size, low_size); + } + + low_rcounts = malloc(low_size * sizeof(int)); + low_displs = malloc(low_size * sizeof(int)); + tmp_buf = (char *) malloc(rsize); /* tmp_buf is still valid if rsize is 0 */ + if (!tmp_buf || !low_rcounts || !low_displs) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto node_leader_out; + } + + for (int i = 0; i < low_size; ++i) { + low_rcounts[i] = (int) ((size_t) low_data_size[i] / datatype_size); + low_displs[i] = i > 0 ? low_displs[i - 1] + low_rcounts[i - 1] : 0; + total_up_scount += low_rcounts[i]; + } + + if (1 < datatype_size) { + coll_han_utils_create_contiguous_datatype(datatype_size, MPI_BYTE, &temptype); + ompi_datatype_commit(&temptype); + } + + /* Low Gatherv */ + low_comm->c_coll->coll_gatherv(sbuf, scount, sdtype, (void *) tmp_buf, low_rcounts, low_displs, + temptype, root_low_rank, low_comm, + low_comm->c_coll->coll_gatherv_module); + + /* Up Gatherv */ + up_comm->c_coll->coll_gatherv(tmp_buf, total_up_scount, temptype, NULL, NULL, NULL, NULL, + root_up_rank, up_comm, up_comm->c_coll->coll_gatherv_module); + +node_leader_out: + if (low_rcounts) { + free(low_rcounts); + } + if (low_displs) { + free(low_displs); + } + if (low_data_size) { + free(low_data_size); + } + if (tmp_buf) { + free(tmp_buf); + } + if (MPI_BYTE != temptype) { + ompi_datatype_destroy(&temptype); + } + + return err; +} diff --git a/ompi/mca/coll/han/coll_han_module.c b/ompi/mca/coll/han/coll_han_module.c index bd24d5ec1a1..782ebcdb760 100644 --- a/ompi/mca/coll/han/coll_han_module.c +++ b/ompi/mca/coll/han/coll_han_module.c @@ -52,6 +52,7 @@ static void han_module_clear(mca_coll_han_module_t *han_module) CLEAN_PREV_COLL(han_module, bcast); CLEAN_PREV_COLL(han_module, reduce); CLEAN_PREV_COLL(han_module, gather); + CLEAN_PREV_COLL(han_module, gatherv); CLEAN_PREV_COLL(han_module, scatter); han_module->reproducible_reduce = NULL; @@ -148,6 +149,7 @@ mca_coll_han_module_destruct(mca_coll_han_module_t * module) OBJ_RELEASE_IF_NOT_NULL(module->previous_allreduce_module); OBJ_RELEASE_IF_NOT_NULL(module->previous_bcast_module); OBJ_RELEASE_IF_NOT_NULL(module->previous_gather_module); + OBJ_RELEASE_IF_NOT_NULL(module->previous_gatherv_module); OBJ_RELEASE_IF_NOT_NULL(module->previous_reduce_module); OBJ_RELEASE_IF_NOT_NULL(module->previous_scatter_module); @@ -250,7 +252,6 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) han_module->super.coll_alltoallv = NULL; han_module->super.coll_alltoallw = NULL; han_module->super.coll_exscan = NULL; - han_module->super.coll_gatherv = NULL; han_module->super.coll_reduce_scatter = NULL; han_module->super.coll_scan = NULL; han_module->super.coll_scatterv = NULL; @@ -258,6 +259,7 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) han_module->super.coll_scatter = mca_coll_han_scatter_intra_dynamic; han_module->super.coll_reduce = mca_coll_han_reduce_intra_dynamic; han_module->super.coll_gather = mca_coll_han_gather_intra_dynamic; + han_module->super.coll_gatherv = mca_coll_han_gatherv_intra_dynamic; han_module->super.coll_bcast = mca_coll_han_bcast_intra_dynamic; han_module->super.coll_allreduce = mca_coll_han_allreduce_intra_dynamic; han_module->super.coll_allgather = mca_coll_han_allgather_intra_dynamic; @@ -311,6 +313,7 @@ han_module_enable(mca_coll_base_module_t * module, HAN_SAVE_PREV_COLL_API(barrier); HAN_SAVE_PREV_COLL_API(bcast); HAN_SAVE_PREV_COLL_API(gather); + HAN_SAVE_PREV_COLL_API(gatherv); HAN_SAVE_PREV_COLL_API(reduce); HAN_SAVE_PREV_COLL_API(scatter); @@ -326,6 +329,7 @@ han_module_enable(mca_coll_base_module_t * module, OBJ_RELEASE_IF_NOT_NULL(han_module->previous_allreduce_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_bcast_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gather_module); + OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module); @@ -347,6 +351,7 @@ mca_coll_han_module_disable(mca_coll_base_module_t * module, OBJ_RELEASE_IF_NOT_NULL(han_module->previous_barrier_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_bcast_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gather_module); + OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module); diff --git a/ompi/mca/coll/han/coll_han_subcomms.c b/ompi/mca/coll/han/coll_han_subcomms.c index fe6e197cfb8..2476e1d2cc9 100644 --- a/ompi/mca/coll/han/coll_han_subcomms.c +++ b/ompi/mca/coll/han/coll_han_subcomms.c @@ -78,6 +78,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, bcast); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, reduce); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, scatter); /** @@ -105,6 +106,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); han_module->enabled = false; /* entire module set to pass-through from now on */ return OMPI_ERR_NOT_SUPPORTED; @@ -181,6 +183,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); OBJ_DESTRUCT(&comm_info); @@ -236,6 +239,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, bcast); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, reduce); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, scatter); /** @@ -260,6 +264,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); han_module->enabled = false; /* entire module set to pass-through from now on */ return OMPI_ERR_NOT_SUPPORTED; @@ -348,6 +353,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, bcast); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, reduce); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); + HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); OBJ_DESTRUCT(&comm_info); diff --git a/ompi/mca/coll/han/coll_han_utils.c b/ompi/mca/coll/han/coll_han_utils.c new file mode 100644 index 00000000000..e9cc8c10a3c --- /dev/null +++ b/ompi/mca/coll/han/coll_han_utils.c @@ -0,0 +1,72 @@ +/* + * Copyright (c) Amazon.com, Inc. or its affiliates. + * All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +/** + * @file + * + * Shared utility functions + */ + +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "coll_han.h" + +/** + * Calculate the Greatest Common Denominator of a list of non-negative integers + * + * @param[in] numerators A list of numerators that should be divisible by + * the denominator + * @param[in] size Number of numerators + * @returns The GCD, where 1 <= GCD + */ +size_t coll_han_utils_gcd(const size_t *numerators, const size_t size) +{ + size_t denominator = numerators[0], numerator, tmp; + + for (size_t i = 1; i < size; ++i) { + numerator = numerators[i]; + + if (0 == denominator) { + denominator = numerator; + continue; + } + + if (0 == numerator) { + continue; + } + + while (0 < numerator % denominator && 0 < denominator % numerator) { + tmp = MIN(numerator, denominator); + denominator = MAX(numerator, denominator) - tmp; + numerator = tmp; + } + } + + if (0 == denominator) { + denominator = 1; + } + + return denominator; +} + +int coll_han_utils_create_contiguous_datatype(size_t count, const ompi_datatype_t *oldType, + ompi_datatype_t **newType) +{ + ompi_datatype_t *pdt; + + if ((0 == count) || (0 == oldType->super.size)) { + return ompi_datatype_duplicate(&ompi_mpi_datatype_null.dt, newType); + } + + pdt = ompi_datatype_create(oldType->super.desc.used + 2); + opal_datatype_add(&(pdt->super), &(oldType->super), count, 0, + (oldType->super.ub - oldType->super.lb)); + *newType = pdt; + return OMPI_SUCCESS; +} From 2152b61a1dd2e6d18cdfbce1aa4c91a3df543d66 Mon Sep 17 00:00:00 2001 From: Jessie Yang Date: Wed, 28 Feb 2024 13:36:00 -0800 Subject: [PATCH 3/3] coll/han: implement hierarchical scatterv Add scatterv implementation to optimize large-scale communications on multiple nodes and multiple processes per node, by avoiding high-incast traffic on the root process. Because *V collectives do not have equal datatype/count on every process, it does not natively support message-size based tuning without an additional global communication. Similar to scatter, the hierarchical scatterv requires a temporary buffer and memory copy to handle out-of-order data, or non-contiguous placement on the send buffer, which results in worse performance for large messages compared to the linear implementation. Signed-off-by: Jessie Yang --- ompi/mca/coll/han/Makefile.am | 1 + ompi/mca/coll/han/coll_han.h | 13 + ompi/mca/coll/han/coll_han_algorithms.c | 4 + ompi/mca/coll/han/coll_han_algorithms.h | 10 + ompi/mca/coll/han/coll_han_component.c | 17 ++ ompi/mca/coll/han/coll_han_dynamic.c | 111 +++++++ ompi/mca/coll/han/coll_han_module.c | 7 +- ompi/mca/coll/han/coll_han_scatterv.c | 385 ++++++++++++++++++++++++ ompi/mca/coll/han/coll_han_subcomms.c | 6 + 9 files changed, 553 insertions(+), 1 deletion(-) create mode 100644 ompi/mca/coll/han/coll_han_scatterv.c diff --git a/ompi/mca/coll/han/Makefile.am b/ompi/mca/coll/han/Makefile.am index b75513b9130..e9ca89d055c 100644 --- a/ompi/mca/coll/han/Makefile.am +++ b/ompi/mca/coll/han/Makefile.am @@ -21,6 +21,7 @@ coll_han_barrier.c \ coll_han_bcast.c \ coll_han_reduce.c \ coll_han_scatter.c \ +coll_han_scatterv.c \ coll_han_gather.c \ coll_han_gatherv.c \ coll_han_allreduce.c \ diff --git a/ompi/mca/coll/han/coll_han.h b/ompi/mca/coll/han/coll_han.h index f79cb06105c..e7c14efaeb5 100644 --- a/ompi/mca/coll/han/coll_han.h +++ b/ompi/mca/coll/han/coll_han.h @@ -193,6 +193,7 @@ typedef struct mca_coll_han_op_module_name_t { mca_coll_han_op_up_low_module_name_t gather; mca_coll_han_op_up_low_module_name_t gatherv; mca_coll_han_op_up_low_module_name_t scatter; + mca_coll_han_op_up_low_module_name_t scatterv; } mca_coll_han_op_module_name_t; /** @@ -244,6 +245,10 @@ typedef struct mca_coll_han_component_t { uint32_t han_scatter_up_module; /* low level module for scatter */ uint32_t han_scatter_low_module; + /* up level module for scatterv */ + uint32_t han_scatterv_up_module; + /* low level module for scatterv */ + uint32_t han_scatterv_low_module; /* name of the modules */ mca_coll_han_op_module_name_t han_op_module_name; /* whether we need reproducible results @@ -287,6 +292,7 @@ typedef struct mca_coll_han_single_collective_fallback_s { mca_coll_base_module_gatherv_fn_t gatherv; mca_coll_base_module_reduce_fn_t reduce; mca_coll_base_module_scatter_fn_t scatter; + mca_coll_base_module_scatterv_fn_t scatterv; } module_fn; mca_coll_base_module_t* module; } mca_coll_han_single_collective_fallback_t; @@ -306,6 +312,7 @@ typedef struct mca_coll_han_collectives_fallback_s { mca_coll_han_single_collective_fallback_t gather; mca_coll_han_single_collective_fallback_t gatherv; mca_coll_han_single_collective_fallback_t scatter; + mca_coll_han_single_collective_fallback_t scatterv; } mca_coll_han_collectives_fallback_t; /** Coll han module */ @@ -384,6 +391,8 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t); #define previous_scatter fallback.scatter.module_fn.scatter #define previous_scatter_module fallback.scatter.module +#define previous_scatterv fallback.scatterv.module_fn.scatterv +#define previous_scatterv_module fallback.scatterv.module /* macro to correctly load a fallback collective module */ #define HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, COLL) \ @@ -403,6 +412,7 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t); HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, barrier); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, bcast); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatter); \ + HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, scatterv); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gather); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, gatherv); \ HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, reduce); \ @@ -495,6 +505,9 @@ mca_coll_han_reduce_intra_dynamic(REDUCE_BASE_ARGS, int mca_coll_han_scatter_intra_dynamic(SCATTER_BASE_ARGS, mca_coll_base_module_t *module); +int +mca_coll_han_scatterv_intra_dynamic(SCATTERV_BASE_ARGS, + mca_coll_base_module_t *module); int mca_coll_han_barrier_intra_simple(struct ompi_communicator_t *comm, mca_coll_base_module_t *module); diff --git a/ompi/mca/coll/han/coll_han_algorithms.c b/ompi/mca/coll/han/coll_han_algorithms.c index dbeb7ebe07d..9ebc04588c1 100644 --- a/ompi/mca/coll/han/coll_han_algorithms.c +++ b/ompi/mca/coll/han/coll_han_algorithms.c @@ -59,6 +59,10 @@ mca_coll_han_algorithm_value_t* mca_coll_han_available_algorithms[COLLCOUNT] = {"simple", (fnptr_t) &mca_coll_han_scatter_intra_simple}, // 2-level { 0 } }, + [SCATTERV] = (mca_coll_han_algorithm_value_t[]){ + {"intra", (fnptr_t) &mca_coll_han_scatterv_intra}, // 2-level + { 0 } + }, [GATHER] = (mca_coll_han_algorithm_value_t[]){ {"intra", (fnptr_t) &mca_coll_han_gather_intra}, // 2-level {"simple", (fnptr_t) &mca_coll_han_gather_intra_simple}, // 2-level diff --git a/ompi/mca/coll/han/coll_han_algorithms.h b/ompi/mca/coll/han/coll_han_algorithms.h index d73250d5963..414b7293e09 100644 --- a/ompi/mca/coll/han/coll_han_algorithms.h +++ b/ompi/mca/coll/han/coll_han_algorithms.h @@ -159,6 +159,16 @@ mca_coll_han_scatter_intra_simple(const void *sbuf, int scount, struct ompi_communicator_t *comm, mca_coll_base_module_t * module); +/* Scatterv */ +int +mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts, + const int *displs, struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, + struct ompi_datatype_t *rdtype, + int root, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module); + /* Gather */ int mca_coll_han_gather_intra(const void *sbuf, int scount, diff --git a/ompi/mca/coll/han/coll_han_component.c b/ompi/mca/coll/han/coll_han_component.c index e4ffd6fdfca..43cdf9c96c2 100644 --- a/ompi/mca/coll/han/coll_han_component.c +++ b/ompi/mca/coll/han/coll_han_component.c @@ -156,6 +156,11 @@ static int han_close(void) free(mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name); mca_coll_han_component.han_op_module_name.scatter.han_op_low_module_name = NULL; + free(mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name); + mca_coll_han_component.han_op_module_name.scatterv.han_op_up_module_name = NULL; + free(mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name); + mca_coll_han_component.han_op_module_name.scatterv.han_op_low_module_name = NULL; + return OMPI_SUCCESS; } @@ -373,6 +378,18 @@ static int han_register(void) OPAL_INFO_LVL_9, &cs->han_scatter_low_module, &cs->han_op_module_name.scatter.han_op_low_module_name); + cs->han_scatterv_up_module = 0; + (void) mca_coll_han_query_module_from_mca(c, "scatterv_up_module", + "up level module for scatterv, 0 basic", + OPAL_INFO_LVL_9, &cs->han_scatterv_up_module, + &cs->han_op_module_name.scatterv.han_op_up_module_name); + + cs->han_scatterv_low_module = 0; + (void) mca_coll_han_query_module_from_mca(c, "scatterv_low_module", + "low level module for scatterv, 0 basic", + OPAL_INFO_LVL_9, &cs->han_scatterv_low_module, + &cs->han_op_module_name.scatterv.han_op_low_module_name); + cs->han_reproducible = 0; (void) mca_base_component_var_register(c, "reproducible", "whether we need reproducible results " diff --git a/ompi/mca/coll/han/coll_han_dynamic.c b/ompi/mca/coll/han/coll_han_dynamic.c index 2c73dfea456..b3cc832b4a7 100644 --- a/ompi/mca/coll/han/coll_han_dynamic.c +++ b/ompi/mca/coll/han/coll_han_dynamic.c @@ -46,6 +46,7 @@ bool mca_coll_han_is_coll_dynamic_implemented(COLLTYPE_T coll_id) case GATHERV: case REDUCE: case SCATTER: + case SCATTERV: return true; default: return false; @@ -1397,3 +1398,113 @@ mca_coll_han_scatter_intra_dynamic(const void *sbuf, int scount, root, comm, sub_module); } + + +/* + * Scatterv selector: + * On a sub-communicator, checks the stored rules to find the module to use + * On the global communicator, calls the han collective implementation, or + * calls the correct module if fallback mechanism is activated + */ +int +mca_coll_han_scatterv_intra_dynamic(const void *sbuf, const int *scounts, + const int *displs, struct ompi_datatype_t *sdtype, + void *rbuf, int rcount, + struct ompi_datatype_t *rdtype, + int root, + struct ompi_communicator_t *comm, + mca_coll_base_module_t *module) +{ + mca_coll_han_module_t *han_module = (mca_coll_han_module_t*) module; + TOPO_LVL_T topo_lvl = han_module->topologic_level; + mca_coll_base_module_scatterv_fn_t scatterv; + mca_coll_base_module_t *sub_module; + int rank, verbosity = 0; + + if (!han_module->enabled) { + return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, + root, comm, han_module->previous_scatterv_module); + } + + /* v collectives do not support message-size based dynamic rules */ + sub_module = get_module(SCATTERV, + MCA_COLL_HAN_ANY_MESSAGE_SIZE, + comm, + han_module); + + /* First errors are always printed by rank 0 */ + rank = ompi_comm_rank(comm); + if( (0 == rank) && (han_module->dynamic_errors < mca_coll_han_component.max_dynamic_errors) ) { + verbosity = 30; + } + + if(NULL == sub_module) { + /* + * No valid collective module from dynamic rules + * nor from mca parameter + */ + han_module->dynamic_errors++; + opal_output_verbose(verbosity, mca_coll_han_component.han_output, + "coll:han:mca_coll_han_scatterv_intra_dynamic " + "HAN did not find any valid module for collective %d (%s) " + "with topological level %d (%s) on communicator (%s/%s). " + "Please check dynamic file/mca parameters\n", + SCATTERV, mca_coll_base_colltype_to_str(SCATTERV), + topo_lvl, mca_coll_han_topo_lvl_to_str(topo_lvl), + ompi_comm_print_cid(comm), comm->c_name); + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "HAN/SCATTERV: No module found for the sub-communicator. " + "Falling back to another component\n")); + scatterv = han_module->previous_scatterv; + sub_module = han_module->previous_scatterv_module; + } else if (NULL == sub_module->coll_scatterv) { + /* + * No valid collective from dynamic rules + * nor from mca parameter + */ + han_module->dynamic_errors++; + opal_output_verbose(verbosity, mca_coll_han_component.han_output, + "coll:han:mca_coll_han_scatterv_intra_dynamic " + "HAN found valid module for collective %d (%s) " + "with topological level %d (%s) on communicator (%s/%s) " + "but this module cannot handle this collective. " + "Please check dynamic file/mca parameters\n", + SCATTERV, mca_coll_base_colltype_to_str(SCATTERV), + topo_lvl, mca_coll_han_topo_lvl_to_str(topo_lvl), + ompi_comm_print_cid(comm), comm->c_name); + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "HAN/SCATTERV: the module found for the sub-" + "communicator cannot handle the SCATTERV operation. " + "Falling back to another component\n")); + scatterv = han_module->previous_scatterv; + sub_module = han_module->previous_scatterv_module; + } else if (GLOBAL_COMMUNICATOR == topo_lvl && sub_module == module) { + /* + * No fallback mechanism activated for this configuration + * sub_module is valid + * sub_module->coll_scatterv is valid and point to this function + * Call han topological collective algorithm + */ + int algorithm_id = get_algorithm(SCATTERV, + MCA_COLL_HAN_ANY_MESSAGE_SIZE, + comm, + han_module); + scatterv = (mca_coll_base_module_scatterv_fn_t)mca_coll_han_algorithm_id_to_fn(SCATTERV, algorithm_id); + if (NULL == scatterv) { /* default behaviour */ + scatterv = mca_coll_han_scatterv_intra; + } + } else { + /* + * If we get here: + * sub_module is valid + * sub_module->coll_scatterv is valid + * They point to the collective to use, according to the dynamic rules + * Selector's job is done, call the collective + */ + scatterv = sub_module->coll_scatterv; + } + + return scatterv(sbuf, scounts, displs, sdtype, + rbuf, rcount, rdtype, + root, comm, sub_module); +} diff --git a/ompi/mca/coll/han/coll_han_module.c b/ompi/mca/coll/han/coll_han_module.c index 782ebcdb760..31ee2d3fb84 100644 --- a/ompi/mca/coll/han/coll_han_module.c +++ b/ompi/mca/coll/han/coll_han_module.c @@ -54,6 +54,7 @@ static void han_module_clear(mca_coll_han_module_t *han_module) CLEAN_PREV_COLL(han_module, gather); CLEAN_PREV_COLL(han_module, gatherv); CLEAN_PREV_COLL(han_module, scatter); + CLEAN_PREV_COLL(han_module, scatterv); han_module->reproducible_reduce = NULL; han_module->reproducible_reduce_module = NULL; @@ -152,6 +153,7 @@ mca_coll_han_module_destruct(mca_coll_han_module_t * module) OBJ_RELEASE_IF_NOT_NULL(module->previous_gatherv_module); OBJ_RELEASE_IF_NOT_NULL(module->previous_reduce_module); OBJ_RELEASE_IF_NOT_NULL(module->previous_scatter_module); + OBJ_RELEASE_IF_NOT_NULL(module->previous_scatterv_module); han_module_clear(module); } @@ -254,7 +256,7 @@ mca_coll_han_comm_query(struct ompi_communicator_t * comm, int *priority) han_module->super.coll_exscan = NULL; han_module->super.coll_reduce_scatter = NULL; han_module->super.coll_scan = NULL; - han_module->super.coll_scatterv = NULL; + han_module->super.coll_scatterv = mca_coll_han_scatterv_intra_dynamic; han_module->super.coll_barrier = mca_coll_han_barrier_intra_dynamic; han_module->super.coll_scatter = mca_coll_han_scatter_intra_dynamic; han_module->super.coll_reduce = mca_coll_han_reduce_intra_dynamic; @@ -316,6 +318,7 @@ han_module_enable(mca_coll_base_module_t * module, HAN_SAVE_PREV_COLL_API(gatherv); HAN_SAVE_PREV_COLL_API(reduce); HAN_SAVE_PREV_COLL_API(scatter); + HAN_SAVE_PREV_COLL_API(scatterv); /* set reproducible algos */ mca_coll_han_reduce_reproducible_decision(comm, module); @@ -332,6 +335,7 @@ han_module_enable(mca_coll_base_module_t * module, OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module); + OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatterv_module); return OMPI_ERROR; } @@ -354,6 +358,7 @@ mca_coll_han_module_disable(mca_coll_base_module_t * module, OBJ_RELEASE_IF_NOT_NULL(han_module->previous_gatherv_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_reduce_module); OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatter_module); + OBJ_RELEASE_IF_NOT_NULL(han_module->previous_scatterv_module); han_module_clear(han_module); diff --git a/ompi/mca/coll/han/coll_han_scatterv.c b/ompi/mca/coll/han/coll_han_scatterv.c new file mode 100644 index 00000000000..df01c9aa8ac --- /dev/null +++ b/ompi/mca/coll/han/coll_han_scatterv.c @@ -0,0 +1,385 @@ +/* + * Copyright (c) 2018-2023 The University of Tennessee and The University + * of Tennessee Research Foundation. All rights + * reserved. + * Copyright (c) 2020 Bull S.A.S. All rights reserved. + * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved. + * Copyright (c) 2022 IBM Corporation. All rights reserved + * Copyright (c) Amazon.com, Inc. or its affiliates. + * All rights reserved. + * $COPYRIGHT$ + * + * Additional copyrights may follow + * + * $HEADER$ + */ + +#include "coll_han.h" +#include "coll_han_trigger.h" +#include "ompi/mca/coll/base/coll_base_functions.h" +#include "ompi/mca/coll/base/coll_tags.h" +#include "ompi/mca/pml/pml.h" + +/* + * @file + * + * This files contains the hierarchical implementations of scatterv. + * Only work with regular situation (each node has equal number of processes). + */ + +/* + * Implement hierarchical Scatterv to optimize large-scale communications where root sends + * non-zero sized messages to multiple nodes and multiple processes per node, i.e. high incast. + * + * In Scatterv, only the root(sender) process has the information of the amount of data, i.e. + * datatype and count, to each receiver process. Therefore node leaders need an additional step to + * collect the expected data from its local peers. In summary, the steps are: + * 1. Root: + * a. If necessary, reorder and sort data (See discussion below) + * b. Send data to other node leaders (Up Iscatterv) + * c. Send data to local peers (Low Scatterv) + * 2. Root's local peers: + * a. Receive data from root. (Low Scatterv) + * 3. Node leaders: + * a. Collect the data transfer sizes(in bytes) from local peers (Low Gather) + * b. Receive data from the root (Up Iscatterv) + * c. Send data to local peers (Low Scatterv) + * 4. Node followers: + * a. Send the data transfer size(in bytes) to the node leader (Low Gather) + * b. Receive data from the node leader (Low Scatterv) + * + * Note on reordering: + * In Up Iscatterv, reordering the send buffer can be avoided if and only if both of following + * conditions are met: + * 1. The data for each node is sorted in the same order as peer local ranks. Note, it is possible + * to send the data in the correct order even if the process are NOT mapped by core. + * 2. In the send buffer, other than the root's node, data destined to the same node are continuous + * - it is ok if data to different nodes has gap. + */ +int mca_coll_han_scatterv_intra(const void *sbuf, const int *scounts, const int *displs, + struct ompi_datatype_t *sdtype, void *rbuf, int rcount, + struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, mca_coll_base_module_t *module) +{ + mca_coll_han_module_t *han_module = (mca_coll_han_module_t *) module; + int w_rank, w_size; /* information about the global communicator */ + int root_low_rank, root_up_rank; /* root ranks for both sub-communicators */ + int err, *vranks, low_rank, low_size, up_rank, up_size, *topo; + int *low_scounts = NULL, *low_displs = NULL; + ompi_request_t *iscatterv_req = NULL; + + /* Create the subcommunicators */ + err = mca_coll_han_comm_create(comm, han_module); + if (OMPI_SUCCESS != err) { + OPAL_OUTPUT_VERBOSE(( + 30, mca_coll_han_component.han_output, + "han cannot handle scatterv with this communicator. Fall back on another component\n")); + /* HAN cannot work with this communicator so fallback on all collectives */ + HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm); + return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, + root, comm, han_module->previous_scatterv_module); + } + + /* Topo must be initialized to know rank distribution which then is used to determine if han can + * be used */ + topo = mca_coll_han_topo_init(comm, han_module, 2); + if (han_module->are_ppn_imbalanced) { + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "han cannot handle scatterv with this communicator (imbalance). Fall " + "back on another component\n")); + /* Put back the fallback collective support and call it once. All + * future calls will then be automatically redirected. + */ + HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, scatterv); + return han_module->previous_scatterv(sbuf, scounts, displs, sdtype, rbuf, rcount, rdtype, + root, comm, han_module->previous_scatterv_module); + } + + w_rank = ompi_comm_rank(comm); + w_size = ompi_comm_size(comm); + + /* create the subcommunicators */ + ompi_communicator_t *low_comm + = han_module->cached_low_comms[mca_coll_han_component.han_scatterv_low_module]; + ompi_communicator_t *up_comm + = han_module->cached_up_comms[mca_coll_han_component.han_scatterv_up_module]; + + /* Get the 'virtual ranks' mapping corresponding to the communicators */ + vranks = han_module->cached_vranks; + /* information about sub-communicators */ + low_rank = ompi_comm_rank(low_comm); + low_size = ompi_comm_size(low_comm); + up_rank = ompi_comm_rank(up_comm); + up_size = ompi_comm_size(up_comm); + /* Get root ranks for low and up comms */ + mca_coll_han_get_ranks(vranks, root, low_size, &root_low_rank, &root_up_rank); + + OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output, + "[%d]: Han scatterv root %d root_low_rank %d root_up_rank %d\n", w_rank, + root, root_low_rank, root_up_rank)); + + err = OMPI_SUCCESS; + /* #################### Root ########################### */ + if (root == w_rank) { + int low_peer, up_peer, w_peer; + int need_bounce_buf = 0, total_up_scounts = 0, *up_displs = NULL, *up_scounts = NULL, + *up_peer_lb = NULL, *up_peer_ub = NULL; + char *reorder_sbuf = (char *) sbuf, *bounce_buf = NULL; + size_t sdsize; + + low_scounts = malloc(low_size * sizeof(int)); + low_displs = malloc(low_size * sizeof(int)); + if (!low_scounts || !low_displs) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto root_out; + } + + for (w_peer = 0; w_peer < w_size; ++w_peer) { + mca_coll_han_get_ranks(vranks, w_peer, low_size, &low_peer, &up_peer); + if (root_up_rank != up_peer) { + /* Not a local peer */ + continue; + } + low_displs[low_peer] = displs[w_peer]; + low_scounts[low_peer] = scounts[w_peer]; + } + + ompi_datatype_type_size(sdtype, &sdsize); + + up_scounts = calloc(up_size, sizeof(int)); + up_displs = malloc(up_size * sizeof(int)); + up_peer_ub = calloc(up_size, sizeof(int)); + if (!up_scounts || !up_displs || !up_peer_ub) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto root_out; + } + + for (up_peer = 0; up_peer < up_size; ++up_peer) { + up_displs[up_peer] = INT_MAX; + } + + /* Calculate send counts for the inter-node scatterv */ + for (w_peer = 0; w_peer < w_size; ++w_peer) { + mca_coll_han_get_ranks(vranks, w_peer, low_size, NULL, &up_peer); + + if (!need_bounce_buf && root_up_rank != up_peer && 0 < scounts[w_peer] && 0 < w_peer + && displs[w_peer] < displs[w_peer - 1]) { + /* Data is not placed in the rank order so reordering is needed */ + need_bounce_buf = 1; + } + + if (root_up_rank == up_peer) { + /* No need to scatter data on the same node again */ + continue; + } + + up_peer_ub[up_peer] = 0 < scounts[w_peer] + && displs[w_peer] + scounts[w_peer] > up_peer_ub[up_peer] + ? displs[w_peer] + scounts[w_peer] + : up_peer_ub[up_peer]; + + up_scounts[up_peer] += scounts[w_peer]; + total_up_scounts += scounts[w_peer]; + + /* Optimize for the happy path */ + up_displs[up_peer] = 0 < scounts[w_peer] && displs[w_peer] < up_displs[up_peer] + ? displs[w_peer] + : up_displs[up_peer]; + } + + /* If the data is not placed contiguously on send buf without overlaping, then we need a + * temp buf without gaps */ + for (up_peer = 0; up_peer < up_size; ++up_peer) { + if (root_up_rank == up_peer) { + continue; + } + if (!need_bounce_buf && 0 < up_scounts[up_peer] + && up_scounts[up_peer] != up_peer_ub[up_peer] - up_displs[up_peer]) { + need_bounce_buf = 1; + break; + } + } + + if (need_bounce_buf) { + bounce_buf = malloc(sdsize * total_up_scounts); + if (!bounce_buf) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto root_out; + } + + /* Calculate displacements for the inter-node scatterv */ + for (up_peer = 0; up_peer < up_size; ++up_peer) { + up_displs[up_peer] = 0 < up_peer ? up_displs[up_peer - 1] + up_scounts[up_peer - 1] + : 0; + } + + /* Use a temp buffer to reorder the send buffer if needed */ + ptrdiff_t offset = 0; + + for (int i = 0; i < w_size; ++i) { + up_peer = topo[2 * i]; + if (root_up_rank == up_peer) { + continue; + } + + w_peer = topo[2 * i + 1]; + + ompi_datatype_copy_content_same_ddt(sdtype, (size_t) scounts[w_peer], + bounce_buf + offset, + (char *) sbuf + + (size_t) displs[w_peer] * sdsize); + offset += sdsize * (size_t) scounts[w_peer]; + } + + reorder_sbuf = bounce_buf; + } + + /* Up Iscatterv */ + up_comm->c_coll->coll_iscatterv((const char *) reorder_sbuf, up_scounts, up_displs, sdtype, + rbuf, rcount, rdtype, root_up_rank, up_comm, &iscatterv_req, + up_comm->c_coll->coll_iscatterv_module); + + /* Low Scatterv */ + low_comm->c_coll->coll_scatterv(sbuf, low_scounts, low_displs, sdtype, rbuf, rcount, rdtype, + root_low_rank, low_comm, + low_comm->c_coll->coll_scatterv_module); + + ompi_request_wait(&iscatterv_req, MPI_STATUS_IGNORE); + + root_out: + if (low_displs) { + free(low_displs); + } + if (low_scounts) { + free(low_scounts); + } + if (up_displs) { + free(up_displs); + } + if (up_scounts) { + free(up_scounts); + } + if (up_peer_lb) { + free(up_peer_lb); + } + if (up_peer_ub) { + free(up_peer_ub); + } + if (bounce_buf) { + free(bounce_buf); + } + + return err; + } + + /* #################### Root's local peers ########################### */ + if (root_up_rank == up_rank) { + /* Low Scatterv */ + low_comm->c_coll->coll_scatterv(NULL, NULL, NULL, NULL, rbuf, rcount, rdtype, root_low_rank, + low_comm, low_comm->c_coll->coll_scatterv_module); + return OMPI_SUCCESS; + } + + size_t rdsize = 0; + uint64_t receive_size = 0; + + ompi_datatype_type_size(rdtype, &rdsize); + receive_size = (uint64_t) rdsize * (uint64_t) rcount; + + /* #################### Other node followers ########################### */ + if (root_low_rank != low_rank) { + /* Low Gather - Gather each local peer's receive data size */ + low_comm->c_coll->coll_gather((const void *) &receive_size, 1, MPI_UINT64_T, NULL, 1, + MPI_UINT64_T, root_low_rank, low_comm, + low_comm->c_coll->coll_gather_module); + /* Low Scatterv */ + low_comm->c_coll->coll_scatterv(NULL, NULL, NULL, NULL, rbuf, rcount, rdtype, root_low_rank, + low_comm, low_comm->c_coll->coll_scatterv_module); + return OMPI_SUCCESS; + } + + /* #################### Node leaders ########################### */ + + uint64_t *low_data_size = NULL; + char *tmp_buf = NULL; + ompi_datatype_t *temptype = MPI_BYTE; + + /* Allocate a temporary array to gather the data size, i.e. data type size x count, + * in bytes from local peers */ + low_data_size = malloc(low_size * sizeof(uint64_t)); + if (!low_data_size) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto node_leader_out; + } + + /* Low Gather - Gather local peers' receive data sizes */ + low_comm->c_coll->coll_gather((const void *) &receive_size, 1, MPI_UINT64_T, + (void *) low_data_size, 1, MPI_UINT64_T, root_low_rank, low_comm, + low_comm->c_coll->coll_gather_module); + + /* Determine if we need to create a custom datatype instead of MPI_BYTE, + * to avoid count(type int) overflow + * TODO: Remove this logic once we adopt large-count, i.e. count will become 64-bit. + */ + int total_up_scount = 0; + size_t rsize = 0, datatype_size = 1, max_data_size = 0; + for (int i = 0; i < low_size; ++i) { + rsize += (size_t) low_data_size[i]; + max_data_size = (size_t) low_data_size[i] > max_data_size ? (size_t) low_data_size[i] + : max_data_size; + } + + if (max_data_size > (size_t) INT_MAX) { + datatype_size = coll_han_utils_gcd(low_data_size, low_size); + } + + low_scounts = malloc(low_size * sizeof(int)); + low_displs = malloc(low_size * sizeof(int)); + tmp_buf = (char *) malloc(rsize); /* tmp_buf is still valid if rsize is 0 */ + if (!tmp_buf || !low_scounts || !low_displs) { + err = OMPI_ERR_OUT_OF_RESOURCE; + goto node_leader_out; + } + + for (int i = 0; i < low_size; ++i) { + low_scounts[i] = (int) ((size_t) low_data_size[i] / datatype_size); + low_displs[i] = i > 0 ? low_displs[i - 1] + low_scounts[i - 1] : 0; + total_up_scount += low_scounts[i]; + } + + if (1 < datatype_size) { + coll_han_utils_create_contiguous_datatype(datatype_size, MPI_BYTE, &temptype); + ompi_datatype_commit(&temptype); + } + + /* Up Iscatterv */ + up_comm->c_coll->coll_iscatterv(NULL, NULL, NULL, NULL, (void *) tmp_buf, total_up_scount, + temptype, root_up_rank, up_comm, &iscatterv_req, + up_comm->c_coll->coll_iscatterv_module); + + ompi_request_wait(&iscatterv_req, MPI_STATUS_IGNORE); + + /* Low Scatterv */ + low_comm->c_coll->coll_scatterv((void *) tmp_buf, low_scounts, low_displs, temptype, rbuf, + rcount, rdtype, root_low_rank, low_comm, + low_comm->c_coll->coll_scatterv_module); + +node_leader_out: + if (low_scounts) { + free(low_scounts); + } + if (low_displs) { + free(low_displs); + } + if (low_data_size) { + free(low_data_size); + } + if (tmp_buf) { + free(tmp_buf); + } + if (MPI_BYTE != temptype) { + ompi_datatype_destroy(&temptype); + } + + return err; +} diff --git a/ompi/mca/coll/han/coll_han_subcomms.c b/ompi/mca/coll/han/coll_han_subcomms.c index 2476e1d2cc9..47ef348975e 100644 --- a/ompi/mca/coll/han/coll_han_subcomms.c +++ b/ompi/mca/coll/han/coll_han_subcomms.c @@ -80,6 +80,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, gather); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, scatterv); /** * HAN is not yet optimized for a single process per node case, we should @@ -108,6 +109,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); han_module->enabled = false; /* entire module set to pass-through from now on */ return OMPI_ERR_NOT_SUPPORTED; } @@ -185,6 +187,7 @@ int mca_coll_han_comm_create_new(struct ompi_communicator_t *comm, HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); OBJ_DESTRUCT(&comm_info); return OMPI_SUCCESS; @@ -241,6 +244,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, gather); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_SAVE_COLLECTIVE(fallbacks, comm, han_module, scatterv); /** * HAN is not yet optimized for a single process per node case, we should @@ -266,6 +270,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); han_module->enabled = false; /* entire module set to pass-through from now on */ return OMPI_ERR_NOT_SUPPORTED; } @@ -355,6 +360,7 @@ int mca_coll_han_comm_create(struct ompi_communicator_t *comm, HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gather); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, gatherv); HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatter); + HAN_SUBCOM_LOAD_COLLECTIVE(fallbacks, comm, han_module, scatterv); OBJ_DESTRUCT(&comm_info); return OMPI_SUCCESS;