Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding destructor for std comms and using nccl allreduce for barrier in mpi comms #473

Merged
merged 11 commits into from
Feb 9, 2022
30 changes: 27 additions & 3 deletions cpp/include/raft/comms/detail/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
#include <raft/cudart_utils.h>
#include <raft/error.hpp>
#include <raft/handle.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>

#define RAFT_MPI_TRY(call) \
do { \
Expand Down Expand Up @@ -104,8 +106,14 @@ constexpr MPI_Op get_mpi_op(const op_t op)

class mpi_comms : public comms_iface {
public:
mpi_comms(MPI_Comm comm, const bool owns_mpi_comm)
: owns_mpi_comm_(owns_mpi_comm), mpi_comm_(comm), size_(0), rank_(1), next_request_id_(0)
mpi_comms(MPI_Comm comm, const bool owns_mpi_comm, rmm::cuda_stream_view stream)
: owns_mpi_comm_(owns_mpi_comm),
mpi_comm_(comm),
size_(0),
rank_(1),
status_(stream),
next_request_id_(0),
stream_(stream)
{
int mpi_is_initialized = 0;
RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized));
Expand All @@ -121,6 +129,12 @@ class mpi_comms : public comms_iface {
RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_));
}

void initialize()
{
status_.set_value_to_zero_async(stream_);
buf_ = status_.data();
}

virtual ~mpi_comms()
{
// finalizing NCCL
Expand All @@ -139,7 +153,13 @@ class mpi_comms : public comms_iface {
return std::unique_ptr<comms_iface>(new mpi_comms(new_comm, true));
}

void barrier() const { RAFT_MPI_TRY(MPI_Barrier(mpi_comm_)); }
void barrier() const
{
allreduce(buf_, buf_, 1, datatype_t::INT32, op_t::SUM, stream_);

ASSERT(sync_stream(stream_) == status_t::SUCCESS,
"ERROR: syncStream failed. This can be caused by a failed rank_.");
}

void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const
{
Expand Down Expand Up @@ -397,6 +417,10 @@ class mpi_comms : public comms_iface {
bool owns_mpi_comm_;
MPI_Comm mpi_comm_;

cudaStream_t stream_;
rmm::device_scalar<int32_t> status_;
int32_t* buf_;

ncclComm_t nccl_comm_;
int size_;
int rank_;
Expand Down
28 changes: 16 additions & 12 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/comms/detail/util.hpp>

#include <raft/handle.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

#include <raft/error.hpp>
Expand Down Expand Up @@ -70,11 +71,11 @@ class std_comms : public comms_iface {
std::shared_ptr<ucp_ep_h*> eps,
int num_ranks,
int rank,
cudaStream_t stream,
rmm::cuda_stream_view stream,
bool subcomms_ucp = true)
: nccl_comm_(nccl_comm),
stream_(stream),
status_(2, stream),
status_(stream),
num_ranks_(num_ranks),
rank_(rank),
subcomms_ucp_(subcomms_ucp),
Expand All @@ -92,10 +93,10 @@ class std_comms : public comms_iface {
* @param rank rank of the current worker
* @param stream stream for ordering collective operations
*/
std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, cudaStream_t stream)
std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, rmm::cuda_stream_view stream)
: nccl_comm_(nccl_comm),
stream_(stream),
status_(2, stream),
status_(stream),
num_ranks_(num_ranks),
rank_(rank),
subcomms_ucp_(false)
Expand All @@ -105,8 +106,14 @@ class std_comms : public comms_iface {

void initialize()
{
sendbuff_ = status_.data();
recvbuff_ = status_.data() + 1;
status_.set_value_to_zero_async(stream_);
buf_ = status_.data();
}

~std_comms()
{
requests_in_flight_.clear();
free_requests_.clear();
}

int get_size() const { return num_ranks_; }
Expand Down Expand Up @@ -179,10 +186,7 @@ class std_comms : public comms_iface {

void barrier() const
{
RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_));
RAFT_CUDA_TRY(cudaMemsetAsync(recvbuff_, 1, sizeof(int), stream_));

allreduce(sendbuff_, recvbuff_, 1, datatype_t::INT32, op_t::SUM, stream_);
allreduce(buf_, buf_, 1, datatype_t::INT32, op_t::SUM, stream_);

ASSERT(sync_stream(stream_) == status_t::SUCCESS,
"ERROR: syncStream failed. This can be caused by a failed rank_.");
Expand Down Expand Up @@ -505,9 +509,9 @@ class std_comms : public comms_iface {
ncclComm_t nccl_comm_;
cudaStream_t stream_;

int *sendbuff_, *recvbuff_;
rmm::device_uvector<int> status_;
rmm::device_scalar<int32_t> status_;

int32_t* buf_;
int num_ranks_;
int rank_;

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ using mpi_comms = detail::mpi_comms;

inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm)
{
auto communicator =
std::make_shared<comms_t>(std::unique_ptr<comms_iface>(new mpi_comms(comm, false)));
auto communicator = std::make_shared<comms_t>(
std::unique_ptr<comms_iface>(new mpi_comms(comm, false, handle->get_stream())));
handle->set_comms(communicator);
};

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ inline void get_unique_id(char* uid, int size)
memcpy(uid, id.internal, size);
}
}; // namespace comms
}; // end namespace raft
}; // end namespace raft