diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index 0ca9f3972f..dc172c9503 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -18,10 +18,8 @@ #include -// FIXME: for get_nccl_comm(), should be removed -#include - #include +#include namespace raft { namespace comms { @@ -96,9 +94,6 @@ class comms_iface { virtual int get_size() const = 0; virtual int get_rank() const = 0; - // FIXME: a temporary hack, should be removed - virtual ncclComm_t get_nccl_comm() const = 0; - virtual std::unique_ptr comm_split(int color, int key) const = 0; virtual void barrier() const = 0; @@ -142,6 +137,25 @@ class comms_iface { virtual void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const = 0; + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + virtual void device_send(const void* buf, size_t size, int dest, + cudaStream_t stream) const = 0; + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + virtual void device_recv(void* buf, size_t size, int source, + cudaStream_t stream) const = 0; + + virtual void device_sendrecv(const void* sendbuf, size_t sendsize, int dest, + void* recvbuf, size_t recvsize, int source, + cudaStream_t stream) const = 0; + + virtual void device_multicast_sendrecv( + const void* sendbuf, std::vector const& sendsizes, + std::vector const& sendoffsets, std::vector const& dests, + void* recvbuf, std::vector const& recvsizes, + std::vector const& recvoffsets, std::vector const& sources, + cudaStream_t stream) const = 0; }; class comms_t { @@ -166,9 +180,6 @@ class comms_t { */ int get_rank() const { return impl_->get_rank(); } - // FIXME: a temporary hack, should be removed - ncclComm_t get_nccl_comm() const { return impl_->get_nccl_comm(); } - /** * Splits the current communicator clique into sub-cliques matching * the given color and key @@ -380,6 +391,102 @@ class comms_t { get_type(), op, stream); } + /** + * Performs a point-to-point send + * + * if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock. + * + * @tparam value_t the type of data to send + * @param buf pointer to array of data to send + * @param size number of elements in buf + * @param dest destination rank + * @param stream CUDA stream to synchronize operation + */ + template + void device_send(const value_t* buf, size_t size, int dest, + cudaStream_t stream) const { + impl_->device_send(static_cast(buf), size * sizeof(value_t), + dest, stream); + } + + /** + * Performs a point-to-point receive + * + * if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock. + * + * @tparam value_t the type of data to be received + * @param buf pointer to (initialized) array that will hold received data + * @param size number of elements in buf + * @param source source rank + * @param stream CUDA stream to synchronize operation + */ + template + void device_recv(value_t* buf, size_t size, int source, + cudaStream_t stream) const { + impl_->device_recv(static_cast(buf), size * sizeof(value_t), source, + stream); + } + + /** + * Performs a point-to-point send/receive + * + * @tparam value_t the type of data to be sent & received + * @param sendbuf pointer to array of data to send + * @param sendsize number of elements in sendbuf + * @param dest destination rank + * @param recvbuf pointer to (initialized) array that will hold received data + * @param recvsize number of elements in recvbuf + * @param source source rank + * @param stream CUDA stream to synchronize operation + */ + template + void device_sendrecv(const value_t* sendbuf, size_t sendsize, int dest, + value_t* recvbuf, size_t recvsize, int source, + cudaStream_t stream) const { + impl_->device_sendrecv( + static_cast(sendbuf), sendsize * sizeof(value_t), dest, + static_cast(recvbuf), recvsize * sizeof(value_t), source, stream); + } + + /** + * Performs a multicast send/receive + * + * @tparam value_t the type of data to be sent & received + * @param sendbuf pointer to array of data to send + * @param sendsizes numbers of elements to send + * @param sendoffsets offsets in a number of elements from sendbuf + * @param dest destination ranks + * @param recvbuf pointer to (initialized) array that will hold received data + * @param recvsizes numbers of elements to recv + * @param recvoffsets offsets in a number of elements from recvbuf + * @param sources source ranks + * @param stream CUDA stream to synchronize operation + */ + template + void device_multicast_sendrecv( + const value_t* sendbuf, std::vector const& sendsizes, + std::vector const& sendoffsets, std::vector const& dests, + value_t* recvbuf, std::vector const& recvsizes, + std::vector const& recvoffsets, std::vector const& sources, + cudaStream_t stream) const { + auto sendbytesizes = sendsizes; + auto sendbyteoffsets = sendoffsets; + for (size_t i = 0; i < sendsizes.size(); ++i) { + sendbytesizes[i] *= sizeof(value_t); + sendbyteoffsets[i] *= sizeof(value_t); + } + auto recvbytesizes = recvsizes; + auto recvbyteoffsets = recvoffsets; + for (size_t i = 0; i < recvsizes.size(); ++i) { + recvbytesizes[i] *= sizeof(value_t); + recvbyteoffsets[i] *= sizeof(value_t); + } + impl_->device_multicast_sendrecv(static_cast(sendbuf), + sendbytesizes, sendbyteoffsets, dests, + static_cast(recvbuf), recvbytesizes, + recvbyteoffsets, sources, stream); + } + private: std::unique_ptr impl_; }; diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 8aebcc80cc..8dda74f0a9 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -133,9 +133,6 @@ class mpi_comms : public comms_iface { int get_rank() const { return rank_; } - // FIXME: a temporary hack, should be removed - ncclComm_t get_nccl_comm() const { return nccl_comm_; } - std::unique_ptr comm_split(int color, int key) const { MPI_Comm new_comm; MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm)); @@ -304,6 +301,51 @@ class mpi_comms : public comms_iface { } }; + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_send(const void* buf, size_t size, int dest, + cudaStream_t stream) const { + NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); + } + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_recv(void* buf, size_t size, int source, + cudaStream_t stream) const { + NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); + } + + void device_sendrecv(const void* sendbuf, size_t sendsize, int dest, + void* recvbuf, size_t recvsize, int source, + cudaStream_t stream) const { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + NCCL_TRY(ncclGroupStart()); + NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); + NCCL_TRY( + ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + + void device_multicast_sendrecv(const void* sendbuf, + std::vector const& sendsizes, + std::vector const& sendoffsets, + std::vector const& dests, void* recvbuf, + std::vector const& recvsizes, + std::vector const& recvoffsets, + std::vector const& sources, + cudaStream_t stream) const { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + NCCL_TRY(ncclGroupStart()); + for (size_t i = 0; i < sendsizes.size(); ++i) { + NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], + sendsizes[i], ncclUint8, dests[i], nccl_comm_, stream)); + } + for (size_t i = 0; i < recvsizes.size(); ++i) { + NCCL_TRY(ncclRecv(static_cast(recvbuf) + recvoffsets[i], + recvsizes[i], ncclUint8, sources[i], nccl_comm_, + stream)); + } + NCCL_TRY(ncclGroupEnd()); + } + private: bool owns_mpi_comm_; MPI_Comm mpi_comm_; diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index a304955ceb..765e8741bb 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -112,9 +112,6 @@ class std_comms : public comms_iface { int get_rank() const { return rank_; } - // FIXME: a temporary hack, should be removed - ncclComm_t get_nccl_comm() const { return nccl_comm_; } - std::unique_ptr comm_split(int color, int key) const { mr::device::buffer d_colors(device_allocator_, stream_, get_size()); mr::device::buffer d_keys(device_allocator_, stream_, get_size()); @@ -418,6 +415,51 @@ class std_comms : public comms_iface { } } + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_send(const void *buf, size_t size, int dest, + cudaStream_t stream) const { + NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); + } + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_recv(void *buf, size_t size, int source, + cudaStream_t stream) const { + NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); + } + + void device_sendrecv(const void *sendbuf, size_t sendsize, int dest, + void *recvbuf, size_t recvsize, int source, + cudaStream_t stream) const { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + NCCL_TRY(ncclGroupStart()); + NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); + NCCL_TRY( + ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + + void device_multicast_sendrecv(const void *sendbuf, + std::vector const &sendsizes, + std::vector const &sendoffsets, + std::vector const &dests, void *recvbuf, + std::vector const &recvsizes, + std::vector const &recvoffsets, + std::vector const &sources, + cudaStream_t stream) const { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + NCCL_TRY(ncclGroupStart()); + for (size_t i = 0; i < sendsizes.size(); ++i) { + NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], + sendsizes[i], ncclUint8, dests[i], nccl_comm_, stream)); + } + for (size_t i = 0; i < recvsizes.size(); ++i) { + NCCL_TRY(ncclRecv(static_cast(recvbuf) + recvoffsets[i], + recvsizes[i], ncclUint8, sources[i], nccl_comm_, + stream)); + } + NCCL_TRY(ncclGroupEnd()); + } + private: ncclComm_t nccl_comm_; cudaStream_t stream_; diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index 5dc6f02d21..2ba9e406be 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -19,6 +19,11 @@ #include #include #include +#include +#include + +#include +#include #include #include @@ -340,6 +345,156 @@ bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) { return ret; } +/** + * A simple sanity check that device is able to send OR receive. + * + * @param the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_send_or_recv(const handle_t &h, int numTrials) { + comms_t const &communicator = h.get_comms(); + int const rank = communicator.get_rank(); + cudaStream_t stream = h.get_stream(); + + bool ret = true; + for (int i = 0; i < numTrials; i++) { + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + std::cout << "Trial " << i << std::endl; + } + + bool sender = (rank % 2) == 0 ? true : false; + rmm::device_scalar received_data(-1, stream); + rmm::device_scalar sent_data(rank, stream); + + if (sender) { + if (rank + 1 < communicator.get_size()) { + communicator.device_send(sent_data.data(), 1, rank + 1, stream); + } + } else { + communicator.device_recv(received_data.data(), 1, rank - 1, stream); + } + + communicator.sync_stream(stream); + + if (!sender && received_data.value() != rank - 1) { + ret = false; + } + + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + } + } + + return ret; +} + +/** + * A simple sanity check that device is able to send and receive at the same time. + * + * @param the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_sendrecv(const handle_t &h, int numTrials) { + comms_t const &communicator = h.get_comms(); + int const rank = communicator.get_rank(); + cudaStream_t stream = h.get_stream(); + + bool ret = true; + for (int i = 0; i < numTrials; i++) { + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + std::cout << "Trial " << i << std::endl; + } + + rmm::device_scalar received_data(-1, stream); + rmm::device_scalar sent_data(rank, stream); + + if (rank % 2 == 0) { + if (rank + 1 < communicator.get_size()) { + communicator.device_sendrecv(sent_data.data(), 1, rank + 1, + received_data.data(), 1, rank + 1, stream); + } + } else { + communicator.device_sendrecv(sent_data.data(), 1, rank - 1, + received_data.data(), 1, rank - 1, stream); + } + + communicator.sync_stream(stream); + + if (((rank % 2 == 0) && (received_data.value() != rank + 1)) || + ((rank % 2 == 1) && (received_data.value() != rank - 1))) { + ret = false; + } + + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + } + } + + return ret; +} + +/** + * A simple sanity check that device is able to perform multiple concurrent sends and receives. + * + * @param the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_multicast_sendrecv(const handle_t &h, + int numTrials) { + comms_t const &communicator = h.get_comms(); + int const rank = communicator.get_rank(); + cudaStream_t stream = h.get_stream(); + + bool ret = true; + for (int i = 0; i < numTrials; i++) { + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + std::cout << "Trial " << i << std::endl; + } + + rmm::device_uvector received_data(communicator.get_size(), stream); + rmm::device_scalar sent_data(rank, stream); + + std::vector sendsizes(communicator.get_size(), 1); + std::vector sendoffsets(communicator.get_size(), 0); + std::vector dests(communicator.get_size()); + std::iota(dests.begin(), dests.end(), int{0}); + + std::vector recvsizes(communicator.get_size(), 1); + std::vector recvoffsets(communicator.get_size()); + std::iota(recvoffsets.begin(), recvoffsets.end(), size_t{0}); + std::vector srcs(communicator.get_size()); + std::iota(srcs.begin(), srcs.end(), int{0}); + + communicator.device_multicast_sendrecv( + sent_data.data(), sendsizes, sendoffsets, dests, received_data.data(), + recvsizes, recvoffsets, srcs, stream); + + communicator.sync_stream(stream); + + std::vector h_received_data(communicator.get_size()); + raft::update_host(h_received_data.data(), received_data.data(), + received_data.size(), stream); + CUDA_TRY(cudaStreamSynchronize(stream)); + for (int i = 0; i < communicator.get_size(); ++i) { + if (h_received_data[i] != i) { + ret = false; + } + } + + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + } + } + + return ret; +} + /** * A simple test that the comms can be split into 2 separate subcommunicators * diff --git a/python/raft/dask/common/__init__.py b/python/raft/dask/common/__init__.py index 73bb5d6700..c2265f6828 100644 --- a/python/raft/dask/common/__init__.py +++ b/python/raft/dask/common/__init__.py @@ -20,6 +20,9 @@ from .comms_utils import inject_comms_on_handle_coll_only from .comms_utils import perform_test_comms_allreduce from .comms_utils import perform_test_comms_send_recv +from .comms_utils import perform_test_comms_device_send_or_recv +from .comms_utils import perform_test_comms_device_sendrecv +from .comms_utils import perform_test_comms_device_multicast_sendrecv from .comms_utils import perform_test_comms_allgather from .comms_utils import perform_test_comms_gather from .comms_utils import perform_test_comms_gatherv diff --git a/python/raft/dask/common/comms_utils.pyx b/python/raft/dask/common/comms_utils.pyx index 1a703485a9..20f004b1d6 100644 --- a/python/raft/dask/common/comms_utils.pyx +++ b/python/raft/dask/common/comms_utils.pyx @@ -65,6 +65,12 @@ cdef extern from "raft/comms/test.hpp" namespace "raft::comms": bool test_collective_reducescatter(const handle_t &h, int root) except + bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) except + + bool test_pointToPoint_device_send_or_recv(const handle_t &h, + int numTrials) except + + bool test_pointToPoint_device_sendrecv(const handle_t &h, + int numTrials) except + + bool test_pointToPoint_device_multicast_sendrecv(const handle_t &h, + int numTrials) except + bool test_commsplit(const handle_t &h, int n_colors) except + @@ -171,11 +177,58 @@ def perform_test_comms_send_recv(handle, n_trials): ---------- handle : raft.common.Handle handle containing comms_t to use + n_trilas : int + Number of test trials """ cdef const handle_t *h = handle.getHandle() return test_pointToPoint_simple_send_recv(deref(h), n_trials) +def perform_test_comms_device_send_or_recv(handle, n_trials): + """ + Performs a p2p device send or recv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + n_trilas : int + Number of test trials + """ + cdef const handle_t *h = handle.getHandle() + return test_pointToPoint_device_send_or_recv(deref(h), n_trials) + + +def perform_test_comms_device_sendrecv(handle, n_trials): + """ + Performs a p2p device concurrent send&recv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + n_trilas : int + Number of test trials + """ + cdef const handle_t *h = handle.getHandle() + return test_pointToPoint_device_sendrecv(deref(h), n_trials) + + +def perform_test_comms_device_multicast_sendrecv(handle, n_trials): + """ + Performs a p2p device concurrent multicast send&recv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + n_trilas : int + Number of test trials + """ + cdef const handle_t *h = handle.getHandle() + return test_pointToPoint_device_multicast_sendrecv(deref(h), n_trials) + + def perform_test_comm_split(handle, n_colors): """ Performs a p2p send/recv on the current worker diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index a0db3b7f4f..a540e8db10 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -24,6 +24,9 @@ from raft.dask import Comms from raft.dask.common import local_handle from raft.dask.common import perform_test_comms_send_recv + from raft.dask.common import perform_test_comms_device_send_or_recv + from raft.dask.common import perform_test_comms_device_sendrecv + from raft.dask.common import perform_test_comms_device_multicast_sendrecv from raft.dask.common import perform_test_comms_allreduce from raft.dask.common import perform_test_comms_bcast from raft.dask.common import perform_test_comms_reduce @@ -65,6 +68,21 @@ def func_test_send_recv(sessionId, n_trials): return perform_test_comms_send_recv(handle, n_trials) +def func_test_device_send_or_recv(sessionId, n_trials): + handle = local_handle(sessionId) + return perform_test_comms_device_send_or_recv(handle, n_trials) + + +def func_test_device_sendrecv(sessionId, n_trials): + handle = local_handle(sessionId) + return perform_test_comms_device_sendrecv(handle, n_trials) + + +def func_test_device_multicast_sendrecv(sessionId, n_trials): + handle = local_handle(sessionId) + return perform_test_comms_device_multicast_sendrecv(handle, n_trials) + + def func_test_comm_split(sessionId, n_trials): handle = local_handle(sessionId) return perform_test_comm_split(handle, n_trials) @@ -247,3 +265,72 @@ def test_send_recv(n_trials, client): wait(dfs, timeout=5) assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize("n_trials", [1, 5]) +def test_device_send_or_recv(n_trials, client): + + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + client.submit( + func_test_device_send_or_recv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize("n_trials", [1, 5]) +def test_device_sendrecv(n_trials, client): + + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + client.submit( + func_test_device_sendrecv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize("n_trials", [1, 5]) +def test_device_multicast_sendrecv(n_trials, client): + + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + client.submit( + func_test_device_multicast_sendrecv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs))