From 37332380376b15dd7153aae6fa921b4621f3f1cd Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Fri, 5 Feb 2021 15:00:57 -0500 Subject: [PATCH 01/12] remove temporary get_nccl_comm function --- cpp/include/raft/comms/comms.hpp | 9 --------- cpp/include/raft/comms/mpi_comms.hpp | 3 --- cpp/include/raft/comms/std_comms.hpp | 3 --- 3 files changed, 15 deletions(-) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index 73e52e781b..575de279bd 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -18,9 +18,6 @@ #include -// FIXME: for get_nccl_comm(), should be removed -#include - #include namespace raft { @@ -96,9 +93,6 @@ class comms_iface { virtual int get_size() const = 0; virtual int get_rank() const = 0; - // FIXME: a temporary hack, should be removed - virtual ncclComm_t get_nccl_comm() const = 0; - virtual std::unique_ptr comm_split(int color, int key) const = 0; virtual void barrier() const = 0; @@ -157,9 +151,6 @@ class comms_t { */ int get_rank() const { return impl_->get_rank(); } - // FIXME: a temporary hack, should be removed - ncclComm_t get_nccl_comm() const { return impl_->get_nccl_comm(); } - /** * Splits the current communicator clique into sub-cliques matching * the given color and key diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index a372702c34..7471f9ddfc 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -133,9 +133,6 @@ class mpi_comms : public comms_iface { int get_rank() const { return rank_; } - // FIXME: a temporary hack, should be removed - ncclComm_t get_nccl_comm() const { return nccl_comm_; } - std::unique_ptr comm_split(int color, int key) const { MPI_Comm new_comm; MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm)); diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index d4b9d2ba39..d45e6096ec 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -112,9 +112,6 @@ class std_comms : public comms_iface { int get_rank() const { return rank_; } - // FIXME: a temporary hack, should be removed - ncclComm_t get_nccl_comm() const { return nccl_comm_; } - std::unique_ptr comm_split(int color, int key) const { mr::device::buffer d_colors(device_allocator_, stream_, get_size()); mr::device::buffer d_keys(device_allocator_, stream_, get_size()); From d89925ae4d66646f909fb1894d5f87f67c6d945e Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Fri, 5 Feb 2021 23:22:31 -0500 Subject: [PATCH 02/12] add device_send/device_recv --- cpp/include/raft/comms/comms.hpp | 44 ++++++++++++++++++++++++++++ cpp/include/raft/comms/mpi_comms.hpp | 12 ++++++++ cpp/include/raft/comms/std_comms.hpp | 12 ++++++++ 3 files changed, 68 insertions(+) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index 575de279bd..ead2b05b5b 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -127,6 +127,14 @@ class comms_iface { virtual void reducescatter(const void* sendbuff, void* recvbuff, size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const = 0; + + // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + virtual void device_send(const void* buf, size_t size, int dest, + cudaStream_t stream) const = 0; + + // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + virtual void device_recv(void* buf, size_t size, int source, + cudaStream_t stream) const = 0; }; class comms_t { @@ -323,6 +331,42 @@ class comms_t { get_type(), op, stream); } + /** + * Performs a point-to-point send + * + * note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock. + * + * @tparam value_t the type of data to send + * @param buf pointer to array of data to send + * @param size number of elements in buf + * @param dest destination rank + * @param stream CUDA stream to synchronize operation + */ + template + void device_send(const value_t* buf, size_t size, int dest, + cudaStream_t stream) const { + impl_->device_send(static_cast(buf), size * sizeof(value_t), + dest, stream); + } + + /** + * Performs a point-to-point receive + * + * note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock. + * + * @tparam value_t the type of data to be received + * @param buf pointer to (initialized) array that will hold received data + * @param size number of elements in buf + * @param source source rank + * @param stream CUDA stream to synchronize operation + */ + template + void device_recv(value_t* buf, size_t size, int source, + cudaStream_t stream) const { + impl_->device_recv(static_cast(buf), size * sizeof(value_t), source, + stream); + } + private: std::unique_ptr impl_; }; diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 7471f9ddfc..396babecf4 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -268,6 +268,18 @@ class mpi_comms : public comms_iface { } }; + // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_send(const void* buf, size_t size, int dest, + cudaStream_t stream) const { + NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); + } + + // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_recv(void* buf, size_t size, int source, + cudaStream_t stream) const { + NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); + } + private: bool owns_mpi_comm_; MPI_Comm mpi_comm_; diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index d45e6096ec..8de576e9f7 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -382,6 +382,18 @@ class std_comms : public comms_iface { } } + // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_send(const void *buf, size_t size, int dest, + cudaStream_t stream) const { + NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); + } + + // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_recv(void *buf, size_t size, int source, + cudaStream_t stream) const { + NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); + } + private: ncclComm_t nccl_comm_; cudaStream_t stream_; From 0d080c2193b7607f189046eeb93c9d9dad66debc Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Fri, 5 Feb 2021 23:38:29 -0500 Subject: [PATCH 03/12] add device_sendrecv --- cpp/include/raft/comms/comms.hpp | 33 ++++++++++++++++++++++++---- cpp/include/raft/comms/mpi_comms.hpp | 14 ++++++++++-- cpp/include/raft/comms/std_comms.hpp | 14 ++++++++++-- 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index ead2b05b5b..a26c03c096 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -128,13 +128,17 @@ class comms_iface { size_t recvcount, datatype_t datatype, op_t op, cudaStream_t stream) const = 0; - // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock virtual void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const = 0; - // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock virtual void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const = 0; + + virtual void device_sendrecv(void* sendbuf, size_t sendsize, int dest, + void* recvbuf, size_t recvsize, int source, + cudaStream_t stream) const = 0; }; class comms_t { @@ -334,7 +338,7 @@ class comms_t { /** * Performs a point-to-point send * - * note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock. + * if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock. * * @tparam value_t the type of data to send * @param buf pointer to array of data to send @@ -352,7 +356,7 @@ class comms_t { /** * Performs a point-to-point receive * - * note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock. + * if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock. * * @tparam value_t the type of data to be received * @param buf pointer to (initialized) array that will hold received data @@ -367,6 +371,27 @@ class comms_t { stream); } + /** + * Performs a point-to-point send/receive + * + * @tparam value_t the type of data to be received + * @param sendbuf pointer to array of data to send + * @param sendsize number of elements in sendbuf + * @param dest destination rank + * @param recvbuf pointer to (initialized) array that will hold received data + * @param recvsize number of elements in recvbuf + * @param source source rank + * @param stream CUDA stream to synchronize operation + */ + template + void device_sendrecv(value_t* sendbuf, size_t sendsize, int dest, + value_t* recvbuf, size_t recvsize, int source, + cudaStream_t stream) const { + impl_->device_sendrecv( + static_cast(snedbuf), sendsize * sizeof(value_t), dest, + static_cast(recvbuf), recvsize * sizeof(value_t), source, stream); + } + private: std::unique_ptr impl_; }; diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 396babecf4..1c33cf60ad 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -268,18 +268,28 @@ class mpi_comms : public comms_iface { } }; - // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const { NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); } - // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const { NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); } + void device_sendrecv(void* sendbuf, size_t sendsize, int dest, void* recvbuf, + size_t recvsize, int source, cudaStream_t stream) const { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + NCCL_TRY(ncclGroupStart()); + NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); + NCCL_TRY( + ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + private: bool owns_mpi_comm_; MPI_Comm mpi_comm_; diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 8de576e9f7..b150dd339a 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -382,18 +382,28 @@ class std_comms : public comms_iface { } } - // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock void device_send(const void *buf, size_t size, int dest, cudaStream_t stream) const { NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); } - // note that if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock void device_recv(void *buf, size_t size, int source, cudaStream_t stream) const { NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); } + void device_sendrecv(void *sendbuf, size_t sendsize, int dest, void *recvbuf, + size_t recvsize, int source, cudaStream_t stream) const { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + NCCL_TRY(ncclGroupStart()); + NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); + NCCL_TRY( + ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream)); + NCCL_TRY(ncclGroupEnd()); + } + private: ncclComm_t nccl_comm_; cudaStream_t stream_; From 9a49d9686b25e18e2401744d769ee5e50ffd3e7a Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Sat, 6 Feb 2021 00:06:38 -0500 Subject: [PATCH 04/12] add device_multicast_sendrecv --- cpp/include/raft/comms/comms.hpp | 49 +++++++++++++++++++++++++++- cpp/include/raft/comms/mpi_comms.hpp | 21 ++++++++++++ cpp/include/raft/comms/std_comms.hpp | 21 ++++++++++++ 3 files changed, 90 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index a26c03c096..5e729e48cf 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -139,6 +139,13 @@ class comms_iface { virtual void device_sendrecv(void* sendbuf, size_t sendsize, int dest, void* recvbuf, size_t recvsize, int source, cudaStream_t stream) const = 0; + + virtual void device_multicast_sendrecv( + void* sendbuf, std::vector const& sendsizes, + std::vector const& sendoffsets, std::vector const& dests, + void* recvbuf, std::vector const& recvsizes, + std::vector const& recvoffsets, std::vector const& sources, + cudaStream_t stream) const = 0; }; class comms_t { @@ -374,7 +381,7 @@ class comms_t { /** * Performs a point-to-point send/receive * - * @tparam value_t the type of data to be received + * @tparam value_t the type of data to be sent & received * @param sendbuf pointer to array of data to send * @param sendsize number of elements in sendbuf * @param dest destination rank @@ -392,6 +399,46 @@ class comms_t { static_cast(recvbuf), recvsize * sizeof(value_t), source, stream); } + /** + * Performs a multicast send/receive + * + * @tparam value_t the type of data to be sent & received + * @param sendbuf pointer to array of data to send + * @param sendsizes numbers of elements to send + * @param sendoffsets offsets in a number of elements from sendbuf + * @param dest destination ranks + * @param recvbuf pointer to (initialized) array that will hold received data + * @param recvsizes numbers of elements to recv + * @param recvoffsets offsets in a number of elements from recvbuf + * @param sources source ranks + * @param stream CUDA stream to synchronize operation + */ + template + void device_multicast_sendrecv(void* sendbuf, + std::vector const& sendsizes, + std::vector const& sendoffsets, + std::vector const& dests, void* recvbuf, + std::vector const& recvsizes, + std::vector const& recvoffsets, + std::vector const& sources, + cudaStream_t stream) const { + auto sendbytesizes = sendsizes; + auto sendbyteoffsets = sendoffsets; + for (size_t i = 0; i < sendsizes.size(); ++i) { + sendbytesizes[i] *= sizeof(value_t); + sendbyteoffsets[i] *= sizeof(value_t); + } + auto recvbytesizes = recvsizes; + auto recvbyteoffsets = recvoffsets; + for (size_t i = 0; i < recvsizes.size(); ++i) { + recvbytesizes[i] *= sizeof(value_t); + recvbyteoffsets[i] *= sizeof(value_t); + } + impl->device_multicast_sendrecv(sendbuf, sendbytesizes, sendbyteoffsets, + dests, recvbytesizes, recvbyteoffsets, + sources, stream); + } + private: std::unique_ptr impl_; }; diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 1c33cf60ad..e0e3f150b1 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -290,6 +290,27 @@ class mpi_comms : public comms_iface { NCCL_TRY(ncclGroupEnd()); } + void device_multicast_sendrecv(void *sendbuf, + std::vector const &sendsizes, + std::vector const &sendoffsets, + std::vector const &dests, void *recvbuf, + std::vector const &recvsizes, + std::vector const &recvoffsets, + std::vector const &sources, + cudaStream_t stream) const { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + NCCL_TRY(ncclGroupStart()); + for (size_t i = 0; i < sendsizes.size(); ++i) { + NCCL_TRY(ncclSend(sendbuf + sendoffsets[i], sendsizes[i], ncclUint8, + dests[i], nccl_comm_, stream)); + } + for (size_t i = 0; i < recvsizes.size(); ++i) { + NCCL_TRY(ncclRecv(recvbuf + recvoffsets[i], recvsizes[i], ncclUint8, + sources[i], nccl_comm_, stream)); + } + NCCL_TRY(ncclGroupEnd()); + } + private: bool owns_mpi_comm_; MPI_Comm mpi_comm_; diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index b150dd339a..c2220a77ea 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -404,6 +404,27 @@ class std_comms : public comms_iface { NCCL_TRY(ncclGroupEnd()); } + void device_multicast_sendrecv(void *sendbuf, + std::vector const &sendsizes, + std::vector const &sendoffsets, + std::vector const &dests, void *recvbuf, + std::vector const &recvsizes, + std::vector const &recvoffsets, + std::vector const &sources, + cudaStream_t stream) const { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + NCCL_TRY(ncclGroupStart()); + for (size_t i = 0; i < sendsizes.size(); ++i) { + NCCL_TRY(ncclSend(sendbuf + sendoffsets[i], sendsizes[i], ncclUint8, + dests[i], nccl_comm_, stream)); + } + for (size_t i = 0; i < recvsizes.size(); ++i) { + NCCL_TRY(ncclRecv(recvbuf + recvoffsets[i], recvsizes[i], ncclUint8, + sources[i], nccl_comm_, stream)); + } + NCCL_TRY(ncclGroupEnd()); + } + private: ncclComm_t nccl_comm_; cudaStream_t stream_; From dc101f8c7535b2962f6d6639e64875ecfea14d1b Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Mon, 8 Feb 2021 14:46:57 -0500 Subject: [PATCH 05/12] fix compile errors --- cpp/include/raft/comms/comms.hpp | 24 ++++++++++++------------ cpp/include/raft/comms/mpi_comms.hpp | 23 ++++++++++++----------- cpp/include/raft/comms/std_comms.hpp | 9 +++++---- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index 5e729e48cf..ab9fde61e2 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -19,6 +19,7 @@ #include #include +#include namespace raft { namespace comms { @@ -395,7 +396,7 @@ class comms_t { value_t* recvbuf, size_t recvsize, int source, cudaStream_t stream) const { impl_->device_sendrecv( - static_cast(snedbuf), sendsize * sizeof(value_t), dest, + static_cast(sendbuf), sendsize * sizeof(value_t), dest, static_cast(recvbuf), recvsize * sizeof(value_t), source, stream); } @@ -414,14 +415,12 @@ class comms_t { * @param stream CUDA stream to synchronize operation */ template - void device_multicast_sendrecv(void* sendbuf, - std::vector const& sendsizes, - std::vector const& sendoffsets, - std::vector const& dests, void* recvbuf, - std::vector const& recvsizes, - std::vector const& recvoffsets, - std::vector const& sources, - cudaStream_t stream) const { + void device_multicast_sendrecv( + value_t* sendbuf, std::vector const& sendsizes, + std::vector const& sendoffsets, std::vector const& dests, + value_t* recvbuf, std::vector const& recvsizes, + std::vector const& recvoffsets, std::vector const& sources, + cudaStream_t stream) const { auto sendbytesizes = sendsizes; auto sendbyteoffsets = sendoffsets; for (size_t i = 0; i < sendsizes.size(); ++i) { @@ -434,9 +433,10 @@ class comms_t { recvbytesizes[i] *= sizeof(value_t); recvbyteoffsets[i] *= sizeof(value_t); } - impl->device_multicast_sendrecv(sendbuf, sendbytesizes, sendbyteoffsets, - dests, recvbytesizes, recvbyteoffsets, - sources, stream); + impl_->device_multicast_sendrecv(static_cast(sendbuf), sendbytesizes, + sendbyteoffsets, dests, + static_cast(recvbuf), recvbytesizes, + recvbyteoffsets, sources, stream); } private: diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index e0e3f150b1..ff774d446c 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -290,23 +290,24 @@ class mpi_comms : public comms_iface { NCCL_TRY(ncclGroupEnd()); } - void device_multicast_sendrecv(void *sendbuf, - std::vector const &sendsizes, - std::vector const &sendoffsets, - std::vector const &dests, void *recvbuf, - std::vector const &recvsizes, - std::vector const &recvoffsets, - std::vector const &sources, + void device_multicast_sendrecv(void* sendbuf, + std::vector const& sendsizes, + std::vector const& sendoffsets, + std::vector const& dests, void* recvbuf, + std::vector const& recvsizes, + std::vector const& recvoffsets, + std::vector const& sources, cudaStream_t stream) const { // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock NCCL_TRY(ncclGroupStart()); for (size_t i = 0; i < sendsizes.size(); ++i) { - NCCL_TRY(ncclSend(sendbuf + sendoffsets[i], sendsizes[i], ncclUint8, - dests[i], nccl_comm_, stream)); + NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], + sendsizes[i], ncclUint8, dests[i], nccl_comm_, stream)); } for (size_t i = 0; i < recvsizes.size(); ++i) { - NCCL_TRY(ncclRecv(recvbuf + recvoffsets[i], recvsizes[i], ncclUint8, - sources[i], nccl_comm_, stream)); + NCCL_TRY(ncclRecv(static_cast(recvbuf) + recvoffsets[i], + recvsizes[i], ncclUint8, sources[i], nccl_comm_, + stream)); } NCCL_TRY(ncclGroupEnd()); } diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index c2220a77ea..9350dcb975 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -415,12 +415,13 @@ class std_comms : public comms_iface { // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock NCCL_TRY(ncclGroupStart()); for (size_t i = 0; i < sendsizes.size(); ++i) { - NCCL_TRY(ncclSend(sendbuf + sendoffsets[i], sendsizes[i], ncclUint8, - dests[i], nccl_comm_, stream)); + NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], + sendsizes[i], ncclUint8, dests[i], nccl_comm_, stream)); } for (size_t i = 0; i < recvsizes.size(); ++i) { - NCCL_TRY(ncclRecv(recvbuf + recvoffsets[i], recvsizes[i], ncclUint8, - sources[i], nccl_comm_, stream)); + NCCL_TRY(ncclRecv(static_cast(recvbuf) + recvoffsets[i], + recvsizes[i], ncclUint8, sources[i], nccl_comm_, + stream)); } NCCL_TRY(ncclGroupEnd()); } From eb88acc7b272c68ac2f560c8255d2d5421969896 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Mon, 8 Feb 2021 16:48:20 -0500 Subject: [PATCH 06/12] add device_(multicast)_send(_or_)recv test suites --- cpp/include/raft/comms/test.hpp | 167 +++++++++++++++++++++++- python/raft/dask/common/__init__.py | 3 + python/raft/dask/common/comms_utils.pyx | 53 ++++++++ python/raft/test/test_comms.py | 87 ++++++++++++ 4 files changed, 304 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index fa7e471174..90448c4b76 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -16,10 +16,14 @@ #pragma once -#include #include #include #include +#include +#include + +#include +#include namespace raft { namespace comms { @@ -158,23 +162,24 @@ bool test_collective_allgather(const handle_t &handle, int root) { bool test_collective_reducescatter(const handle_t &handle, int root) { comms_t const &communicator = handle.get_comms(); - int const send = 1; + std::vector sends(communicator.get_size(), 1); cudaStream_t stream = handle.get_stream(); raft::mr::device::buffer temp_d(handle.get_device_allocator(), stream, - 1); + sends.size()); raft::mr::device::buffer recv_d(handle.get_device_allocator(), stream, 1); - CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int), - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(), + sends.size() * sizeof(int), cudaMemcpyHostToDevice, + stream)); communicator.reducescatter(temp_d.data(), recv_d.data(), 1, op_t::SUM, stream); communicator.sync_stream(stream); int temp_h = -1; // Verify more than one byte is being sent - CUDA_CHECK(cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), + CUDA_CHECK(cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); communicator.barrier(); @@ -250,6 +255,156 @@ bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) { return ret; } +/** + * A simple sanity check that device is able to send OR receive. + * + * @param the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_send_or_recv(const handle_t &h, int numTrials) { + comms_t const &communicator = h.get_comms(); + int const rank = communicator.get_rank(); + cudaStream_t stream = h.get_stream(); + + bool ret = true; + for (int i = 0; i < numTrials; i++) { + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + std::cout << "Trial " << i << std::endl; + } + + bool sender = (rank % 2) == 0 ? true : false; + rmm::device_scalar received_data(-1, stream); + rmm::device_scalar sent_data(rank, stream); + + if (sender) { + if (rank + 1 < communicator.get_size()) { + communicator.device_send(sent_data.data(), 1, rank + 1, stream); + } + } else { + communicator.device_recv(received_data.data(), 1, rank - 1, stream); + } + + communicator.sync_stream(stream); + + if (!sender && received_data.value() != rank - 1) { + ret = false; + } + + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + } + } + + return ret; +} + +/** + * A simple sanity check that device is able to send and receive at the same time. + * + * @param the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_sendrecv(const handle_t &h, int numTrials) { + comms_t const &communicator = h.get_comms(); + int const rank = communicator.get_rank(); + cudaStream_t stream = h.get_stream(); + + bool ret = true; + for (int i = 0; i < numTrials; i++) { + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + std::cout << "Trial " << i << std::endl; + } + + rmm::device_scalar received_data(-1, stream); + rmm::device_scalar sent_data(rank, stream); + + if (rank % 2 == 0) { + if (rank + 1 < communicator.get_size()) { + communicator.device_sendrecv(sent_data.data(), 1, rank + 1, + received_data.data(), 1, rank + 1, stream); + } + } else { + communicator.device_sendrecv(sent_data.data(), 1, rank - 1, + received_data.data(), 1, rank - 1, stream); + } + + communicator.sync_stream(stream); + + if (((rank % 2 == 0) && (received_data.value() != rank + 1)) || + ((rank % 2 == 1) && (received_data.value() != rank - 1))) { + ret = false; + } + + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + } + } + + return ret; +} + +/** + * A simple sanity check that device is able to perform multiple concurrent sends and receives. + * + * @param the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_multicast_sendrecv(const handle_t &h, + int numTrials) { + comms_t const &communicator = h.get_comms(); + int const rank = communicator.get_rank(); + cudaStream_t stream = h.get_stream(); + + bool ret = true; + for (int i = 0; i < numTrials; i++) { + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + std::cout << "Trial " << i << std::endl; + } + + rmm::device_uvector received_data(communicator.get_size(), stream); + rmm::device_scalar sent_data(rank, stream); + + std::vector sendsizes(communicator.get_size(), 1); + std::vector sendoffsets(communicator.get_size(), 0); + std::vector dests(communicator.get_size()); + std::iota(dests.begin(), dests.end(), int{0}); + + std::vector recvsizes(communicator.get_size(), 1); + std::vector recvoffsets(communicator.get_size()); + std::iota(recvoffsets.begin(), recvoffsets.end(), size_t{0}); + std::vector srcs(communicator.get_size()); + std::iota(srcs.begin(), srcs.end(), int{0}); + + communicator.device_multicast_sendrecv( + sent_data.data(), sendsizes, sendoffsets, dests, received_data.data(), + recvsizes, recvoffsets, srcs, stream); + + communicator.sync_stream(stream); + + std::vector h_received_data(communicator.get_size()); + raft::update_host(h_received_data.data(), received_data.data(), + received_data.size(), stream); + CUDA_TRY(cudaStreamSynchronize(stream)); + for (int i = 0; i < communicator.get_size(); ++i) { + if (h_received_data[i] != i) { + ret = false; + } + } + + if (communicator.get_rank() == 0) { + std::cout << "=========================" << std::endl; + } + } + + return ret; +} + /** * A simple test that the comms can be split into 2 separate subcommunicators * diff --git a/python/raft/dask/common/__init__.py b/python/raft/dask/common/__init__.py index 788af46c92..27b5e74e68 100644 --- a/python/raft/dask/common/__init__.py +++ b/python/raft/dask/common/__init__.py @@ -20,6 +20,9 @@ from .comms_utils import inject_comms_on_handle_coll_only from .comms_utils import perform_test_comms_allreduce from .comms_utils import perform_test_comms_send_recv +from .comms_utils import perform_test_comms_device_send_or_recv +from .comms_utils import perform_test_comms_device_sendrecv +from .comms_utils import perform_test_comms_device_multicast_sendrecv from .comms_utils import perform_test_comms_allgather from .comms_utils import perform_test_comms_bcast from .comms_utils import perform_test_comms_reduce diff --git a/python/raft/dask/common/comms_utils.pyx b/python/raft/dask/common/comms_utils.pyx index 4dbd2f1a7c..2ea534fa22 100644 --- a/python/raft/dask/common/comms_utils.pyx +++ b/python/raft/dask/common/comms_utils.pyx @@ -63,6 +63,12 @@ cdef extern from "raft/comms/test.hpp" namespace "raft::comms": bool test_collective_reducescatter(const handle_t &h, int root) except + bool test_pointToPoint_simple_send_recv(const handle_t &h, int numTrials) except + + bool test_pointToPoint_device_send_or_recv(const handle_t &h, + int numTrials) except + + bool test_pointToPoint_device_sendrecv(const handle_t &h, + int numTrials) except + + bool test_pointToPoint_device_multicast_sendrecv(const handle_t &h, + int numTrials) except + bool test_commsplit(const handle_t &h, int n_colors) except + @@ -139,11 +145,58 @@ def perform_test_comms_send_recv(handle, n_trials): ---------- handle : raft.common.Handle handle containing comms_t to use + n_trilas : int + Number of test trials """ cdef const handle_t *h = handle.getHandle() return test_pointToPoint_simple_send_recv(deref(h), n_trials) +def perform_test_comms_device_send_or_recv(handle, n_trials): + """ + Performs a p2p device send or recv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + n_trilas : int + Number of test trials + """ + cdef const handle_t *h = handle.getHandle() + return test_pointToPoint_device_send_or_recv(deref(h), n_trials) + + +def perform_test_comms_device_sendrecv(handle, n_trials): + """ + Performs a p2p device concurrent send & recv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + n_trilas : int + Number of test trials + """ + cdef const handle_t *h = handle.getHandle() + return test_pointToPoint_device_sendrecv(deref(h), n_trials) + + +def perform_test_comms_device_multicast_sendrecv(handle, n_trials): + """ + Performs a p2p device concurrent multicast send & recv on the current worker + + Parameters + ---------- + handle : raft.common.Handle + handle containing comms_t to use + n_trilas : int + Number of test trials + """ + cdef const handle_t *h = handle.getHandle() + return test_pointToPoint_device_multicast_sendrecv(deref(h), n_trials) + + def perform_test_comm_split(handle, n_colors): """ Performs a p2p send/recv on the current worker diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index 7dccb7bbae..0721f5d60f 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -24,6 +24,9 @@ from raft.dask import Comms from raft.dask.common import local_handle from raft.dask.common import perform_test_comms_send_recv + from raft.dask.common import perform_test_comms_device_send_or_recv + from raft.dask.common import perform_test_comms_device_sendrecv + from raft.dask.common import perform_test_comms_device_multicast_sendrecv from raft.dask.common import perform_test_comms_allreduce from raft.dask.common import perform_test_comms_bcast from raft.dask.common import perform_test_comms_reduce @@ -63,6 +66,21 @@ def func_test_send_recv(sessionId, n_trials): return perform_test_comms_send_recv(handle, n_trials) +def func_test_device_send_or_recv(sessionId, n_trials): + handle = local_handle(sessionId) + return perform_test_comms_device_send_or_recv(handle, n_trials) + + +def func_test_device_sendrecv(sessionId, n_trials): + handle = local_handle(sessionId) + return perform_test_comms_device_sendrecv(handle, n_trials) + + +def func_test_device_multicast_sendrecv(sessionId, n_trials): + handle = local_handle(sessionId) + return perform_test_comms_device_multicast_sendrecv(handle, n_trials) + + def func_test_comm_split(sessionId, n_trials): handle = local_handle(sessionId) return perform_test_comm_split(handle, n_trials) @@ -243,3 +261,72 @@ def test_send_recv(n_trials, client): wait(dfs, timeout=5) assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize("n_trials", [1, 5]) +def test_device_send_or_recv(n_trials, client): + + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + client.submit( + func_test_device_send_or_recv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize("n_trials", [1, 5]) +def test_device_sendrecv(n_trials, client): + + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + client.submit( + func_test_device_sendrecv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) + + +@pytest.mark.nccl +@pytest.mark.parametrize("n_trials", [1, 5]) +def test_device_multicast_sendrecv(n_trials, client): + + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + client.submit( + func_test_device_multicast_sendrecv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) From cfd46c79106421443301d1afe42ce8a3e79c5794 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Mon, 8 Feb 2021 21:36:46 -0500 Subject: [PATCH 07/12] undo reducescatter test bug fixes (this should come from a PR merged in 0.18) --- cpp/include/raft/comms/test.hpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index 90448c4b76..51ce76cee1 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -162,24 +162,23 @@ bool test_collective_allgather(const handle_t &handle, int root) { bool test_collective_reducescatter(const handle_t &handle, int root) { comms_t const &communicator = handle.get_comms(); - std::vector sends(communicator.get_size(), 1); + int const send = 1; cudaStream_t stream = handle.get_stream(); raft::mr::device::buffer temp_d(handle.get_device_allocator(), stream, - sends.size()); + 1); raft::mr::device::buffer recv_d(handle.get_device_allocator(), stream, 1); - CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), sends.data(), - sends.size() * sizeof(int), cudaMemcpyHostToDevice, - stream)); + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send sizeof(int), + cudaMemcpyHostToDevice, stream)); communicator.reducescatter(temp_d.data(), recv_d.data(), 1, op_t::SUM, stream); communicator.sync_stream(stream); int temp_h = -1; // Verify more than one byte is being sent - CUDA_CHECK(cudaMemcpyAsync(&temp_h, recv_d.data(), sizeof(int), + CUDA_CHECK(cudaMemcpyAsync(&temp_h, temp_d.data(), sizeof(int), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); communicator.barrier(); From c87a0cd6786d821f81928984baba51da3de6c25e Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Mon, 8 Feb 2021 21:41:38 -0500 Subject: [PATCH 08/12] compile error fix --- cpp/include/raft/comms/test.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/test.hpp index 51ce76cee1..08d2cf05c7 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/test.hpp @@ -171,7 +171,7 @@ bool test_collective_reducescatter(const handle_t &handle, int root) { raft::mr::device::buffer recv_d(handle.get_device_allocator(), stream, 1); - CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send sizeof(int), + CUDA_CHECK(cudaMemcpyAsync(temp_d.data(), &send, sizeof(int), cudaMemcpyHostToDevice, stream)); communicator.reducescatter(temp_d.data(), recv_d.data(), 1, op_t::SUM, From df8a50c26874c362620c30c290f216530a953f74 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Mon, 8 Feb 2021 21:50:47 -0500 Subject: [PATCH 09/12] fix flake8 style error --- python/raft/dask/common/comms_utils.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/raft/dask/common/comms_utils.pyx b/python/raft/dask/common/comms_utils.pyx index 2ea534fa22..94ab7695f4 100644 --- a/python/raft/dask/common/comms_utils.pyx +++ b/python/raft/dask/common/comms_utils.pyx @@ -169,7 +169,7 @@ def perform_test_comms_device_send_or_recv(handle, n_trials): def perform_test_comms_device_sendrecv(handle, n_trials): """ - Performs a p2p device concurrent send & recv on the current worker + Performs a p2p device concurrent send&recv on the current worker Parameters ---------- @@ -184,7 +184,7 @@ def perform_test_comms_device_sendrecv(handle, n_trials): def perform_test_comms_device_multicast_sendrecv(handle, n_trials): """ - Performs a p2p device concurrent multicast send & recv on the current worker + Performs a p2p device concurrent multicast send&recv on the current worker Parameters ---------- From def1b4dc94a69e05d38db65fb9b6369e210e1ef1 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Tue, 9 Feb 2021 13:26:59 -0500 Subject: [PATCH 10/12] add const to send buffer pointer --- cpp/include/raft/comms/comms.hpp | 8 ++++---- cpp/include/raft/comms/mpi_comms.hpp | 7 ++++--- cpp/include/raft/comms/std_comms.hpp | 7 ++++--- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index ab9fde61e2..8b89cba321 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -137,12 +137,12 @@ class comms_iface { virtual void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const = 0; - virtual void device_sendrecv(void* sendbuf, size_t sendsize, int dest, + virtual void device_sendrecv(const void* sendbuf, size_t sendsize, int dest, void* recvbuf, size_t recvsize, int source, cudaStream_t stream) const = 0; virtual void device_multicast_sendrecv( - void* sendbuf, std::vector const& sendsizes, + const void* sendbuf, std::vector const& sendsizes, std::vector const& sendoffsets, std::vector const& dests, void* recvbuf, std::vector const& recvsizes, std::vector const& recvoffsets, std::vector const& sources, @@ -392,7 +392,7 @@ class comms_t { * @param stream CUDA stream to synchronize operation */ template - void device_sendrecv(value_t* sendbuf, size_t sendsize, int dest, + void device_sendrecv(const value_t* sendbuf, size_t sendsize, int dest, value_t* recvbuf, size_t recvsize, int source, cudaStream_t stream) const { impl_->device_sendrecv( @@ -416,7 +416,7 @@ class comms_t { */ template void device_multicast_sendrecv( - value_t* sendbuf, std::vector const& sendsizes, + const value_t* sendbuf, std::vector const& sendsizes, std::vector const& sendoffsets, std::vector const& dests, value_t* recvbuf, std::vector const& recvsizes, std::vector const& recvoffsets, std::vector const& sources, diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index ff774d446c..138eba66f3 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -280,8 +280,9 @@ class mpi_comms : public comms_iface { NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); } - void device_sendrecv(void* sendbuf, size_t sendsize, int dest, void* recvbuf, - size_t recvsize, int source, cudaStream_t stream) const { + void device_sendrecv(const void* sendbuf, size_t sendsize, int dest, + void* recvbuf, size_t recvsize, int source, + cudaStream_t stream) const { // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock NCCL_TRY(ncclGroupStart()); NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); @@ -290,7 +291,7 @@ class mpi_comms : public comms_iface { NCCL_TRY(ncclGroupEnd()); } - void device_multicast_sendrecv(void* sendbuf, + void device_multicast_sendrecv(const void* sendbuf, std::vector const& sendsizes, std::vector const& sendoffsets, std::vector const& dests, void* recvbuf, diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 9350dcb975..32f110464d 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -394,8 +394,9 @@ class std_comms : public comms_iface { NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); } - void device_sendrecv(void *sendbuf, size_t sendsize, int dest, void *recvbuf, - size_t recvsize, int source, cudaStream_t stream) const { + void device_sendrecv(const void *sendbuf, size_t sendsize, int dest, + void *recvbuf, size_t recvsize, int source, + cudaStream_t stream) const { // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock NCCL_TRY(ncclGroupStart()); NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); @@ -404,7 +405,7 @@ class std_comms : public comms_iface { NCCL_TRY(ncclGroupEnd()); } - void device_multicast_sendrecv(void *sendbuf, + void device_multicast_sendrecv(const void *sendbuf, std::vector const &sendsizes, std::vector const &sendoffsets, std::vector const &dests, void *recvbuf, From 16da04783d8f46852c0775b9cb572fd41c0b8348 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Tue, 9 Feb 2021 13:32:37 -0500 Subject: [PATCH 11/12] fix compile errors --- cpp/include/raft/comms/comms.hpp | 4 ++-- cpp/include/raft/comms/mpi_comms.hpp | 2 +- cpp/include/raft/comms/std_comms.hpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index 8b89cba321..d2a8f08f6d 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -396,7 +396,7 @@ class comms_t { value_t* recvbuf, size_t recvsize, int source, cudaStream_t stream) const { impl_->device_sendrecv( - static_cast(sendbuf), sendsize * sizeof(value_t), dest, + static_cast(sendbuf), sendsize * sizeof(value_t), dest, static_cast(recvbuf), recvsize * sizeof(value_t), source, stream); } @@ -433,7 +433,7 @@ class comms_t { recvbytesizes[i] *= sizeof(value_t); recvbyteoffsets[i] *= sizeof(value_t); } - impl_->device_multicast_sendrecv(static_cast(sendbuf), sendbytesizes, + impl_->device_multicast_sendrecv(static_cast(sendbuf), sendbytesizes, sendbyteoffsets, dests, static_cast(recvbuf), recvbytesizes, recvbyteoffsets, sources, stream); diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 138eba66f3..2fa27f5bfe 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -302,7 +302,7 @@ class mpi_comms : public comms_iface { // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock NCCL_TRY(ncclGroupStart()); for (size_t i = 0; i < sendsizes.size(); ++i) { - NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], + NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], sendsizes[i], ncclUint8, dests[i], nccl_comm_, stream)); } for (size_t i = 0; i < recvsizes.size(); ++i) { diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 32f110464d..bcb1cfc2dc 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -416,7 +416,7 @@ class std_comms : public comms_iface { // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock NCCL_TRY(ncclGroupStart()); for (size_t i = 0; i < sendsizes.size(); ++i) { - NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], + NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], sendsizes[i], ncclUint8, dests[i], nccl_comm_, stream)); } for (size_t i = 0; i < recvsizes.size(); ++i) { From 4ec520e7362ba254eaf0052d2fb2bcaf8a3c0257 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Thu, 11 Feb 2021 18:06:36 -0500 Subject: [PATCH 12/12] clang-format --- cpp/include/raft/comms/comms.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index da768f45fb..dc172c9503 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -481,8 +481,8 @@ class comms_t { recvbytesizes[i] *= sizeof(value_t); recvbyteoffsets[i] *= sizeof(value_t); } - impl_->device_multicast_sendrecv(static_cast(sendbuf), sendbytesizes, - sendbyteoffsets, dests, + impl_->device_multicast_sendrecv(static_cast(sendbuf), + sendbytesizes, sendbyteoffsets, dests, static_cast(recvbuf), recvbytesizes, recvbyteoffsets, sources, stream); }