From 27d49f00e5191d0303b39c26de91037d1c18c6de Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 1 Feb 2022 11:11:30 -0500 Subject: [PATCH 1/9] Adding destructor for std comms to clear out requests_in_flight and free_requests --- cpp/include/raft/comms/mpi_comms.hpp | 2 ++ cpp/include/raft/comms/std_comms.hpp | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 432f250b59..c39e1b3324 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -126,6 +126,8 @@ class mpi_comms : public comms_iface { // finalizing NCCL RAFT_NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_)); if (owns_mpi_comm_) { RAFT_MPI_TRY_NO_THROW(MPI_Comm_free(&mpi_comm_)); } + + free_requests_.clear(); } int get_size() const { return size_; } diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 99f15643a1..4067ddfe0b 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -102,6 +102,12 @@ class std_comms : public comms_iface { initialize(); }; + ~std_comms() { + + requests_in_flight_.clear(); + free_requests_.clear(); + } + void initialize() { sendbuff_ = status_.data(); From dd9e15b29d3dd8ac1fabc88bdc686cab50fbb63b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 1 Feb 2022 11:16:22 -0500 Subject: [PATCH 2/9] Style --- cpp/include/raft/comms/std_comms.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 4067ddfe0b..bed60a1d96 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -102,10 +102,10 @@ class std_comms : public comms_iface { initialize(); }; - ~std_comms() { - - requests_in_flight_.clear(); - free_requests_.clear(); + ~std_comms() + { + requests_in_flight_.clear(); + free_requests_.clear(); } void initialize() From e87fe820c117069ad5e632bcdab28fd86e15aedf Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 2 Feb 2022 14:26:40 -0500 Subject: [PATCH 3/9] Updating mpi comms to accept a stream --- cpp/include/raft/comms/mpi_comms.hpp | 32 +++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index c39e1b3324..99ba4174f7 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -31,6 +31,7 @@ #include #include #include +#include #define RAFT_MPI_TRY(call) \ do { \ @@ -104,8 +105,14 @@ constexpr MPI_Op get_mpi_op(const op_t op) class mpi_comms : public comms_iface { public: - mpi_comms(MPI_Comm comm, const bool owns_mpi_comm) - : owns_mpi_comm_(owns_mpi_comm), mpi_comm_(comm), size_(0), rank_(1), next_request_id_(0) + mpi_comms(MPI_Comm comm, const bool owns_mpi_comm, cudaStream_t stream) + : owns_mpi_comm_(owns_mpi_comm), + mpi_comm_(comm), + size_(0), + rank_(1), + stream_(stream), + status(2, stream), + next_request_id_(0) { int mpi_is_initialized = 0; RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized)); @@ -121,6 +128,12 @@ class mpi_comms : public comms_iface { RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_)); } + void initialize() + { + sendbuff_ = status_.data(); + recvbuff_ = status_.data() + 1; + } + virtual ~mpi_comms() { // finalizing NCCL @@ -141,7 +154,16 @@ class mpi_comms : public comms_iface { return std::unique_ptr(new mpi_comms(new_comm, true)); } - void barrier() const { RAFT_MPI_TRY(MPI_Barrier(mpi_comm_)); } + void barrier() const + { + RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_)); + RAFT_CUDA_TRY(cudaMemsetAsync(recvbuff_, 1, sizeof(int), stream_)); + + allreduce(sendbuff_, recvbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); + + ASSERT(sync_stream(stream_) == status_t::SUCCESS, + "ERROR: syncStream failed. This can be caused by a failed rank_."); + } void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const { @@ -429,6 +451,10 @@ class mpi_comms : public comms_iface { private: bool owns_mpi_comm_; MPI_Comm mpi_comm_; + cudaStream_t stream_; + + int *sendbuff_, *recvbuff_; + rmm::device_uvector status_; ncclComm_t nccl_comm_; int size_; From b01468917952b155f57b765e17d284b8086c825c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 2 Feb 2022 15:04:55 -0500 Subject: [PATCH 4/9] removing recvbuff from barrier --- cpp/include/raft/comms/mpi_comms.hpp | 6 ++---- cpp/include/raft/comms/std_comms.hpp | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 99ba4174f7..63643dd6de 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -131,7 +131,6 @@ class mpi_comms : public comms_iface { void initialize() { sendbuff_ = status_.data(); - recvbuff_ = status_.data() + 1; } virtual ~mpi_comms() @@ -157,9 +156,8 @@ class mpi_comms : public comms_iface { void barrier() const { RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_)); - RAFT_CUDA_TRY(cudaMemsetAsync(recvbuff_, 1, sizeof(int), stream_)); - allreduce(sendbuff_, recvbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); + allreduce(sendbuff_, sendbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); ASSERT(sync_stream(stream_) == status_t::SUCCESS, "ERROR: syncStream failed. This can be caused by a failed rank_."); @@ -453,7 +451,7 @@ class mpi_comms : public comms_iface { MPI_Comm mpi_comm_; cudaStream_t stream_; - int *sendbuff_, *recvbuff_; + int *sendbuff_; rmm::device_uvector status_; ncclComm_t nccl_comm_; diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index bed60a1d96..841e58da05 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -111,7 +111,6 @@ class std_comms : public comms_iface { void initialize() { sendbuff_ = status_.data(); - recvbuff_ = status_.data() + 1; } int get_size() const { return num_ranks_; } @@ -185,9 +184,8 @@ class std_comms : public comms_iface { void barrier() const { RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_)); - RAFT_CUDA_TRY(cudaMemsetAsync(recvbuff_, 1, sizeof(int), stream_)); - allreduce(sendbuff_, recvbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); + allreduce(sendbuff_, sendbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); ASSERT(sync_stream(stream_) == status_t::SUCCESS, "ERROR: syncStream failed. This can be caused by a failed rank_."); @@ -541,7 +539,7 @@ class std_comms : public comms_iface { ncclComm_t nccl_comm_; cudaStream_t stream_; - int *sendbuff_, *recvbuff_; + int *sendbuff_; rmm::device_uvector status_; int num_ranks_; From dd5c473512f9fdb760d4b65be1dbc6fac4a4f492 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 2 Feb 2022 15:24:44 -0500 Subject: [PATCH 5/9] Adding destructor back to std comms (lost in merge) --- cpp/include/raft/comms/detail/std_comms.hpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 08b7800fa8..3916c7602b 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -103,6 +103,12 @@ class std_comms : public comms_iface { initialize(); }; + ~std_comms() + { + requests_in_flight_.clear(); + free_requests_.clear(); + } + void initialize() { sendbuff_ = status_.data(); } int get_size() const { return num_ranks_; } From 12d4d982158a00bd854e538bd0ea8f6fae5834bc Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 2 Feb 2022 16:44:15 -0500 Subject: [PATCH 6/9] Updates based on review feedback --- cpp/include/raft/comms/detail/mpi_comms.hpp | 17 ++++++----------- cpp/include/raft/comms/detail/std_comms.hpp | 18 +++++++----------- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index 5460a13640..6a3b13e9f2 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -31,7 +31,8 @@ #include #include #include -#include +#include +#include #define RAFT_MPI_TRY(call) \ do { \ @@ -105,12 +106,12 @@ constexpr MPI_Op get_mpi_op(const op_t op) class mpi_comms : public comms_iface { public: - mpi_comms(MPI_Comm comm, const bool owns_mpi_comm, cudaStream_t stream) + mpi_comms(MPI_Comm comm, const bool owns_mpi_comm, rmm::cuda_stream_view stream) : owns_mpi_comm_(owns_mpi_comm), mpi_comm_(comm), size_(0), rank_(1), - status_(2, stream), + status_(stream), next_request_id_(0), stream_(stream) { @@ -126,11 +127,8 @@ class mpi_comms : public comms_iface { // initializing NCCL RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_)); - initialize(); } - void initialize() { sendbuff_ = status_.data(); } - virtual ~mpi_comms() { // finalizing NCCL @@ -151,9 +149,7 @@ class mpi_comms : public comms_iface { void barrier() const { - RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_)); - - allreduce(sendbuff_, sendbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); + allreduce(status_.data(), status_.data(), 1, datatype_t::INT32, op_t::SUM, stream_); ASSERT(sync_stream(stream_) == status_t::SUCCESS, "ERROR: syncStream failed. This can be caused by a failed rank_."); @@ -447,8 +443,7 @@ class mpi_comms : public comms_iface { MPI_Comm mpi_comm_; cudaStream_t stream_; - int* sendbuff_; - rmm::device_uvector status_; + rmm::device_scalar status_; ncclComm_t nccl_comm_; int size_; diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 3916c7602b..a4ba27ccc8 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -70,11 +71,11 @@ class std_comms : public comms_iface { std::shared_ptr eps, int num_ranks, int rank, - cudaStream_t stream, + rmm::cuda_stream_view stream, bool subcomms_ucp = true) : nccl_comm_(nccl_comm), stream_(stream), - status_(2, stream), + status_(stream), num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(subcomms_ucp), @@ -92,10 +93,10 @@ class std_comms : public comms_iface { * @param rank rank of the current worker * @param stream stream for ordering collective operations */ - std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, cudaStream_t stream) + std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, rmm::cuda_stream_view stream) : nccl_comm_(nccl_comm), stream_(stream), - status_(2, stream), + status_(stream), num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(false) @@ -109,8 +110,6 @@ class std_comms : public comms_iface { free_requests_.clear(); } - void initialize() { sendbuff_ = status_.data(); } - int get_size() const { return num_ranks_; } int get_rank() const { return rank_; } @@ -181,9 +180,7 @@ class std_comms : public comms_iface { void barrier() const { - RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_)); - - allreduce(sendbuff_, sendbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); + allreduce(status_.data(), status_.data(), 1, datatype_t::INT32, op_t::SUM, stream_); ASSERT(sync_stream(stream_) == status_t::SUCCESS, "ERROR: syncStream failed. This can be caused by a failed rank_."); @@ -537,8 +534,7 @@ class std_comms : public comms_iface { ncclComm_t nccl_comm_; cudaStream_t stream_; - int* sendbuff_; - rmm::device_uvector status_; + rmm::device_scalar status_; int num_ranks_; int rank_; From 9df05da722d803c8240aa2787041ce5c9131b580 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 2 Feb 2022 18:39:50 -0500 Subject: [PATCH 7/9] Still fixing --- cpp/include/raft/comms/detail/std_comms.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index a4ba27ccc8..d3e311dd6f 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -82,9 +82,7 @@ class std_comms : public comms_iface { ucp_worker_(ucp_worker), ucp_eps_(eps), next_request_id_(0) - { - initialize(); - }; + {}; /** * @brief constructor for collective-only operation @@ -180,7 +178,8 @@ class std_comms : public comms_iface { void barrier() const { - allreduce(status_.data(), status_.data(), 1, datatype_t::INT32, op_t::SUM, stream_); + void *s = status_.data(); + allreduce(status_.data(), s, 1, datatype_t::INT32, op_t::SUM, stream_); ASSERT(sync_stream(stream_) == status_t::SUCCESS, "ERROR: syncStream failed. This can be caused by a failed rank_."); From b50d37170f05c132141c9c2512aefa1f07555526 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 9 Feb 2022 10:20:13 -0500 Subject: [PATCH 8/9] Updates --- cpp/include/raft/comms/detail/mpi_comms.hpp | 9 ++++++++- cpp/include/raft/comms/detail/std_comms.hpp | 14 +++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index 3cd907a80b..423beace7f 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -129,6 +129,12 @@ class mpi_comms : public comms_iface { RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_)); } + void initialize() + { + status_.set_value_to_zero_async(stream_); + buf_ = status_.data(); + } + virtual ~mpi_comms() { // finalizing NCCL @@ -149,7 +155,7 @@ class mpi_comms : public comms_iface { void barrier() const { - allreduce(status_.data(), status_.data(), 1, datatype_t::INT32, op_t::SUM, stream_); + allreduce(buf_, buf_, 1, datatype_t::INT32, op_t::SUM, stream_); ASSERT(sync_stream(stream_) == status_t::SUCCESS, "ERROR: syncStream failed. This can be caused by a failed rank_."); @@ -413,6 +419,7 @@ class mpi_comms : public comms_iface { cudaStream_t stream_; rmm::device_scalar status_; + int32_t* buf_; ncclComm_t nccl_comm_; int size_; diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index e5e1d292c8..1a4cc2fcf9 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -82,7 +82,9 @@ class std_comms : public comms_iface { ucp_worker_(ucp_worker), ucp_eps_(eps), next_request_id_(0) - {}; + { + initialize(); + }; /** * @brief constructor for collective-only operation @@ -102,6 +104,12 @@ class std_comms : public comms_iface { initialize(); }; + void initialize() + { + status_.set_value_to_zero_async(stream_); + buf_ = status_.data(); + } + ~std_comms() { requests_in_flight_.clear(); @@ -178,8 +186,7 @@ class std_comms : public comms_iface { void barrier() const { - void *s = status_.data(); - allreduce(status_.data(), s, 1, datatype_t::INT32, op_t::SUM, stream_); + allreduce(buf_, buf_, 1, datatype_t::INT32, op_t::SUM, stream_); ASSERT(sync_stream(stream_) == status_t::SUCCESS, "ERROR: syncStream failed. This can be caused by a failed rank_."); @@ -504,6 +511,7 @@ class std_comms : public comms_iface { rmm::device_scalar status_; + int32_t* buf_; int num_ranks_; int rank_; From 638a8bb9d16651b5b4ca45e27bcb547d0df14227 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 9 Feb 2022 15:11:11 -0500 Subject: [PATCH 9/9] Fixing factory function so it's no longer a breaking change --- cpp/include/raft/comms/mpi_comms.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 8e2ad459ee..3fab04c441 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -24,10 +24,10 @@ namespace comms { using mpi_comms = detail::mpi_comms; -inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm, cudaStream_t stream) +inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm) { - auto communicator = - std::make_shared(std::unique_ptr(new mpi_comms(comm, false, stream))); + auto communicator = std::make_shared( + std::unique_ptr(new mpi_comms(comm, false, handle->get_stream()))); handle->set_comms(communicator); };