diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index ec1101032e..f06506888c 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -275,8 +275,11 @@ class mpi_comms : public comms_iface { datatype_t datatype, cudaStream_t stream) const { + RAFT_EXPECTS(size_ <= 2048, + "# NCCL operations between ncclGroupStart & ncclGroupEnd shouldn't exceed 2048."); // From: "An Empirical Evaluation of Allgatherv on Multi-GPU Systems" - // https://arxiv.org/pdf/1812.05964.pdf Listing 1 on page 4. + RAFT_NCCL_TRY(ncclGroupStart()); for (int root = 0; root < size_; ++root) { RAFT_NCCL_TRY( ncclBroadcast(sendbuf, @@ -287,6 +290,7 @@ class mpi_comms : public comms_iface { nccl_comm_, stream)); } + RAFT_NCCL_TRY(ncclGroupEnd()); } void gather(const void* sendbuff, diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 1a4cc2fcf9..0d54a7e55c 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -367,6 +367,9 @@ class std_comms : public comms_iface { { // From: "An Empirical Evaluation of Allgatherv on Multi-GPU Systems" - // https://arxiv.org/pdf/1812.05964.pdf Listing 1 on page 4. + RAFT_EXPECTS(num_ranks_ <= 2048, + "# NCCL operations between ncclGroupStart & ncclGroupEnd shouldn't exceed 2048."); + RAFT_NCCL_TRY(ncclGroupStart()); for (int root = 0; root < num_ranks_; ++root) { size_t dtype_size = get_datatype_size(datatype); RAFT_NCCL_TRY(ncclBroadcast(sendbuf, @@ -377,6 +380,7 @@ class std_comms : public comms_iface { nccl_comm_, stream)); } + RAFT_NCCL_TRY(ncclGroupEnd()); } void gather(const void* sendbuff,