Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add gather & gatherv to raft::comms::comms_t #114

Merged
merged 6 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions cpp/include/raft/comms/comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ class comms_iface {
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, cudaStream_t stream) const = 0;

virtual void gather(const void* sendbuff, void* recvbuff, size_t sendcount,
datatype_t datatype, int root,
cudaStream_t stream) const = 0;

virtual void gatherv(const void* sendbuf, void* recvbuf, size_t sendcount,
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, int root,
cudaStream_t stream) const = 0;

virtual void reducescatter(const void* sendbuff, void* recvbuff,
size_t recvcount, datatype_t datatype, op_t op,
cudaStream_t stream) const = 0;
Expand Down Expand Up @@ -316,6 +325,45 @@ class comms_t {
get_type<value_t>(), stream);
}

/**
* Gathers data from each rank onto all ranks
* @tparam value_t datatype of underlying buffers
* @param sendbuff buffer containing data to gather
* @param recvbuff buffer containing gathered data from all ranks
* @param sendcount number of elements in send buffer
* @param root rank to store the results
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void gather(const value_t* sendbuff, value_t* recvbuff, size_t sendcount,
int root, cudaStream_t stream) const {
impl_->gather(static_cast<const void*>(sendbuff),
static_cast<void*>(recvbuff), sendcount, get_type<value_t>(),
root, stream);
}

/**
* Gathers data from all ranks and delivers to combined data to all ranks
* @param value_t datatype of underlying buffers
* @param sendbuff buffer containing data to send
* @param recvbuff buffer containing data to receive
* @param sendcount number of elements in send buffer
* @param recvcounts pointer to an array (of length num_ranks size) containing the number of
* elements that are to be received from each rank
* @param displs pointer to an array (of length num_ranks size) to specify the displacement
* (relative to recvbuf) at which to place the incoming data from each rank
* @param root rank to store the results
* @param stream CUDA stream to synchronize operation
*/
template <typename value_t>
void gatherv(const value_t* sendbuf, value_t* recvbuf, size_t sendcount,
const size_t* recvcounts, const size_t* displs, int root,
cudaStream_t stream) const {
impl_->gatherv(static_cast<const void*>(sendbuf),
static_cast<void*>(recvbuf), sendcount, recvcounts, displs,
get_type<value_t>(), root, stream);
}

/**
* Reduces data from all ranks then scatters the result across ranks
* @tparam value_t datatype of underlying buffers
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/raft/comms/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,39 @@ class mpi_comms : public comms_iface {
}
}

void gather(const void* sendbuff, void* recvbuff, size_t sendcount,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(
static_cast<char*>(recvbuff) + sendcount * r * dtype_size, sendcount,
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void gatherv(const void* sendbuff, void* recvbuff, size_t sendcount,
const size_t* recvcounts, const size_t* displs,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(static_cast<char*>(recvbuff) + displs[r] * dtype_size,
recvcounts[r], get_nccl_datatype(datatype), r,
nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount,
datatype_t datatype, op_t op, cudaStream_t stream) const {
NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount,
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,39 @@ class std_comms : public comms_iface {
}
}

void gather(const void *sendbuff, void *recvbuff, size_t sendcount,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(
static_cast<char *>(recvbuff) + sendcount * r * dtype_size, sendcount,
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void gatherv(const void *sendbuff, void *recvbuff, size_t sendcount,
const size_t *recvcounts, const size_t *displs,
datatype_t datatype, int root, cudaStream_t stream) const {
size_t dtype_size = get_datatype_size(datatype);
NCCL_TRY(ncclGroupStart());
if (get_rank() == root) {
for (int r = 0; r < get_size(); ++r) {
NCCL_TRY(ncclRecv(
static_cast<char *>(recvbuff) + displs[r] * dtype_size, recvcounts[r],
get_nccl_datatype(datatype), r, nccl_comm_, stream));
}
}
NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root,
nccl_comm_, stream));
NCCL_TRY(ncclGroupEnd());
}

void reducescatter(const void *sendbuff, void *recvbuff, size_t recvcount,
datatype_t datatype, op_t op, cudaStream_t stream) const {
NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount,
Expand Down
102 changes: 96 additions & 6 deletions cpp/include/raft/comms/test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

#pragma once

#include <iostream>
#include <raft/comms/comms.hpp>
#include <raft/handle.hpp>
#include <raft/mr/device/buffer.hpp>

#include <iostream>
#include <numeric>

namespace raft {
namespace comms {

Expand Down Expand Up @@ -155,26 +157,114 @@ bool test_collective_allgather(const handle_t &handle, int root) {
return true;
}

bool test_collective_gather(const handle_t &handle, int root) {
comms_t const &communicator = handle.get_comms();

int const send = communicator.get_rank();

cudaStream_t stream = handle.get_stream();

raft::mr::device::buffer<int> temp_d(handle.get_device_allocator(), stream);
temp_d.resize(1, stream);

raft::mr::device::buffer<int> recv_d(
handle.get_device_allocator(), stream,
communicator.get_rank() == root ? communicator.get_size() : 0);

CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int),
cudaMemcpyHostToDevice, stream));

communicator.gather(temp_d.data(), recv_d.data(), 1, root, stream);
communicator.sync_stream(stream);

if (communicator.get_rank() == root) {
std::vector<int> temp_h(communicator.get_size(), 0);
CUDA_CHECK(cudaMemcpyAsync(temp_h.data(), recv_d.data(),
sizeof(int) * temp_h.size(),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));

for (int i = 0; i < communicator.get_size(); i++) {
if (temp_h[i] != i) return false;
}
}
return true;
}

bool test_collective_gatherv(const handle_t &handle, int root) {
comms_t const &communicator = handle.get_comms();

std::vector<size_t> sendcounts(communicator.get_size());
std::iota(sendcounts.begin(), sendcounts.end(), size_t{1});
std::vector<size_t> displacements(communicator.get_size() + 1, 0);
std::partial_sum(sendcounts.begin(), sendcounts.end(),
displacements.begin() + 1);

std::vector<int> sends(displacements[communicator.get_rank() + 1] -
displacements[communicator.get_rank()],
communicator.get_rank());

cudaStream_t stream = handle.get_stream();

raft::mr::device::buffer<int> temp_d(handle.get_device_allocator(), stream);
temp_d.resize(sends.size(), stream);

raft::mr::device::buffer<int> recv_d(
handle.get_device_allocator(), stream,
communicator.get_rank() == root ? displacements.back() : 0);

CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(),
sends.size() * sizeof(int), cudaMemcpyHostToDevice,
stream));

communicator.gatherv(
temp_d.data(), recv_d.data(), temp_d.size(),
communicator.get_rank() == root ? sendcounts.data()
: static_cast<size_t *>(nullptr),
communicator.get_rank() == root ? displacements.data()
: static_cast<size_t *>(nullptr),
root, stream);
communicator.sync_stream(stream);

if (communicator.get_rank() == root) {
std::vector<int> temp_h(displacements.back(), 0);
CUDA_CHECK(cudaMemcpyAsync(temp_h.data(), recv_d.data(),
sizeof(int) * displacements.back(),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));

for (int i = 0; i < communicator.get_size(); i++) {
if (std::count_if(temp_h.begin() + displacements[i],
temp_h.begin() + displacements[i + 1],
[i](auto val) { return val != i; }) != 0) {
return false;
}
}
}
return true;
}

bool test_collective_reducescatter(const handle_t &handle, int root) {
comms_t const &communicator = handle.get_comms();

int const send = 1;
std::vector<int> sends(communicator.get_size(), 1);

cudaStream_t stream = handle.get_stream();

raft::mr::device::buffer<int> temp_d(handle.get_device_allocator(), stream,
1);
sends.size());
raft::mr::device::buffer<int> recv_d(handle.get_device_allocator(), stream,
1);

CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int),
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(),
sends.size() * sizeof(int), cudaMemcpyHostToDevice,
stream));

communicator.reducescatter(temp_d.data(), recv_d.data(), 1, op_t::SUM,
stream);
communicator.sync_stream(stream);
int temp_h = -1; // Verify more than one byte is being sent
CUDA_CHECK(cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int),
CUDA_CHECK(cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
communicator.barrier();
Expand Down
2 changes: 2 additions & 0 deletions python/raft/dask/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .comms_utils import perform_test_comms_allreduce
from .comms_utils import perform_test_comms_send_recv
from .comms_utils import perform_test_comms_allgather
from .comms_utils import perform_test_comms_gather
from .comms_utils import perform_test_comms_gatherv
from .comms_utils import perform_test_comms_bcast
from .comms_utils import perform_test_comms_reduce
from .comms_utils import perform_test_comms_reducescatter
Expand Down
32 changes: 32 additions & 0 deletions python/raft/dask/common/comms_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ cdef extern from "raft/comms/test.hpp" namespace "raft::comms":
bool test_collective_broadcast(const handle_t &h, int root) except +
bool test_collective_reduce(const handle_t &h, int root) except +
bool test_collective_allgather(const handle_t &h, int root) except +
bool test_collective_gather(const handle_t &h, int root) except +
bool test_collective_gatherv(const handle_t &h, int root) except +
bool test_collective_reducescatter(const handle_t &h, int root) except +
bool test_pointToPoint_simple_send_recv(const handle_t &h,
int numTrials) except +
Expand Down Expand Up @@ -131,6 +133,36 @@ def perform_test_comms_allgather(handle, root):
return test_collective_allgather(deref(h), root)


def perform_test_comms_gather(handle, root):
"""
Performs a gather on the current worker

Parameters
----------
handle : raft.common.Handle
handle containing comms_t to use
root : int
Rank of the root worker
"""
cdef const handle_t* h = <handle_t*><size_t>handle.getHandle()
return test_collective_gather(deref(h), root)


def perform_test_comms_gatherv(handle, root):
"""
Performs a gatherv on the current worker

Parameters
----------
handle : raft.common.Handle
handle containing comms_t to use
root : int
Rank of the root worker
"""
cdef const handle_t* h = <handle_t*><size_t>handle.getHandle()
return test_collective_gatherv(deref(h), root)


def perform_test_comms_send_recv(handle, n_trials):
"""
Performs a p2p send/recv on the current worker
Expand Down
4 changes: 4 additions & 0 deletions python/raft/test/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from raft.dask.common import perform_test_comms_bcast
from raft.dask.common import perform_test_comms_reduce
from raft.dask.common import perform_test_comms_allgather
from raft.dask.common import perform_test_comms_gather
from raft.dask.common import perform_test_comms_gatherv
from raft.dask.common import perform_test_comms_reducescatter
from raft.dask.common import perform_test_comm_split

Expand Down Expand Up @@ -130,6 +132,8 @@ def _has_handle(sessionId):
perform_test_comms_allgather,
perform_test_comms_allreduce,
perform_test_comms_bcast,
perform_test_comms_gather,
perform_test_comms_gatherv,
perform_test_comms_reduce,
perform_test_comms_reducescatter,
]
Expand Down