Skip to content

Commit

Permalink
Adding destructor for std comms and using nccl allreduce for barrier …
Browse files Browse the repository at this point in the history
…in mpi comms (#473)

Closes #281

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)

URL: #473
  • Loading branch information
cjnolet authored Feb 9, 2022
1 parent 9b0208b commit 6963de9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 18 deletions.
30 changes: 27 additions & 3 deletions cpp/include/raft/comms/detail/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <raft/cudart_utils.h>
#include <raft/error.hpp>
#include <raft/handle.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>

#define RAFT_MPI_TRY(call) \
do { \
Expand Down Expand Up @@ -104,8 +106,14 @@ constexpr MPI_Op get_mpi_op(const op_t op)

class mpi_comms : public comms_iface {
public:
mpi_comms(MPI_Comm comm, const bool owns_mpi_comm)
: owns_mpi_comm_(owns_mpi_comm), mpi_comm_(comm), size_(0), rank_(1), next_request_id_(0)
mpi_comms(MPI_Comm comm, const bool owns_mpi_comm, rmm::cuda_stream_view stream)
: owns_mpi_comm_(owns_mpi_comm),
mpi_comm_(comm),
size_(0),
rank_(1),
status_(stream),
next_request_id_(0),
stream_(stream)
{
int mpi_is_initialized = 0;
RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized));
Expand All @@ -121,6 +129,12 @@ class mpi_comms : public comms_iface {
RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_));
}

void initialize()
{
status_.set_value_to_zero_async(stream_);
buf_ = status_.data();
}

virtual ~mpi_comms()
{
// finalizing NCCL
Expand All @@ -139,7 +153,13 @@ class mpi_comms : public comms_iface {
return std::unique_ptr<comms_iface>(new mpi_comms(new_comm, true));
}

void barrier() const { RAFT_MPI_TRY(MPI_Barrier(mpi_comm_)); }
void barrier() const
{
allreduce(buf_, buf_, 1, datatype_t::INT32, op_t::SUM, stream_);

ASSERT(sync_stream(stream_) == status_t::SUCCESS,
"ERROR: syncStream failed. This can be caused by a failed rank_.");
}

void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const
{
Expand Down Expand Up @@ -397,6 +417,10 @@ class mpi_comms : public comms_iface {
bool owns_mpi_comm_;
MPI_Comm mpi_comm_;

cudaStream_t stream_;
rmm::device_scalar<int32_t> status_;
int32_t* buf_;

ncclComm_t nccl_comm_;
int size_;
int rank_;
Expand Down
28 changes: 16 additions & 12 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/comms/detail/util.hpp>

#include <raft/handle.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

#include <raft/error.hpp>
Expand Down Expand Up @@ -70,11 +71,11 @@ class std_comms : public comms_iface {
std::shared_ptr<ucp_ep_h*> eps,
int num_ranks,
int rank,
cudaStream_t stream,
rmm::cuda_stream_view stream,
bool subcomms_ucp = true)
: nccl_comm_(nccl_comm),
stream_(stream),
status_(2, stream),
status_(stream),
num_ranks_(num_ranks),
rank_(rank),
subcomms_ucp_(subcomms_ucp),
Expand All @@ -92,10 +93,10 @@ class std_comms : public comms_iface {
* @param rank rank of the current worker
* @param stream stream for ordering collective operations
*/
std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, cudaStream_t stream)
std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, rmm::cuda_stream_view stream)
: nccl_comm_(nccl_comm),
stream_(stream),
status_(2, stream),
status_(stream),
num_ranks_(num_ranks),
rank_(rank),
subcomms_ucp_(false)
Expand All @@ -105,8 +106,14 @@ class std_comms : public comms_iface {

void initialize()
{
sendbuff_ = status_.data();
recvbuff_ = status_.data() + 1;
status_.set_value_to_zero_async(stream_);
buf_ = status_.data();
}

~std_comms()
{
requests_in_flight_.clear();
free_requests_.clear();
}

int get_size() const { return num_ranks_; }
Expand Down Expand Up @@ -179,10 +186,7 @@ class std_comms : public comms_iface {

void barrier() const
{
RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_));
RAFT_CUDA_TRY(cudaMemsetAsync(recvbuff_, 1, sizeof(int), stream_));

allreduce(sendbuff_, recvbuff_, 1, datatype_t::INT32, op_t::SUM, stream_);
allreduce(buf_, buf_, 1, datatype_t::INT32, op_t::SUM, stream_);

ASSERT(sync_stream(stream_) == status_t::SUCCESS,
"ERROR: syncStream failed. This can be caused by a failed rank_.");
Expand Down Expand Up @@ -505,9 +509,9 @@ class std_comms : public comms_iface {
ncclComm_t nccl_comm_;
cudaStream_t stream_;

int *sendbuff_, *recvbuff_;
rmm::device_uvector<int> status_;
rmm::device_scalar<int32_t> status_;

int32_t* buf_;
int num_ranks_;
int rank_;

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ using mpi_comms = detail::mpi_comms;

inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm)
{
auto communicator =
std::make_shared<comms_t>(std::unique_ptr<comms_iface>(new mpi_comms(comm, false)));
auto communicator = std::make_shared<comms_t>(
std::unique_ptr<comms_iface>(new mpi_comms(comm, false, handle->get_stream())));
handle->set_comms(communicator);
};

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ inline void get_unique_id(char* uid, int size)
memcpy(uid, id.internal, size);
}
}; // namespace comms
}; // end namespace raft
}; // end namespace raft

0 comments on commit 6963de9

Please sign in to comment.