From fd2595c6281cb6136ac4224117d3b99bd8a4c29f Mon Sep 17 00:00:00 2001 From: Seunghwa Kang <45857425+seunghwak@users.noreply.github.com> Date: Mon, 18 Jul 2022 18:35:29 -0700 Subject: [PATCH] Add wrapper functions for ncclGroupStart() and ncclGroupEnd() (#742) This PR adds group_start() and group_end() functions. These functions internally call ncclGroupStart() and ncclGroupEnd(), respectively. Authors: - Seunghwa Kang (https://github.com/seunghwak) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) URL: https://github.com/rapidsai/raft/pull/742 --- cpp/include/raft/comms/detail/mpi_comms.hpp | 4 ++++ cpp/include/raft/comms/detail/std_comms.hpp | 4 ++++ cpp/include/raft/core/comms.hpp | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index f06506888c..3bf5438296 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -419,6 +419,10 @@ class mpi_comms : public comms_iface { RAFT_NCCL_TRY(ncclGroupEnd()); } + void group_start() const { RAFT_NCCL_TRY(ncclGroupStart()); } + + void group_end() const { RAFT_NCCL_TRY(ncclGroupEnd()); } + private: bool owns_mpi_comm_; MPI_Comm mpi_comm_; diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 0d54a7e55c..2be1310c50 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -509,6 +509,10 @@ class std_comms : public comms_iface { RAFT_NCCL_TRY(ncclGroupEnd()); } + void group_start() const { RAFT_NCCL_TRY(ncclGroupStart()); } + + void group_end() const { RAFT_NCCL_TRY(ncclGroupEnd()); } + private: ncclComm_t nccl_comm_; cudaStream_t stream_; diff --git a/cpp/include/raft/core/comms.hpp b/cpp/include/raft/core/comms.hpp index bf2f7af777..7f0aa74960 100644 --- a/cpp/include/raft/core/comms.hpp +++ b/cpp/include/raft/core/comms.hpp @@ -209,6 +209,10 @@ class comms_iface { std::vector const& recvoffsets, std::vector const& sources, cudaStream_t stream) const = 0; + + virtual void group_start() const = 0; + + virtual void group_end() const = 0; }; class comms_t { @@ -625,6 +629,20 @@ class comms_t { stream); } + /** + * Multiple collectives & device send/receive operations placed between group_start() and + * group_end() are merged into one big operation. Internally, this function is a wrapper for + * ncclGroupStart(). + */ + void group_start() const { impl_->group_start(); } + + /** + * Multiple collectives & device send/receive operations placed between group_start() and + * group_end() are merged into one big operation. Internally, this function is a wrapper for + * ncclGroupEnd(). + */ + void group_end() const { impl_->group_end(); } + private: std::unique_ptr impl_; };