Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add gather & gatherv to raft::comms::comms_t #114

Merged
merged 6 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions cpp/include/raft/comms/comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ class comms_iface {
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, cudaStream_t stream) const = 0;

virtual void gather(const void* sendbuff, void* recvbuff, size_t sendcount,
datatype_t datatype, int root,
cudaStream_t stream) const = 0;

virtual void gatherv(const void* sendbuf, void* recvbuf, size_t sendcount,
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, int root,
cudaStream_t stream) const = 0;

virtual void reducescatter(const void* sendbuff, void* recvbuff,
size_t recvcount, datatype_t datatype, op_t op,
cudaStream_t stream) const = 0;
Expand Down Expand Up @@ -316,6 +325,45 @@ class comms_t {
get_type<value_t>(), stream);
}

/**
* Gathers data from each rank onto all ranks
* @tparam value_t datatype of underlying buffers
* @param sendbuff buffer containing data to gather
* @param recvbuff buffer containing gathered data from all ranks
* @param sendcount number of elements in send buffer
* @param root rank to store the results
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void gather(const value_t* sendbuff, value_t* recvbuff, size_t sendcount,
int root, cudaStream_t stream) const {
impl_->gather(static_cast<const void*>(sendbuff),
static_cast<void*>(recvbuff), sendcount, get_type<value_t>(),
root, stream);
}

/**
* Gathers data from all ranks and delivers to combined data to all ranks
* @param value_t datatype of underlying buffers
* @param sendbuff buffer containing data to send
* @param recvbuff buffer containing data to receive
* @param sendcount number of elements in send buffer
* @param recvcounts pointer to an array (of length num_ranks size) containing the number of
* elements that are to be received from each rank
* @param displs pointer to an array (of length num_ranks size) to specify the displacement
* (relative to recvbuf) at which to place the incoming data from each rank
* @param root rank to store the results
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void gatherv(const value_t* sendbuf, value_t* recvbuf, size_t sendcount,
const size_t* recvcounts, const size_t* displs, int root,
cudaStream_t stream) const {
impl_->gatherv(static_cast<const void*>(sendbuf),
static_cast<void*>(recvbuf), sendcount, recvcounts, displs,
get_type<value_t>(), root, stream);
}

/**
* Reduces data from all ranks then scatters the result across ranks
* @tparam value_t datatype of underlying buffers
Expand Down
31 changes: 31 additions & 0 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,37 @@ class mpi_comms : public comms_iface {
}
}

void gather(const void* sendbuff, void* recvbuff, size_t sendcount,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(recvbuff + sendcount * i * dtype_size, sendcount,
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void gatherv(const void* sendbuf, void* recvbuf, size_t sendcount,
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(recvbuff + displs[r] * dtype_size, recvcounts[r],
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount,
datatype_t datatype, op_t op, cudaStream_t stream) const {
NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount,
Expand Down
31 changes: 31 additions & 0 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,37 @@ class std_comms : public comms_iface {
}
}

void gather(const void *sendbuff, void *recvbuff, size_t sendcount,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(recvbuff + sendcount * i * dtype_size, sendcount,
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void gatherv(const void *sendbuf, void *recvbuf, size_t sendcount,
const size_t *recvcounts, const size_t *displs,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(recvbuff + displs[r] * dtype_size, recvcounts[r],
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void reducescatter(const void *sendbuff, void *recvbuff, size_t recvcount,
datatype_t datatype, op_t op, cudaStream_t stream) const {
NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount,
Expand Down