diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index f06506888c..3bf5438296 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -419,6 +419,10 @@ class mpi_comms : public comms_iface { RAFT_NCCL_TRY(ncclGroupEnd()); } + void group_start() const { RAFT_NCCL_TRY(ncclGroupStart()); } + + void group_end() const { RAFT_NCCL_TRY(ncclGroupEnd()); } + private: bool owns_mpi_comm_; MPI_Comm mpi_comm_; diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 0d54a7e55c..2be1310c50 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -509,6 +509,10 @@ class std_comms : public comms_iface { RAFT_NCCL_TRY(ncclGroupEnd()); } + void group_start() const { RAFT_NCCL_TRY(ncclGroupStart()); } + + void group_end() const { RAFT_NCCL_TRY(ncclGroupEnd()); } + private: ncclComm_t nccl_comm_; cudaStream_t stream_; diff --git a/cpp/include/raft/core/comms.hpp b/cpp/include/raft/core/comms.hpp index bf2f7af777..7f0aa74960 100644 --- a/cpp/include/raft/core/comms.hpp +++ b/cpp/include/raft/core/comms.hpp @@ -209,6 +209,10 @@ class comms_iface { std::vector const& recvoffsets, std::vector const& sources, cudaStream_t stream) const = 0; + + virtual void group_start() const = 0; + + virtual void group_end() const = 0; }; class comms_t { @@ -625,6 +629,20 @@ class comms_t { stream); } + /** + * Multiple collectives & device send/receive operations placed between group_start() and + * group_end() are merged into one big operation. Internally, this function is a wrapper for + * ncclGroupStart(). + */ + void group_start() const { impl_->group_start(); } + + /** + * Multiple collectives & device send/receive operations placed between group_start() and + * group_end() are merged into one big operation. Internally, this function is a wrapper for + * ncclGroupEnd(). + */ + void group_end() const { impl_->group_end(); } + private: std::unique_ptr impl_; };