Skip to content

Commit

Permalink
Add device_send, device_recv, device_sendrecv, device_multicast_sendr…
Browse files Browse the repository at this point in the history
…ecv (#144)

- Undo temporarily exposing a RAFT communication object's private NCCL communicator.
- Add device_send/device_recv (if sending or receiving), device_sendrecv (if sending and receiving), device_multicast_sendrecv (if sending and receiving multiple messages).
- Add test suites for newly added raft::comms_t routines.

Authors:
  - Seunghwa Kang (@seunghwak)

Approvers:
  - Alex Fender (@afender)

URL: #144
  • Loading branch information
seunghwak authored Feb 22, 2021
1 parent 82d3437 commit a3461b2
Show file tree
Hide file tree
Showing 7 changed files with 504 additions and 15 deletions.
125 changes: 116 additions & 9 deletions cpp/include/raft/comms/comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@

#include <raft/cudart_utils.h>

// FIXME: for get_nccl_comm(), should be removed
#include <nccl.h>

#include <memory>
#include <vector>

namespace raft {
namespace comms {
Expand Down Expand Up @@ -96,9 +94,6 @@ class comms_iface {
virtual int get_size() const = 0;
virtual int get_rank() const = 0;

// FIXME: a temporary hack, should be removed
virtual ncclComm_t get_nccl_comm() const = 0;

virtual std::unique_ptr<comms_iface> comm_split(int color, int key) const = 0;
virtual void barrier() const = 0;

Expand Down Expand Up @@ -142,6 +137,25 @@ class comms_iface {
virtual void reducescatter(const void* sendbuff, void* recvbuff,
size_t recvcount, datatype_t datatype, op_t op,
cudaStream_t stream) const = 0;

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
virtual void device_send(const void* buf, size_t size, int dest,
cudaStream_t stream) const = 0;

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
virtual void device_recv(void* buf, size_t size, int source,
cudaStream_t stream) const = 0;

virtual void device_sendrecv(const void* sendbuf, size_t sendsize, int dest,
void* recvbuf, size_t recvsize, int source,
cudaStream_t stream) const = 0;

virtual void device_multicast_sendrecv(
const void* sendbuf, std::vector<size_t> const& sendsizes,
std::vector<size_t> const& sendoffsets, std::vector<int> const& dests,
void* recvbuf, std::vector<size_t> const& recvsizes,
std::vector<size_t> const& recvoffsets, std::vector<int> const& sources,
cudaStream_t stream) const = 0;
};

class comms_t {
Expand All @@ -166,9 +180,6 @@ class comms_t {
*/
int get_rank() const { return impl_->get_rank(); }

// FIXME: a temporary hack, should be removed
ncclComm_t get_nccl_comm() const { return impl_->get_nccl_comm(); }

/**
* Splits the current communicator clique into sub-cliques matching
* the given color and key
Expand Down Expand Up @@ -380,6 +391,102 @@ class comms_t {
get_type<value_t>(), op, stream);
}

/**
* Performs a point-to-point send
*
* if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock.
*
* @tparam value_t the type of data to send
* @param buf pointer to array of data to send
* @param size number of elements in buf
* @param dest destination rank
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void device_send(const value_t* buf, size_t size, int dest,
cudaStream_t stream) const {
impl_->device_send(static_cast<const void*>(buf), size * sizeof(value_t),
dest, stream);
}

/**
* Performs a point-to-point receive
*
* if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock.
*
* @tparam value_t the type of data to be received
* @param buf pointer to (initialized) array that will hold received data
* @param size number of elements in buf
* @param source source rank
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void device_recv(value_t* buf, size_t size, int source,
cudaStream_t stream) const {
impl_->device_recv(static_cast<void*>(buf), size * sizeof(value_t), source,
stream);
}

/**
* Performs a point-to-point send/receive
*
* @tparam value_t the type of data to be sent & received
* @param sendbuf pointer to array of data to send
* @param sendsize number of elements in sendbuf
* @param dest destination rank
* @param recvbuf pointer to (initialized) array that will hold received data
* @param recvsize number of elements in recvbuf
* @param source source rank
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void device_sendrecv(const value_t* sendbuf, size_t sendsize, int dest,
value_t* recvbuf, size_t recvsize, int source,
cudaStream_t stream) const {
impl_->device_sendrecv(
static_cast<const void*>(sendbuf), sendsize * sizeof(value_t), dest,
static_cast<void*>(recvbuf), recvsize * sizeof(value_t), source, stream);
}

/**
* Performs a multicast send/receive
*
* @tparam value_t the type of data to be sent & received
* @param sendbuf pointer to array of data to send
* @param sendsizes numbers of elements to send
* @param sendoffsets offsets in a number of elements from sendbuf
* @param dest destination ranks
* @param recvbuf pointer to (initialized) array that will hold received data
* @param recvsizes numbers of elements to recv
* @param recvoffsets offsets in a number of elements from recvbuf
* @param sources source ranks
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void device_multicast_sendrecv(
const value_t* sendbuf, std::vector<size_t> const& sendsizes,
std::vector<size_t> const& sendoffsets, std::vector<int> const& dests,
value_t* recvbuf, std::vector<size_t> const& recvsizes,
std::vector<size_t> const& recvoffsets, std::vector<int> const& sources,
cudaStream_t stream) const {
auto sendbytesizes = sendsizes;
auto sendbyteoffsets = sendoffsets;
for (size_t i = 0; i < sendsizes.size(); ++i) {
sendbytesizes[i] *= sizeof(value_t);
sendbyteoffsets[i] *= sizeof(value_t);
}
auto recvbytesizes = recvsizes;
auto recvbyteoffsets = recvoffsets;
for (size_t i = 0; i < recvsizes.size(); ++i) {
recvbytesizes[i] *= sizeof(value_t);
recvbyteoffsets[i] *= sizeof(value_t);
}
impl_->device_multicast_sendrecv(static_cast<const void*>(sendbuf),
sendbytesizes, sendbyteoffsets, dests,
static_cast<void*>(recvbuf), recvbytesizes,
recvbyteoffsets, sources, stream);
}

private:
std::unique_ptr<comms_iface> impl_;
};
Expand Down
48 changes: 45 additions & 3 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,6 @@ class mpi_comms : public comms_iface {

int get_rank() const { return rank_; }

// FIXME: a temporary hack, should be removed
ncclComm_t get_nccl_comm() const { return nccl_comm_; }

std::unique_ptr<comms_iface> comm_split(int color, int key) const {
MPI_Comm new_comm;
MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm));
Expand Down Expand Up @@ -304,6 +301,51 @@ class mpi_comms : public comms_iface {
}
};

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_send(const void* buf, size_t size, int dest,
cudaStream_t stream) const {
NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream));
}

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_recv(void* buf, size_t size, int source,
cudaStream_t stream) const {
NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream));
}

void device_sendrecv(const void* sendbuf, size_t sendsize, int dest,
void* recvbuf, size_t recvsize, int source,
cudaStream_t stream) const {
// ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock
NCCL_TRY(ncclGroupStart());
NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream));
NCCL_TRY(
ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void device_multicast_sendrecv(const void* sendbuf,
std::vector<size_t> const& sendsizes,
std::vector<size_t> const& sendoffsets,
std::vector<int> const& dests, void* recvbuf,
std::vector<size_t> const& recvsizes,
std::vector<size_t> const& recvoffsets,
std::vector<int> const& sources,
cudaStream_t stream) const {
// ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock
NCCL_TRY(ncclGroupStart());
for (size_t i = 0; i < sendsizes.size(); ++i) {
NCCL_TRY(ncclSend(static_cast<const char*>(sendbuf) + sendoffsets[i],
sendsizes[i], ncclUint8, dests[i], nccl_comm_, stream));
}
for (size_t i = 0; i < recvsizes.size(); ++i) {
NCCL_TRY(ncclRecv(static_cast<char*>(recvbuf) + recvoffsets[i],
recvsizes[i], ncclUint8, sources[i], nccl_comm_,
stream));
}
NCCL_TRY(ncclGroupEnd());
}

private:
bool owns_mpi_comm_;
MPI_Comm mpi_comm_;
Expand Down
48 changes: 45 additions & 3 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,6 @@ class std_comms : public comms_iface {

int get_rank() const { return rank_; }

// FIXME: a temporary hack, should be removed
ncclComm_t get_nccl_comm() const { return nccl_comm_; }

std::unique_ptr<comms_iface> comm_split(int color, int key) const {
mr::device::buffer<int> d_colors(device_allocator_, stream_, get_size());
mr::device::buffer<int> d_keys(device_allocator_, stream_, get_size());
Expand Down Expand Up @@ -418,6 +415,51 @@ class std_comms : public comms_iface {
}
}

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_send(const void *buf, size_t size, int dest,
cudaStream_t stream) const {
NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream));
}

// if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
void device_recv(void *buf, size_t size, int source,
cudaStream_t stream) const {
NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream));
}

void device_sendrecv(const void *sendbuf, size_t sendsize, int dest,
void *recvbuf, size_t recvsize, int source,
cudaStream_t stream) const {
// ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock
NCCL_TRY(ncclGroupStart());
NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream));
NCCL_TRY(
ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void device_multicast_sendrecv(const void *sendbuf,
std::vector<size_t> const &sendsizes,
std::vector<size_t> const &sendoffsets,
std::vector<int> const &dests, void *recvbuf,
std::vector<size_t> const &recvsizes,
std::vector<size_t> const &recvoffsets,
std::vector<int> const &sources,
cudaStream_t stream) const {
// ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock
NCCL_TRY(ncclGroupStart());
for (size_t i = 0; i < sendsizes.size(); ++i) {
NCCL_TRY(ncclSend(static_cast<const char *>(sendbuf) + sendoffsets[i],
sendsizes[i], ncclUint8, dests[i], nccl_comm_, stream));
}
for (size_t i = 0; i < recvsizes.size(); ++i) {
NCCL_TRY(ncclRecv(static_cast<char *>(recvbuf) + recvoffsets[i],
recvsizes[i], ncclUint8, sources[i], nccl_comm_,
stream));
}
NCCL_TRY(ncclGroupEnd());
}

private:
ncclComm_t nccl_comm_;
cudaStream_t stream_;
Expand Down
Loading

0 comments on commit a3461b2

Please sign in to comment.