Skip to content

Commit

Permalink
Merge pull request #114 from seunghwak/fea_comm_gather_gatherv
Browse files Browse the repository at this point in the history
[REVIEW] Add gather & gatherv to raft::comms::comms_t
  • Loading branch information
cjnolet authored Feb 3, 2021
2 parents b5d5044 + c314a6e commit 4a79adc
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 6 deletions.
48 changes: 48 additions & 0 deletions cpp/include/raft/comms/comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ class comms_iface {
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, cudaStream_t stream) const = 0;

virtual void gather(const void* sendbuff, void* recvbuff, size_t sendcount,
datatype_t datatype, int root,
cudaStream_t stream) const = 0;

virtual void gatherv(const void* sendbuf, void* recvbuf, size_t sendcount,
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, int root,
cudaStream_t stream) const = 0;

virtual void reducescatter(const void* sendbuff, void* recvbuff,
size_t recvcount, datatype_t datatype, op_t op,
cudaStream_t stream) const = 0;
Expand Down Expand Up @@ -316,6 +325,45 @@ class comms_t {
get_type<value_t>(), stream);
}

/**
* Gathers data from each rank onto all ranks
* @tparam value_t datatype of underlying buffers
* @param sendbuff buffer containing data to gather
* @param recvbuff buffer containing gathered data from all ranks
* @param sendcount number of elements in send buffer
* @param root rank to store the results
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void gather(const value_t* sendbuff, value_t* recvbuff, size_t sendcount,
int root, cudaStream_t stream) const {
impl_->gather(static_cast<const void*>(sendbuff),
static_cast<void*>(recvbuff), sendcount, get_type<value_t>(),
root, stream);
}

/**
* Gathers data from all ranks and delivers to combined data to all ranks
* @param value_t datatype of underlying buffers
* @param sendbuff buffer containing data to send
* @param recvbuff buffer containing data to receive
* @param sendcount number of elements in send buffer
* @param recvcounts pointer to an array (of length num_ranks size) containing the number of
* elements that are to be received from each rank
* @param displs pointer to an array (of length num_ranks size) to specify the displacement
* (relative to recvbuf) at which to place the incoming data from each rank
* @param root rank to store the results
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void gatherv(const value_t* sendbuf, value_t* recvbuf, size_t sendcount,
const size_t* recvcounts, const size_t* displs, int root,
cudaStream_t stream) const {
impl_->gatherv(static_cast<const void*>(sendbuf),
static_cast<void*>(recvbuf), sendcount, recvcounts, displs,
get_type<value_t>(), root, stream);
}

/**
* Reduces data from all ranks then scatters the result across ranks
* @tparam value_t datatype of underlying buffers
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,39 @@ class mpi_comms : public comms_iface {
}
}

void gather(const void* sendbuff, void* recvbuff, size_t sendcount,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(
static_cast<char*>(recvbuff) + sendcount * r * dtype_size, sendcount,
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void gatherv(const void* sendbuff, void* recvbuff, size_t sendcount,
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(static_cast<char*>(recvbuff) + displs[r] * dtype_size,
recvcounts[r], get_nccl_datatype(datatype), r,
nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount,
datatype_t datatype, op_t op, cudaStream_t stream) const {
NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount,
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,39 @@ class std_comms : public comms_iface {
}
}

void gather(const void *sendbuff, void *recvbuff, size_t sendcount,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(
static_cast<char *>(recvbuff) + sendcount * r * dtype_size, sendcount,
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void gatherv(const void *sendbuff, void *recvbuff, size_t sendcount,
const size_t *recvcounts, const size_t *displs,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(
static_cast<char *>(recvbuff) + displs[r] * dtype_size, recvcounts[r],
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void reducescatter(const void *sendbuff, void *recvbuff, size_t recvcount,
datatype_t datatype, op_t op, cudaStream_t stream) const {
NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount,
Expand Down
102 changes: 96 additions & 6 deletions cpp/include/raft/comms/test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

#pragma once

#include <iostream>
#include <raft/comms/comms.hpp>
#include <raft/handle.hpp>
#include <raft/mr/device/buffer.hpp>

#include <iostream>
#include <numeric>

namespace raft {
namespace comms {

Expand Down Expand Up @@ -155,26 +157,114 @@ bool test_collective_allgather(const handle_t &handle, int root) {
return true;
}

bool test_collective_gather(const handle_t &handle, int root) {
comms_t const &communicator = handle.get_comms();

int const send = communicator.get_rank();

cudaStream_t stream = handle.get_stream();

raft::mr::device::buffer<int> temp_d(handle.get_device_allocator(), stream);
temp_d.resize(1, stream);

raft::mr::device::buffer<int> recv_d(
handle.get_device_allocator(), stream,
communicator.get_rank() == root ? communicator.get_size() : 0);

CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int),
cudaMemcpyHostToDevice, stream));

communicator.gather(temp_d.data(), recv_d.data(), 1, root, stream);
communicator.sync_stream(stream);

if (communicator.get_rank() == root) {
std::vector<int> temp_h(communicator.get_size(), 0);
CUDA_CHECK(cudaMemcpyAsync(temp_h.data(), recv_d.data(),
sizeof(int) * temp_h.size(),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));

for (int i = 0; i < communicator.get_size(); i++) {
if (temp_h[i] != i) return false;
}
}
return true;
}

bool test_collective_gatherv(const handle_t &handle, int root) {
comms_t const &communicator = handle.get_comms();

std::vector<size_t> sendcounts(communicator.get_size());
std::iota(sendcounts.begin(), sendcounts.end(), size_t{1});
std::vector<size_t> displacements(communicator.get_size() + 1, 0);
std::partial_sum(sendcounts.begin(), sendcounts.end(),
displacements.begin() + 1);

std::vector<int> sends(displacements[communicator.get_rank() + 1] -
displacements[communicator.get_rank()],
communicator.get_rank());

cudaStream_t stream = handle.get_stream();

raft::mr::device::buffer<int> temp_d(handle.get_device_allocator(), stream);
temp_d.resize(sends.size(), stream);

raft::mr::device::buffer<int> recv_d(
handle.get_device_allocator(), stream,
communicator.get_rank() == root ? displacements.back() : 0);

CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(),
sends.size() * sizeof(int), cudaMemcpyHostToDevice,
stream));

communicator.gatherv(
temp_d.data(), recv_d.data(), temp_d.size(),
communicator.get_rank() == root ? sendcounts.data()
: static_cast<size_t *>(nullptr),
communicator.get_rank() == root ? displacements.data()
: static_cast<size_t *>(nullptr),
root, stream);
communicator.sync_stream(stream);

if (communicator.get_rank() == root) {
std::vector<int> temp_h(displacements.back(), 0);
CUDA_CHECK(cudaMemcpyAsync(temp_h.data(), recv_d.data(),
sizeof(int) * displacements.back(),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));

for (int i = 0; i < communicator.get_size(); i++) {
if (std::count_if(temp_h.begin() + displacements[i],
temp_h.begin() + displacements[i + 1],
[i](auto val) { return val != i; }) != 0) {
return false;
}
}
}
return true;
}

bool test_collective_reducescatter(const handle_t &handle, int root) {
comms_t const &communicator = handle.get_comms();

int const send = 1;
std::vector<int> sends(communicator.get_size(), 1);

cudaStream_t stream = handle.get_stream();

raft::mr::device::buffer<int> temp_d(handle.get_device_allocator(), stream,
1);
sends.size());
raft::mr::device::buffer<int> recv_d(handle.get_device_allocator(), stream,
1);

CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int),
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(),
sends.size() * sizeof(int), cudaMemcpyHostToDevice,
stream));

communicator.reducescatter(temp_d.data(), recv_d.data(), 1, op_t::SUM,
stream);
communicator.sync_stream(stream);
int temp_h = -1; // Verify more than one byte is being sent
CUDA_CHECK(cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int),
CUDA_CHECK(cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
communicator.barrier();
Expand Down
2 changes: 2 additions & 0 deletions python/raft/dask/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .comms_utils import perform_test_comms_allreduce
from .comms_utils import perform_test_comms_send_recv
from .comms_utils import perform_test_comms_allgather
from .comms_utils import perform_test_comms_gather
from .comms_utils import perform_test_comms_gatherv
from .comms_utils import perform_test_comms_bcast
from .comms_utils import perform_test_comms_reduce
from .comms_utils import perform_test_comms_reducescatter
Expand Down
32 changes: 32 additions & 0 deletions python/raft/dask/common/comms_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ cdef extern from "raft/comms/test.hpp" namespace "raft::comms":
bool test_collective_broadcast(const handle_t &h, int root) except +
bool test_collective_reduce(const handle_t &h, int root) except +
bool test_collective_allgather(const handle_t &h, int root) except +
bool test_collective_gather(const handle_t &h, int root) except +
bool test_collective_gatherv(const handle_t &h, int root) except +
bool test_collective_reducescatter(const handle_t &h, int root) except +
bool test_pointToPoint_simple_send_recv(const handle_t &h,
int numTrials) except +
Expand Down Expand Up @@ -131,6 +133,36 @@ def perform_test_comms_allgather(handle, root):
return test_collective_allgather(deref(h), root)


def perform_test_comms_gather(handle, root):
"""
Performs a gather on the current worker
Parameters
----------
handle : raft.common.Handle
handle containing comms_t to use
root : int
Rank of the root worker
"""
cdef const handle_t* h = <handle_t*><size_t>handle.getHandle()
return test_collective_gather(deref(h), root)


def perform_test_comms_gatherv(handle, root):
"""
Performs a gatherv on the current worker
Parameters
----------
handle : raft.common.Handle
handle containing comms_t to use
root : int
Rank of the root worker
"""
cdef const handle_t* h = <handle_t*><size_t>handle.getHandle()
return test_collective_gatherv(deref(h), root)


def perform_test_comms_send_recv(handle, n_trials):
"""
Performs a p2p send/recv on the current worker
Expand Down
4 changes: 4 additions & 0 deletions python/raft/test/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from raft.dask.common import perform_test_comms_bcast
from raft.dask.common import perform_test_comms_reduce
from raft.dask.common import perform_test_comms_allgather
from raft.dask.common import perform_test_comms_gather
from raft.dask.common import perform_test_comms_gatherv
from raft.dask.common import perform_test_comms_reducescatter
from raft.dask.common import perform_test_comm_split

Expand Down Expand Up @@ -130,6 +132,8 @@ def _has_handle(sessionId):
perform_test_comms_allgather,
perform_test_comms_allreduce,
perform_test_comms_bcast,
perform_test_comms_gather,
perform_test_comms_gatherv,
perform_test_comms_reduce,
perform_test_comms_reducescatter,
]
Expand Down

0 comments on commit 4a79adc

Please sign in to comment.