From f090fc0fd46909394145ed333546cf5979abac69 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Thu, 10 Dec 2020 16:15:59 -0500 Subject: [PATCH 1/4] add gather & gatherv to raft::comms_t --- cpp/include/raft/comms/comms.hpp | 48 ++++++++++++++++++++++++++++ cpp/include/raft/comms/mpi_comms.hpp | 31 ++++++++++++++++++ cpp/include/raft/comms/std_comms.hpp | 31 ++++++++++++++++++ 3 files changed, 110 insertions(+) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index 73e52e781b..0ca9f3972f 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -130,6 +130,15 @@ class comms_iface { const size_t* recvcounts, const size_t* displs, datatype_t datatype, cudaStream_t stream) const = 0; + virtual void gather(const void* sendbuff, void* recvbuff, size_t sendcount, + datatype_t datatype, int root, + cudaStream_t stream) const = 0; + + virtual void gatherv(const void* sendbuf, void* recvbuf, size_t sendcount, + const size_t* recvcounts, const size_t* displs, + datatype_t datatype, int root, + cudaStream_t stream) const = 0; + virtual void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const = 0; @@ -316,6 +325,45 @@ class comms_t { get_type(), stream); } + /** + * Gathers data from each rank onto all ranks + * @tparam value_t datatype of underlying buffers + * @param sendbuff buffer containing data to gather + * @param recvbuff buffer containing gathered data from all ranks + * @param sendcount number of elements in send buffer + * @param root rank to store the results + * @param stream CUDA stream to synchronize operation + */ + template + void gather(const value_t* sendbuff, value_t* recvbuff, size_t sendcount, + int root, cudaStream_t stream) const { + impl_->gather(static_cast(sendbuff), + static_cast(recvbuff), sendcount, get_type(), + root, stream); + } + + /** + * Gathers data from all ranks and delivers to combined data to all ranks + * @param value_t datatype of underlying buffers + * @param sendbuff buffer containing data to send + * @param recvbuff buffer containing data to receive + * @param sendcount number of elements in send buffer + * @param recvcounts pointer to an array (of length num_ranks size) containing the number of + * elements that are to be received from each rank + * @param displs pointer to an array (of length num_ranks size) to specify the displacement + * (relative to recvbuf) at which to place the incoming data from each rank + * @param root rank to store the results + * @param stream CUDA stream to synchronize operation + */ + template + void gatherv(const value_t* sendbuf, value_t* recvbuf, size_t sendcount, + const size_t* recvcounts, const size_t* displs, int root, + cudaStream_t stream) const { + impl_->gatherv(static_cast(sendbuf), + static_cast(recvbuf), sendcount, recvcounts, displs, + get_type(), root, stream); + } + /** * Reduces data from all ranks then scatters the result across ranks * @tparam value_t datatype of underlying buffers diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index a372702c34..dddc709fa7 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -232,6 +232,37 @@ class mpi_comms : public comms_iface { } } + void gather(const void* sendbuff, void* recvbuff, size_t sendcount, + datatype_t datatype, int root, cudaStream_t stream) const { + size_t dtype_size = get_datatype_size(datatype); + NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + NCCL_TRY(ncclRecv(recvbuff + sendcount * i * dtype_size, sendcount, + get_nccl_datatype(datatype), r, nccl_comm_, stream)); + } + } + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r, + nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + + void gatherv(const void* sendbuf, void* recvbuf, size_t sendcount, + const size_t* recvcounts, const size_t* displs, + datatype_t datatype, int root, cudaStream_t stream) const { + size_t dtype_size = get_datatype_size(datatype); + NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + NCCL_TRY(ncclRecv(recvbuff + displs[r] * dtype_size, recvcounts[r], + get_nccl_datatype(datatype), r, nccl_comm_, stream)); + } + } + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r, + nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const { NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount, diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index d4b9d2ba39..3cafdc87c0 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -346,6 +346,37 @@ class std_comms : public comms_iface { } } + void gather(const void *sendbuff, void *recvbuff, size_t sendcount, + datatype_t datatype, int root, cudaStream_t stream) const { + size_t dtype_size = get_datatype_size(datatype); + NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + NCCL_TRY(ncclRecv(recvbuff + sendcount * i * dtype_size, sendcount, + get_nccl_datatype(datatype), r, nccl_comm_, stream)); + } + } + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r, + nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + + void gatherv(const void *sendbuf, void *recvbuf, size_t sendcount, + const size_t *recvcounts, const size_t *displs, + datatype_t datatype, int root, cudaStream_t stream) const { + size_t dtype_size = get_datatype_size(datatype); + NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + NCCL_TRY(ncclRecv(recvbuff + displs[r] * dtype_size, recvcounts[r], + get_nccl_datatype(datatype), r, nccl_comm_, stream)); + } + } + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r, + nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + void reducescatter(const void *sendbuff, void *recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const { NCCL_TRY(ncclReduceScatter(sendbuff, recvbuff, recvcount, From 7a46dc9d2280b3a0f899ce2b8c5799f19395464c Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Tue, 26 Jan 2021 15:58:41 -0500 Subject: [PATCH 2/4] fix build errors --- cpp/include/raft/comms/mpi_comms.hpp | 16 +++++++++------- cpp/include/raft/comms/std_comms.hpp | 16 +++++++++------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index dddc709fa7..8aebcc80cc 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -238,27 +238,29 @@ class mpi_comms : public comms_iface { NCCL_TRY(ncclGroupStart()); if (get_rank() == root) { for (int r = 0; r < get_size(); ++r) { - NCCL_TRY(ncclRecv(recvbuff + sendcount * i * dtype_size, sendcount, - get_nccl_datatype(datatype), r, nccl_comm_, stream)); + NCCL_TRY(ncclRecv( + static_cast(recvbuff) + sendcount * r * dtype_size, sendcount, + get_nccl_datatype(datatype), r, nccl_comm_, stream)); } } - NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r, + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); NCCL_TRY(ncclGroupEnd()); } - void gatherv(const void* sendbuf, void* recvbuf, size_t sendcount, + void gatherv(const void* sendbuff, void* recvbuff, size_t sendcount, const size_t* recvcounts, const size_t* displs, datatype_t datatype, int root, cudaStream_t stream) const { size_t dtype_size = get_datatype_size(datatype); NCCL_TRY(ncclGroupStart()); if (get_rank() == root) { for (int r = 0; r < get_size(); ++r) { - NCCL_TRY(ncclRecv(recvbuff + displs[r] * dtype_size, recvcounts[r], - get_nccl_datatype(datatype), r, nccl_comm_, stream)); + NCCL_TRY(ncclRecv(static_cast(recvbuff) + displs[r] * dtype_size, + recvcounts[r], get_nccl_datatype(datatype), r, + nccl_comm_, stream)); } } - NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r, + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); NCCL_TRY(ncclGroupEnd()); } diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 3cafdc87c0..a304955ceb 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -352,27 +352,29 @@ class std_comms : public comms_iface { NCCL_TRY(ncclGroupStart()); if (get_rank() == root) { for (int r = 0; r < get_size(); ++r) { - NCCL_TRY(ncclRecv(recvbuff + sendcount * i * dtype_size, sendcount, - get_nccl_datatype(datatype), r, nccl_comm_, stream)); + NCCL_TRY(ncclRecv( + static_cast(recvbuff) + sendcount * r * dtype_size, sendcount, + get_nccl_datatype(datatype), r, nccl_comm_, stream)); } } - NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r, + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); NCCL_TRY(ncclGroupEnd()); } - void gatherv(const void *sendbuf, void *recvbuf, size_t sendcount, + void gatherv(const void *sendbuff, void *recvbuff, size_t sendcount, const size_t *recvcounts, const size_t *displs, datatype_t datatype, int root, cudaStream_t stream) const { size_t dtype_size = get_datatype_size(datatype); NCCL_TRY(ncclGroupStart()); if (get_rank() == root) { for (int r = 0; r < get_size(); ++r) { - NCCL_TRY(ncclRecv(recvbuff + displs[r] * dtype_size, recvcounts[r], - get_nccl_datatype(datatype), r, nccl_comm_, stream)); + NCCL_TRY(ncclRecv( + static_cast(recvbuff) + displs[r] * dtype_size, recvcounts[r], + get_nccl_datatype(datatype), r, nccl_comm_, stream)); } } - NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), r, + NCCL_TRY(ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); NCCL_TRY(ncclGroupEnd()); } From e6180740181c986b11d5a3557bbab76932805530 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Tue, 26 Jan 2021 23:42:42 -0500 Subject: [PATCH 3/4] fix a bug in reducescatter test --- cpp/include/raft/comms/test.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index fa7e471174..627d629e68 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -158,23 +158,23 @@ bool test_collective_allgather(const handle_t &handle, int root) { bool test_collective_reducescatter(const handle_t &handle, int root) { comms_t const &communicator = handle.get_comms(); - int const send = 1; + std::vector sends(communicator.get_size(), 1); cudaStream_t stream = handle.get_stream(); raft::mr::device::buffer temp_d(handle.get_device_allocator(), stream, - 1); + sends.size()); raft::mr::device::buffer recv_d(handle.get_device_allocator(), stream, 1); - CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int), + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(), sends.size() * sizeof(int), cudaMemcpyHostToDevice, stream)); communicator.reducescatter(temp_d.data(), recv_d.data(), 1, op_t::SUM, stream); communicator.sync_stream(stream); int temp_h = -1; // Verify more than one byte is being sent - CUDA_CHECK(cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), + CUDA_CHECK(cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); communicator.barrier(); From d390789761098d88193cd325e81712173d13d14d Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Wed, 27 Jan 2021 00:20:38 -0500 Subject: [PATCH 4/4] add python gather & gatherv tests --- cpp/include/raft/comms/test.hpp | 96 ++++++++++++++++++++++++- python/raft/dask/common/__init__.py | 2 + python/raft/dask/common/comms_utils.pyx | 32 +++++++++ python/raft/test/test_comms.py | 4 ++ 4 files changed, 131 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index 627d629e68..5dc6f02d21 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -16,11 +16,13 @@ #pragma once -#include #include #include #include +#include +#include + namespace raft { namespace comms { @@ -155,6 +157,93 @@ bool test_collective_allgather(const handle_t &handle, int root) { return true; } +bool test_collective_gather(const handle_t &handle, int root) { + comms_t const &communicator = handle.get_comms(); + + int const send = communicator.get_rank(); + + cudaStream_t stream = handle.get_stream(); + + raft::mr::device::buffer temp_d(handle.get_device_allocator(), stream); + temp_d.resize(1, stream); + + raft::mr::device::buffer recv_d( + handle.get_device_allocator(), stream, + communicator.get_rank() == root ? communicator.get_size() : 0); + + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int), + cudaMemcpyHostToDevice, stream)); + + communicator.gather(temp_d.data(), recv_d.data(), 1, root, stream); + communicator.sync_stream(stream); + + if (communicator.get_rank() == root) { + std::vector temp_h(communicator.get_size(), 0); + CUDA_CHECK(cudaMemcpyAsync(temp_h.data(), recv_d.data(), + sizeof(int) * temp_h.size(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (int i = 0; i < communicator.get_size(); i++) { + if (temp_h[i] != i) return false; + } + } + return true; +} + +bool test_collective_gatherv(const handle_t &handle, int root) { + comms_t const &communicator = handle.get_comms(); + + std::vector sendcounts(communicator.get_size()); + std::iota(sendcounts.begin(), sendcounts.end(), size_t{1}); + std::vector displacements(communicator.get_size() + 1, 0); + std::partial_sum(sendcounts.begin(), sendcounts.end(), + displacements.begin() + 1); + + std::vector sends(displacements[communicator.get_rank() + 1] - + displacements[communicator.get_rank()], + communicator.get_rank()); + + cudaStream_t stream = handle.get_stream(); + + raft::mr::device::buffer temp_d(handle.get_device_allocator(), stream); + temp_d.resize(sends.size(), stream); + + raft::mr::device::buffer recv_d( + handle.get_device_allocator(), stream, + communicator.get_rank() == root ? displacements.back() : 0); + + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(), + sends.size() * sizeof(int), cudaMemcpyHostToDevice, + stream)); + + communicator.gatherv( + temp_d.data(), recv_d.data(), temp_d.size(), + communicator.get_rank() == root ? sendcounts.data() + : static_cast(nullptr), + communicator.get_rank() == root ? displacements.data() + : static_cast(nullptr), + root, stream); + communicator.sync_stream(stream); + + if (communicator.get_rank() == root) { + std::vector temp_h(displacements.back(), 0); + CUDA_CHECK(cudaMemcpyAsync(temp_h.data(), recv_d.data(), + sizeof(int) * displacements.back(), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (int i = 0; i < communicator.get_size(); i++) { + if (std::count_if(temp_h.begin() + displacements[i], + temp_h.begin() + displacements[i + 1], + [i](auto val) { return val != i; }) != 0) { + return false; + } + } + } + return true; +} + bool test_collective_reducescatter(const handle_t &handle, int root) { comms_t const &communicator = handle.get_comms(); @@ -167,8 +256,9 @@ bool test_collective_reducescatter(const handle_t &handle, int root) { raft::mr::device::buffer recv_d(handle.get_device_allocator(), stream, 1); - CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(), sends.size() * sizeof(int), - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(), + sends.size() * sizeof(int), cudaMemcpyHostToDevice, + stream)); communicator.reducescatter(temp_d.data(), recv_d.data(), 1, op_t::SUM, stream); diff --git a/python/raft/dask/common/__init__.py b/python/raft/dask/common/__init__.py index 788af46c92..73bb5d6700 100644 --- a/python/raft/dask/common/__init__.py +++ b/python/raft/dask/common/__init__.py @@ -21,6 +21,8 @@ from .comms_utils import perform_test_comms_allreduce from .comms_utils import perform_test_comms_send_recv from .comms_utils import perform_test_comms_allgather +from .comms_utils import perform_test_comms_gather +from .comms_utils import perform_test_comms_gatherv from .comms_utils import perform_test_comms_bcast from .comms_utils import perform_test_comms_reduce from .comms_utils import perform_test_comms_reducescatter diff --git a/python/raft/dask/common/comms_utils.pyx b/python/raft/dask/common/comms_utils.pyx index 4dbd2f1a7c..1a703485a9 100644 --- a/python/raft/dask/common/comms_utils.pyx +++ b/python/raft/dask/common/comms_utils.pyx @@ -60,6 +60,8 @@ cdef extern from "raft/comms/test.hpp" namespace "raft::comms": bool test_collective_broadcast(const handle_t &h, int root) except + bool test_collective_reduce(const handle_t &h, int root) except + bool test_collective_allgather(const handle_t &h, int root) except + + bool test_collective_gather(const handle_t &h, int root) except + + bool test_collective_gatherv(const handle_t &h, int root) except + bool test_collective_reducescatter(const handle_t &h, int root) except + bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) except + @@ -131,6 +133,36 @@ def perform_test_comms_allgather(handle, root): return test_collective_allgather(deref(h), root) +def perform_test_comms_gather(handle, root): + """ + Performs a gather on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + root : int + Rank of the root worker + """ + cdef const handle_t* h = handle.getHandle() + return test_collective_gather(deref(h), root) + + +def perform_test_comms_gatherv(handle, root): + """ + Performs a gatherv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + root : int + Rank of the root worker + """ + cdef const handle_t* h = handle.getHandle() + return test_collective_gatherv(deref(h), root) + + def perform_test_comms_send_recv(handle, n_trials): """ Performs a p2p send/recv on the current worker diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index 7dccb7bbae..a0db3b7f4f 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -28,6 +28,8 @@ from raft.dask.common import perform_test_comms_bcast from raft.dask.common import perform_test_comms_reduce from raft.dask.common import perform_test_comms_allgather + from raft.dask.common import perform_test_comms_gather + from raft.dask.common import perform_test_comms_gatherv from raft.dask.common import perform_test_comms_reducescatter from raft.dask.common import perform_test_comm_split @@ -130,6 +132,8 @@ def _has_handle(sessionId): perform_test_comms_allgather, perform_test_comms_allreduce, perform_test_comms_bcast, + perform_test_comms_gather, + perform_test_comms_gatherv, perform_test_comms_reduce, perform_test_comms_reducescatter, ]