diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index b0da532f0a..423beace7f 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -31,6 +31,8 @@ #include #include #include +#include +#include #define RAFT_MPI_TRY(call) \ do { \ @@ -104,8 +106,14 @@ constexpr MPI_Op get_mpi_op(const op_t op) class mpi_comms : public comms_iface { public: - mpi_comms(MPI_Comm comm, const bool owns_mpi_comm) - : owns_mpi_comm_(owns_mpi_comm), mpi_comm_(comm), size_(0), rank_(1), next_request_id_(0) + mpi_comms(MPI_Comm comm, const bool owns_mpi_comm, rmm::cuda_stream_view stream) + : owns_mpi_comm_(owns_mpi_comm), + mpi_comm_(comm), + size_(0), + rank_(1), + status_(stream), + next_request_id_(0), + stream_(stream) { int mpi_is_initialized = 0; RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized)); @@ -121,6 +129,12 @@ class mpi_comms : public comms_iface { RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_)); } + void initialize() + { + status_.set_value_to_zero_async(stream_); + buf_ = status_.data(); + } + virtual ~mpi_comms() { // finalizing NCCL @@ -139,7 +153,13 @@ class mpi_comms : public comms_iface { return std::unique_ptr(new mpi_comms(new_comm, true)); } - void barrier() const { RAFT_MPI_TRY(MPI_Barrier(mpi_comm_)); } + void barrier() const + { + allreduce(buf_, buf_, 1, datatype_t::INT32, op_t::SUM, stream_); + + ASSERT(sync_stream(stream_) == status_t::SUCCESS, + "ERROR: syncStream failed. This can be caused by a failed rank_."); + } void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const { @@ -397,6 +417,10 @@ class mpi_comms : public comms_iface { bool owns_mpi_comm_; MPI_Comm mpi_comm_; + cudaStream_t stream_; + rmm::device_scalar status_; + int32_t* buf_; + ncclComm_t nccl_comm_; int size_; int rank_; diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index d8b0f2090c..1a4cc2fcf9 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -70,11 +71,11 @@ class std_comms : public comms_iface { std::shared_ptr eps, int num_ranks, int rank, - cudaStream_t stream, + rmm::cuda_stream_view stream, bool subcomms_ucp = true) : nccl_comm_(nccl_comm), stream_(stream), - status_(2, stream), + status_(stream), num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(subcomms_ucp), @@ -92,10 +93,10 @@ class std_comms : public comms_iface { * @param rank rank of the current worker * @param stream stream for ordering collective operations */ - std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, cudaStream_t stream) + std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, rmm::cuda_stream_view stream) : nccl_comm_(nccl_comm), stream_(stream), - status_(2, stream), + status_(stream), num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(false) @@ -105,8 +106,14 @@ class std_comms : public comms_iface { void initialize() { - sendbuff_ = status_.data(); - recvbuff_ = status_.data() + 1; + status_.set_value_to_zero_async(stream_); + buf_ = status_.data(); + } + + ~std_comms() + { + requests_in_flight_.clear(); + free_requests_.clear(); } int get_size() const { return num_ranks_; } @@ -179,10 +186,7 @@ class std_comms : public comms_iface { void barrier() const { - RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_)); - RAFT_CUDA_TRY(cudaMemsetAsync(recvbuff_, 1, sizeof(int), stream_)); - - allreduce(sendbuff_, recvbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); + allreduce(buf_, buf_, 1, datatype_t::INT32, op_t::SUM, stream_); ASSERT(sync_stream(stream_) == status_t::SUCCESS, "ERROR: syncStream failed. This can be caused by a failed rank_."); @@ -505,9 +509,9 @@ class std_comms : public comms_iface { ncclComm_t nccl_comm_; cudaStream_t stream_; - int *sendbuff_, *recvbuff_; - rmm::device_uvector status_; + rmm::device_scalar status_; + int32_t* buf_; int num_ranks_; int rank_; diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index bb1e30afc8..3fab04c441 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -26,8 +26,8 @@ using mpi_comms = detail::mpi_comms; inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm) { - auto communicator = - std::make_shared(std::unique_ptr(new mpi_comms(comm, false))); + auto communicator = std::make_shared( + std::unique_ptr(new mpi_comms(comm, false, handle->get_stream()))); handle->set_comms(communicator); }; diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index b4aa72d53e..6fa0f7e37b 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -103,4 +103,4 @@ inline void get_unique_id(char* uid, int size) memcpy(uid, id.internal, size); } }; // namespace comms -}; // end namespace raft \ No newline at end of file +}; // end namespace raft