diff --git a/cpp/cmake/thirdparty/get_thrust.cmake b/cpp/cmake/thirdparty/get_thrust.cmake index c28ff6e66d..3813d0ea02 100644 --- a/cpp/cmake/thirdparty/get_thrust.cmake +++ b/cpp/cmake/thirdparty/get_thrust.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -14,14 +14,14 @@ # Use CPM to find or clone thrust function(find_and_configure_thrust) - include(${rapids-cmake-dir}/cpm/thrust.cmake) + include(${rapids-cmake-dir}/cpm/thrust.cmake) - rapids_cpm_thrust( - NAMESPACE raft - BUILD_EXPORT_SET raft-exports - INSTALL_EXPORT_SET raft-exports - ) + rapids_cpm_thrust( + NAMESPACE raft + BUILD_EXPORT_SET raft-exports + INSTALL_EXPORT_SET raft-exports + ) endfunction() -find_and_configure_thrust() +find_and_configure_thrust() \ No newline at end of file diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp index 0de84117e0..14c33c6cf2 100644 --- a/cpp/include/raft/comms/comms.hpp +++ b/cpp/include/raft/comms/comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,52 +38,70 @@ enum class status_t { }; template -constexpr datatype_t get_type(); +constexpr datatype_t + +get_type(); template <> -constexpr datatype_t get_type() +constexpr datatype_t + +get_type() { return datatype_t::CHAR; } template <> -constexpr datatype_t get_type() +constexpr datatype_t + +get_type() { return datatype_t::UINT8; } template <> -constexpr datatype_t get_type() +constexpr datatype_t + +get_type() { return datatype_t::INT32; } template <> -constexpr datatype_t get_type() +constexpr datatype_t + +get_type() { return datatype_t::UINT32; } template <> -constexpr datatype_t get_type() +constexpr datatype_t + +get_type() { return datatype_t::INT64; } template <> -constexpr datatype_t get_type() +constexpr datatype_t + +get_type() { return datatype_t::UINT64; } template <> -constexpr datatype_t get_type() +constexpr datatype_t + +get_type() { return datatype_t::FLOAT32; } template <> -constexpr datatype_t get_type() +constexpr datatype_t + +get_type() { return datatype_t::FLOAT64; } @@ -93,10 +111,12 @@ class comms_iface { virtual ~comms_iface() {} virtual int get_size() const = 0; + virtual int get_rank() const = 0; virtual std::unique_ptr comm_split(int color, int key) const = 0; - virtual void barrier() const = 0; + + virtual void barrier() const = 0; virtual status_t sync_stream(cudaStream_t stream) const = 0; diff --git a/cpp/include/raft/comms/comms_test.hpp b/cpp/include/raft/comms/comms_test.hpp new file mode 100644 index 0000000000..1acb72bc85 --- /dev/null +++ b/cpp/include/raft/comms/comms_test.hpp @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace raft { +namespace comms { + +/** + * @brief A simple sanity check that NCCL is able to perform a collective operation + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_allreduce(const handle_t& handle, int root) +{ + return detail::test_collective_allreduce(handle, root); +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective operation + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_broadcast(const handle_t& handle, int root) +{ + return detail::test_collective_broadcast(handle, root); +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective reduce + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_reduce(const handle_t& handle, int root) +{ + return detail::test_collective_reduce(handle, root); +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective allgather + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_allgather(const handle_t& handle, int root) +{ + return detail::test_collective_allgather(handle, root); +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective gather + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_gather(const handle_t& handle, int root) +{ + return detail::test_collective_gather(handle, root); +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective gatherv + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_gatherv(const handle_t& handle, int root) +{ + return detail::test_collective_gatherv(handle, root); +} + +/** + * @brief A simple sanity check that NCCL is able to perform a collective reducescatter + * + * @param[in] handle the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] root the root rank id + */ +bool test_collective_reducescatter(const handle_t& handle, int root) +{ + return detail::test_collective_reducescatter(handle, root); +} + +/** + * A simple sanity check that UCX is able to send messages between all ranks + * + * @param[in] h the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param[in] numTrials number of iterations of all-to-all messaging to perform + */ +bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) +{ + return detail::test_pointToPoint_simple_send_recv(h, numTrials); +} + +/** + * A simple sanity check that device is able to send OR receive. + * + * @param h the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param numTrials number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) +{ + return detail::test_pointToPoint_device_send_or_recv(h, numTrials); +} + +/** + * A simple sanity check that device is able to send and receive at the same time. + * + * @param h the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param numTrials number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) +{ + return detail::test_pointToPoint_device_sendrecv(h, numTrials); +} + +/** + * A simple sanity check that device is able to perform multiple concurrent sends and receives. + * + * @param h the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param numTrials number of iterations of send or receive messaging to perform + */ +bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrials) +{ + return detail::test_pointToPoint_device_multicast_sendrecv(h, numTrials); +} + +/** + * A simple test that the comms can be split into 2 separate subcommunicators + * + * @param h the raft handle to use. This is expected to already have an + * initialized comms instance. + * @param n_colors number of different colors to test + */ +bool test_commsplit(const handle_t& h, int n_colors) { return detail::test_commsplit(h, n_colors); } +} // namespace comms +}; // namespace raft diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp new file mode 100644 index 0000000000..3bfd72baf9 --- /dev/null +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -0,0 +1,441 @@ +/* + * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +#define RAFT_MPI_TRY(call) \ + do { \ + int status = call; \ + if (MPI_SUCCESS != status) { \ + int mpi_error_string_lenght = 0; \ + char mpi_error_string[MPI_MAX_ERROR_STRING]; \ + MPI_Error_string(status, mpi_error_string, &mpi_error_string_lenght); \ + RAFT_EXPECTS( \ + MPI_SUCCESS == status, "ERROR: MPI call='%s'. Reason:%s\n", #call, mpi_error_string); \ + } \ + } while (0) + +// FIXME: Remove after consumer rename +#ifndef MPI_TRY +#define MPI_TRY(call) RAFT_MPI_TRY(call) +#endif + +#define RAFT_MPI_TRY_NO_THROW(call) \ + do { \ + int status = call; \ + if (MPI_SUCCESS != status) { \ + int mpi_error_string_lenght = 0; \ + char mpi_error_string[MPI_MAX_ERROR_STRING]; \ + MPI_Error_string(status, mpi_error_string, &mpi_error_string_lenght); \ + printf("MPI call='%s' at file=%s line=%d failed with %s ", \ + #call, \ + __FILE__, \ + __LINE__, \ + mpi_error_string); \ + } \ + } while (0) + +// FIXME: Remove after consumer rename +#ifndef MPI_TRY_NO_THROW +#define MPI_TRY_NO_THROW(call) RAFT_MPI_TRY_NO_THROW(call) +#endif +namespace raft { +namespace comms { +namespace detail { + +constexpr MPI_Datatype get_mpi_datatype(const datatype_t datatype) +{ + switch (datatype) { + case datatype_t::CHAR: return MPI_CHAR; + case datatype_t::UINT8: return MPI_UNSIGNED_CHAR; + case datatype_t::INT32: return MPI_INT; + case datatype_t::UINT32: return MPI_UNSIGNED; + case datatype_t::INT64: return MPI_LONG_LONG; + case datatype_t::UINT64: return MPI_UNSIGNED_LONG_LONG; + case datatype_t::FLOAT32: return MPI_FLOAT; + case datatype_t::FLOAT64: return MPI_DOUBLE; + default: + // Execution should never reach here. This takes care of compiler warning. + return MPI_DOUBLE; + } +} + +constexpr MPI_Op get_mpi_op(const op_t op) +{ + switch (op) { + case op_t::SUM: return MPI_SUM; + case op_t::PROD: return MPI_PROD; + case op_t::MIN: return MPI_MIN; + case op_t::MAX: return MPI_MAX; + default: + // Execution should never reach here. This takes care of compiler warning. + return MPI_MAX; + } +} + +class mpi_comms : public comms_iface { + public: + mpi_comms(MPI_Comm comm, const bool owns_mpi_comm) + : owns_mpi_comm_(owns_mpi_comm), mpi_comm_(comm), size_(0), rank_(1), next_request_id_(0) + { + int mpi_is_initialized = 0; + RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized)); + RAFT_EXPECTS(mpi_is_initialized, "ERROR: MPI is not initialized!"); + RAFT_MPI_TRY(MPI_Comm_size(mpi_comm_, &size_)); + RAFT_MPI_TRY(MPI_Comm_rank(mpi_comm_, &rank_)); + // get NCCL unique ID at rank 0 and broadcast it to all others + ncclUniqueId id; + if (0 == rank_) RAFT_NCCL_TRY(ncclGetUniqueId(&id)); + RAFT_MPI_TRY(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, mpi_comm_)); + + // initializing NCCL + RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_)); + } + + virtual ~mpi_comms() + { + // finalizing NCCL + RAFT_NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_)); + if (owns_mpi_comm_) { RAFT_MPI_TRY_NO_THROW(MPI_Comm_free(&mpi_comm_)); } + } + + int get_size() const { return size_; } + + int get_rank() const { return rank_; } + + std::unique_ptr comm_split(int color, int key) const + { + MPI_Comm new_comm; + RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm)); + return std::unique_ptr(new mpi_comms(new_comm, true)); + } + + void barrier() const { RAFT_MPI_TRY(MPI_Barrier(mpi_comm_)); } + + void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const + { + MPI_Request mpi_req; + request_t req_id; + if (free_requests_.empty()) { + req_id = next_request_id_++; + } else { + auto it = free_requests_.begin(); + req_id = *it; + free_requests_.erase(it); + } + RAFT_MPI_TRY(MPI_Isend(buf, size, MPI_BYTE, dest, tag, mpi_comm_, &mpi_req)); + requests_in_flight_.insert(std::make_pair(req_id, mpi_req)); + *request = req_id; + } + + void irecv(void* buf, size_t size, int source, int tag, request_t* request) const + { + MPI_Request mpi_req; + request_t req_id; + if (free_requests_.empty()) { + req_id = next_request_id_++; + } else { + auto it = free_requests_.begin(); + req_id = *it; + free_requests_.erase(it); + } + + RAFT_MPI_TRY(MPI_Irecv(buf, size, MPI_BYTE, source, tag, mpi_comm_, &mpi_req)); + requests_in_flight_.insert(std::make_pair(req_id, mpi_req)); + *request = req_id; + } + + void waitall(int count, request_t array_of_requests[]) const + { + std::vector requests; + requests.reserve(count); + for (int i = 0; i < count; ++i) { + auto req_it = requests_in_flight_.find(array_of_requests[i]); + RAFT_EXPECTS(requests_in_flight_.end() != req_it, + "ERROR: waitall on invalid request: %d", + array_of_requests[i]); + requests.push_back(req_it->second); + free_requests_.insert(req_it->first); + requests_in_flight_.erase(req_it); + } + RAFT_MPI_TRY(MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE)); + } + + void allreduce(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + op_t op, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclAllReduce( + sendbuff, recvbuff, count, get_nccl_datatype(datatype), get_nccl_op(op), nccl_comm_, stream)); + } + + void bcast(void* buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const + { + RAFT_NCCL_TRY( + ncclBroadcast(buff, buff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + } + + void bcast(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclBroadcast( + sendbuff, recvbuff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + } + + void reduce(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + op_t op, + int root, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclReduce(sendbuff, + recvbuff, + count, + get_nccl_datatype(datatype), + get_nccl_op(op), + root, + nccl_comm_, + stream)); + } + + void allgather(const void* sendbuff, + void* recvbuff, + size_t sendcount, + datatype_t datatype, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclAllGather( + sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), nccl_comm_, stream)); + } + + void allgatherv(const void* sendbuf, + void* recvbuf, + const size_t* recvcounts, + const size_t* displs, + datatype_t datatype, + cudaStream_t stream) const + { + // From: "An Empirical Evaluation of Allgatherv on Multi-GPU Systems" - + // https://arxiv.org/pdf/1812.05964.pdf Listing 1 on page 4. + for (int root = 0; root < size_; ++root) { + RAFT_NCCL_TRY( + ncclBroadcast(sendbuf, + static_cast(recvbuf) + displs[root] * get_datatype_size(datatype), + recvcounts[root], + get_nccl_datatype(datatype), + root, + nccl_comm_, + stream)); + } + } + + void gather(const void* sendbuff, + void* recvbuff, + size_t sendcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + size_t dtype_size = get_datatype_size(datatype); + RAFT_NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + sendcount * r * dtype_size, + sendcount, + get_nccl_datatype(datatype), + r, + nccl_comm_, + stream)); + } + } + RAFT_NCCL_TRY( + ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + RAFT_NCCL_TRY(ncclGroupEnd()); + } + + void gatherv(const void* sendbuff, + void* recvbuff, + size_t sendcount, + const size_t* recvcounts, + const size_t* displs, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + size_t dtype_size = get_datatype_size(datatype); + RAFT_NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + displs[r] * dtype_size, + recvcounts[r], + get_nccl_datatype(datatype), + r, + nccl_comm_, + stream)); + } + } + RAFT_NCCL_TRY( + ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + RAFT_NCCL_TRY(ncclGroupEnd()); + } + + void reducescatter(const void* sendbuff, + void* recvbuff, + size_t recvcount, + datatype_t datatype, + op_t op, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclReduceScatter(sendbuff, + recvbuff, + recvcount, + get_nccl_datatype(datatype), + get_nccl_op(op), + nccl_comm_, + stream)); + } + + status_t sync_stream(cudaStream_t stream) const + { + cudaError_t cudaErr; + ncclResult_t ncclErr, ncclAsyncErr; + while (1) { + cudaErr = cudaStreamQuery(stream); + if (cudaErr == cudaSuccess) return status_t::SUCCESS; + + if (cudaErr != cudaErrorNotReady) { + // An error occurred querying the status of the stream + return status_t::ERROR; + } + + ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr); + if (ncclErr != ncclSuccess) { + // An error occurred retrieving the asynchronous error + return status_t::ERROR; + } + + if (ncclAsyncErr != ncclSuccess) { + // An asynchronous error happened. Stop the operation and destroy + // the communicator + ncclErr = ncclCommAbort(nccl_comm_); + if (ncclErr != ncclSuccess) + // Caller may abort with an exception or try to re-create a new communicator. + return status_t::ABORT; + } + + // Let other threads (including NCCL threads) use the CPU. + pthread_yield(); + } + }; + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); + } + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); + } + + void device_sendrecv(const void* sendbuf, + size_t sendsize, + int dest, + void* recvbuf, + size_t recvsize, + int source, + cudaStream_t stream) const + { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + RAFT_NCCL_TRY(ncclGroupStart()); + RAFT_NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); + RAFT_NCCL_TRY(ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream)); + RAFT_NCCL_TRY(ncclGroupEnd()); + } + + void device_multicast_sendrecv(const void* sendbuf, + std::vector const& sendsizes, + std::vector const& sendoffsets, + std::vector const& dests, + void* recvbuf, + std::vector const& recvsizes, + std::vector const& recvoffsets, + std::vector const& sources, + cudaStream_t stream) const + { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + RAFT_NCCL_TRY(ncclGroupStart()); + for (size_t i = 0; i < sendsizes.size(); ++i) { + RAFT_NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], + sendsizes[i], + ncclUint8, + dests[i], + nccl_comm_, + stream)); + } + for (size_t i = 0; i < recvsizes.size(); ++i) { + RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuf) + recvoffsets[i], + recvsizes[i], + ncclUint8, + sources[i], + nccl_comm_, + stream)); + } + RAFT_NCCL_TRY(ncclGroupEnd()); + } + + private: + bool owns_mpi_comm_; + MPI_Comm mpi_comm_; + + ncclComm_t nccl_comm_; + int size_; + int rank_; + mutable request_t next_request_id_; + mutable std::unordered_map requests_in_flight_; + mutable std::unordered_set free_requests_; +}; + +} // end namespace detail +}; // end namespace comms +}; // end namespace raft diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp new file mode 100644 index 0000000000..758a9d3781 --- /dev/null +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -0,0 +1,556 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include + +#include + +#include + +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace comms { +namespace detail { + +class std_comms : public comms_iface { + public: + std_comms() = delete; + + /** + * @brief Constructor for collective + point-to-point operation. + * @param nccl_comm initialized nccl comm + * @param ucp_worker initialized ucp_worker instance + * @param eps shared pointer to array of ucp endpoints + * @param num_ranks number of ranks in the cluster + * @param rank rank of the current worker + * @param stream cuda stream for synchronizing and ordering collective operations + * @param subcomms_ucp use ucp for subcommunicators + */ + std_comms(ncclComm_t nccl_comm, + ucp_worker_h ucp_worker, + std::shared_ptr eps, + int num_ranks, + int rank, + cudaStream_t stream, + bool subcomms_ucp = true) + : nccl_comm_(nccl_comm), + stream_(stream), + status_(2, stream), + num_ranks_(num_ranks), + rank_(rank), + subcomms_ucp_(subcomms_ucp), + ucp_worker_(ucp_worker), + ucp_eps_(eps), + next_request_id_(0) + { + initialize(); + }; + + /** + * @brief constructor for collective-only operation + * @param nccl_comm initilized nccl communicator + * @param num_ranks size of the cluster + * @param rank rank of the current worker + * @param stream stream for ordering collective operations + */ + std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, cudaStream_t stream) + : nccl_comm_(nccl_comm), + stream_(stream), + status_(2, stream), + num_ranks_(num_ranks), + rank_(rank), + subcomms_ucp_(false) + { + initialize(); + }; + + void initialize() + { + sendbuff_ = status_.data(); + recvbuff_ = status_.data() + 1; + } + + int get_size() const { return num_ranks_; } + + int get_rank() const { return rank_; } + + std::unique_ptr comm_split(int color, int key) const + { + rmm::device_uvector d_colors(get_size(), stream_); + rmm::device_uvector d_keys(get_size(), stream_); + + update_device(d_colors.data() + get_rank(), &color, 1, stream_); + update_device(d_keys.data() + get_rank(), &key, 1, stream_); + + allgather(d_colors.data() + get_rank(), d_colors.data(), 1, datatype_t::INT32, stream_); + allgather(d_keys.data() + get_rank(), d_keys.data(), 1, datatype_t::INT32, stream_); + this->sync_stream(stream_); + + std::vector h_colors(get_size()); + std::vector h_keys(get_size()); + + update_host(h_colors.data(), d_colors.data(), get_size(), stream_); + update_host(h_keys.data(), d_keys.data(), get_size(), stream_); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream_)); + + std::vector subcomm_ranks{}; + std::vector new_ucx_ptrs{}; + + for (int i = 0; i < get_size(); ++i) { + if (h_colors[i] == color) { + subcomm_ranks.push_back(i); + if (ucp_worker_ != nullptr && subcomms_ucp_) { new_ucx_ptrs.push_back((*ucp_eps_)[i]); } + } + } + + ncclUniqueId id{}; + if (get_rank() == subcomm_ranks[0]) { // root of the new subcommunicator + RAFT_NCCL_TRY(ncclGetUniqueId(&id)); + std::vector requests(subcomm_ranks.size() - 1); + for (size_t i = 1; i < subcomm_ranks.size(); ++i) { + isend(&id, sizeof(ncclUniqueId), subcomm_ranks[i], color, requests.data() + (i - 1)); + } + waitall(requests.size(), requests.data()); + } else { + request_t request{}; + irecv(&id, sizeof(ncclUniqueId), subcomm_ranks[0], color, &request); + waitall(1, &request); + } + // FIXME: this seems unnecessary, do more testing and remove this + barrier(); + + ncclComm_t nccl_comm; + RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_ranks.size(), id, key)); + + if (ucp_worker_ != nullptr && subcomms_ucp_) { + auto eps_sp = std::make_shared(new_ucx_ptrs.data()); + return std::unique_ptr(new std_comms(nccl_comm, + (ucp_worker_h)ucp_worker_, + eps_sp, + subcomm_ranks.size(), + key, + stream_, + subcomms_ucp_)); + } else { + return std::unique_ptr( + new std_comms(nccl_comm, subcomm_ranks.size(), key, stream_)); + } + } + + void barrier() const + { + RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_)); + RAFT_CUDA_TRY(cudaMemsetAsync(recvbuff_, 1, sizeof(int), stream_)); + + allreduce(sendbuff_, recvbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); + + ASSERT(sync_stream(stream_) == status_t::SUCCESS, + "ERROR: syncStream failed. This can be caused by a failed rank_."); + } + + void get_request_id(request_t* req) const + { + request_t req_id; + + if (this->free_requests_.empty()) + req_id = this->next_request_id_++; + else { + auto it = this->free_requests_.begin(); + req_id = *it; + this->free_requests_.erase(it); + } + *req = req_id; + } + + void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const + { + ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + + get_request_id(request); + ucp_ep_h ep_ptr = (*ucp_eps_)[dest]; + + ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); + + this->ucp_handler_.ucp_isend(ucp_req, ep_ptr, buf, size, tag, default_tag_mask, get_rank()); + + requests_in_flight_.insert(std::make_pair(*request, ucp_req)); + } + + void irecv(void* buf, size_t size, int source, int tag, request_t* request) const + { + ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + + get_request_id(request); + + ucp_ep_h ep_ptr = (*ucp_eps_)[source]; + + ucp_tag_t tag_mask = default_tag_mask; + + ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); + ucp_handler_.ucp_irecv(ucp_req, ucp_worker_, ep_ptr, buf, size, tag, tag_mask, source); + + requests_in_flight_.insert(std::make_pair(*request, ucp_req)); + } + + void waitall(int count, request_t array_of_requests[]) const + { + ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); + + std::vector requests; + requests.reserve(count); + + time_t start = time(NULL); + + for (int i = 0; i < count; ++i) { + auto req_it = requests_in_flight_.find(array_of_requests[i]); + ASSERT(requests_in_flight_.end() != req_it, + "ERROR: waitall on invalid request: %d", + array_of_requests[i]); + requests.push_back(req_it->second); + free_requests_.insert(req_it->first); + requests_in_flight_.erase(req_it); + } + + while (requests.size() > 0) { + time_t now = time(NULL); + + // Timeout if we have not gotten progress or completed any requests + // in 10 or more seconds. + ASSERT(now - start < 10, "Timed out waiting for requests."); + + for (std::vector::iterator it = requests.begin(); it != requests.end();) { + bool restart = false; // resets the timeout when any progress was made + + // Causes UCP to progress through the send/recv message queue + while (ucp_handler_.ucp_progress(ucp_worker_) != 0) { + restart = true; + } + + auto req = *it; + + // If the message needs release, we know it will be sent/received + // asynchronously, so we will need to track and verify its state + if (req->needs_release) { + ASSERT(UCS_PTR_IS_PTR(req->req), "UCX Request Error. Request is not valid UCX pointer"); + ASSERT(!UCS_PTR_IS_ERR(req->req), "UCX Request Error: %d\n", UCS_PTR_STATUS(req->req)); + ASSERT(req->req->completed == 1 || req->req->completed == 0, + "request->completed not a valid value: %d\n", + req->req->completed); + } + + // If a message was sent synchronously (eg. completed before + // `isend`/`irecv` completed) or an asynchronous message + // is complete, we can go ahead and clean it up. + if (!req->needs_release || req->req->completed == 1) { + restart = true; + + // perform cleanup + ucp_handler_.free_ucp_request(req); + + // remove from pending requests + it = requests.erase(it); + } else { + ++it; + } + // if any progress was made, reset the timeout start time + if (restart) { start = time(NULL); } + } + } + } + + void allreduce(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + op_t op, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclAllReduce( + sendbuff, recvbuff, count, get_nccl_datatype(datatype), get_nccl_op(op), nccl_comm_, stream)); + } + + void bcast(void* buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const + { + RAFT_NCCL_TRY( + ncclBroadcast(buff, buff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + } + + void bcast(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclBroadcast( + sendbuff, recvbuff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + } + + void reduce(const void* sendbuff, + void* recvbuff, + size_t count, + datatype_t datatype, + op_t op, + int root, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclReduce(sendbuff, + recvbuff, + count, + get_nccl_datatype(datatype), + get_nccl_op(op), + root, + nccl_comm_, + stream)); + } + + void allgather(const void* sendbuff, + void* recvbuff, + size_t sendcount, + datatype_t datatype, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclAllGather( + sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), nccl_comm_, stream)); + } + + void allgatherv(const void* sendbuf, + void* recvbuf, + const size_t* recvcounts, + const size_t* displs, + datatype_t datatype, + cudaStream_t stream) const + { + // From: "An Empirical Evaluation of Allgatherv on Multi-GPU Systems" - + // https://arxiv.org/pdf/1812.05964.pdf Listing 1 on page 4. + for (int root = 0; root < num_ranks_; ++root) { + size_t dtype_size = get_datatype_size(datatype); + RAFT_NCCL_TRY(ncclBroadcast(sendbuf, + static_cast(recvbuf) + displs[root] * dtype_size, + recvcounts[root], + get_nccl_datatype(datatype), + root, + nccl_comm_, + stream)); + } + } + + void gather(const void* sendbuff, + void* recvbuff, + size_t sendcount, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + size_t dtype_size = get_datatype_size(datatype); + RAFT_NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + sendcount * r * dtype_size, + sendcount, + get_nccl_datatype(datatype), + r, + nccl_comm_, + stream)); + } + } + RAFT_NCCL_TRY( + ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + RAFT_NCCL_TRY(ncclGroupEnd()); + } + + void gatherv(const void* sendbuff, + void* recvbuff, + size_t sendcount, + const size_t* recvcounts, + const size_t* displs, + datatype_t datatype, + int root, + cudaStream_t stream) const + { + size_t dtype_size = get_datatype_size(datatype); + RAFT_NCCL_TRY(ncclGroupStart()); + if (get_rank() == root) { + for (int r = 0; r < get_size(); ++r) { + RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + displs[r] * dtype_size, + recvcounts[r], + get_nccl_datatype(datatype), + r, + nccl_comm_, + stream)); + } + } + RAFT_NCCL_TRY( + ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); + RAFT_NCCL_TRY(ncclGroupEnd()); + } + + void reducescatter(const void* sendbuff, + void* recvbuff, + size_t recvcount, + datatype_t datatype, + op_t op, + cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclReduceScatter(sendbuff, + recvbuff, + recvcount, + get_nccl_datatype(datatype), + get_nccl_op(op), + nccl_comm_, + stream)); + } + + status_t sync_stream(cudaStream_t stream) const + { + cudaError_t cudaErr; + ncclResult_t ncclErr, ncclAsyncErr; + while (1) { + cudaErr = cudaStreamQuery(stream); + if (cudaErr == cudaSuccess) return status_t::SUCCESS; + + if (cudaErr != cudaErrorNotReady) { + // An error occurred querying the status of the stream_ + return status_t::ERROR; + } + + ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr); + if (ncclErr != ncclSuccess) { + // An error occurred retrieving the asynchronous error + return status_t::ERROR; + } + + if (ncclAsyncErr != ncclSuccess) { + // An asynchronous error happened. Stop the operation and destroy + // the communicator + ncclErr = ncclCommAbort(nccl_comm_); + if (ncclErr != ncclSuccess) + // Caller may abort with an exception or try to re-create a new communicator. + return status_t::ABORT; + } + + // Let other threads (including NCCL threads) use the CPU. + std::this_thread::yield(); + } + } + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); + } + + // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock + void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const + { + RAFT_NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); + } + + void device_sendrecv(const void* sendbuf, + size_t sendsize, + int dest, + void* recvbuf, + size_t recvsize, + int source, + cudaStream_t stream) const + { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + RAFT_NCCL_TRY(ncclGroupStart()); + RAFT_NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); + RAFT_NCCL_TRY(ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream)); + RAFT_NCCL_TRY(ncclGroupEnd()); + } + + void device_multicast_sendrecv(const void* sendbuf, + std::vector const& sendsizes, + std::vector const& sendoffsets, + std::vector const& dests, + void* recvbuf, + std::vector const& recvsizes, + std::vector const& recvoffsets, + std::vector const& sources, + cudaStream_t stream) const + { + // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock + RAFT_NCCL_TRY(ncclGroupStart()); + for (size_t i = 0; i < sendsizes.size(); ++i) { + RAFT_NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], + sendsizes[i], + ncclUint8, + dests[i], + nccl_comm_, + stream)); + } + for (size_t i = 0; i < recvsizes.size(); ++i) { + RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuf) + recvoffsets[i], + recvsizes[i], + ncclUint8, + sources[i], + nccl_comm_, + stream)); + } + RAFT_NCCL_TRY(ncclGroupEnd()); + } + + private: + ncclComm_t nccl_comm_; + cudaStream_t stream_; + + int *sendbuff_, *recvbuff_; + rmm::device_uvector status_; + + int num_ranks_; + int rank_; + + bool subcomms_ucp_; + + comms_ucp_handler ucp_handler_; + ucp_worker_h ucp_worker_; + std::shared_ptr ucp_eps_; + mutable request_t next_request_id_; + mutable std::unordered_map requests_in_flight_; + mutable std::unordered_set free_requests_; +}; +} // namespace detail +} // end namespace comms +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/comms/test.hpp b/cpp/include/raft/comms/detail/test.hpp similarity index 99% rename from cpp/include/raft/comms/test.hpp rename to cpp/include/raft/comms/detail/test.hpp index 01ad6369f8..cd84d2becd 100644 --- a/cpp/include/raft/comms/test.hpp +++ b/cpp/include/raft/comms/detail/test.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ namespace raft { namespace comms { +namespace detail { /** * @brief A simple sanity check that NCCL is able to perform a collective operation @@ -538,5 +539,6 @@ bool test_commsplit(const handle_t& h, int n_colors) return test_collective_allreduce(new_handle, 0); } +} // namespace detail } // namespace comms }; // namespace raft diff --git a/cpp/include/raft/comms/ucp_helper.hpp b/cpp/include/raft/comms/detail/ucp_helper.hpp similarity index 98% rename from cpp/include/raft/comms/ucp_helper.hpp rename to cpp/include/raft/comms/detail/ucp_helper.hpp index 89c7b25630..6ba66fb6f3 100644 --- a/cpp/include/raft/comms/ucp_helper.hpp +++ b/cpp/include/raft/comms/detail/ucp_helper.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,13 +24,17 @@ namespace raft { namespace comms { +namespace detail { typedef void (*dlsym_print_info)(ucp_ep_h, FILE*); + typedef void (*dlsym_rec_free)(void*); + typedef int (*dlsym_worker_progress)(ucp_worker_h); typedef ucs_status_ptr_t (*dlsym_send)( ucp_ep_h, const void*, size_t, ucp_datatype_t, ucp_tag_t, ucp_send_callback_t); + typedef ucs_status_ptr_t (*dlsym_recv)(ucp_worker_h, void*, size_t count, @@ -250,5 +254,6 @@ class comms_ucp_handler { UCS_PTR_STATUS(recv_result)); } }; +} // end namespace detail } // end namespace comms } // end namespace raft diff --git a/cpp/include/raft/comms/util.hpp b/cpp/include/raft/comms/detail/util.hpp similarity index 93% rename from cpp/include/raft/comms/util.hpp rename to cpp/include/raft/comms/detail/util.hpp index ef16773c75..1c0d152016 100644 --- a/cpp/include/raft/comms/util.hpp +++ b/cpp/include/raft/comms/detail/util.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,8 +61,11 @@ namespace raft { namespace comms { +namespace detail { -constexpr size_t get_datatype_size(const datatype_t datatype) +constexpr size_t + +get_datatype_size(const datatype_t datatype) { switch (datatype) { case datatype_t::CHAR: return sizeof(char); @@ -77,7 +80,9 @@ constexpr size_t get_datatype_size(const datatype_t datatype) } } -constexpr ncclDataType_t get_nccl_datatype(const datatype_t datatype) +constexpr ncclDataType_t + +get_nccl_datatype(const datatype_t datatype) { switch (datatype) { case datatype_t::CHAR: return ncclChar; @@ -92,7 +97,9 @@ constexpr ncclDataType_t get_nccl_datatype(const datatype_t datatype) } } -constexpr ncclRedOp_t get_nccl_op(const op_t op) +constexpr ncclRedOp_t + +get_nccl_op(const op_t op) { switch (op) { case op_t::SUM: return ncclSum; @@ -102,5 +109,6 @@ constexpr ncclRedOp_t get_nccl_op(const op_t op) default: throw "Unsupported datatype"; } } +}; // namespace detail }; // namespace comms }; // namespace raft diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index 432f250b59..bb1e30afc8 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,425 +16,13 @@ #pragma once -#include -#include - -#include -#include -#include - -#include -#include - #include -#include -#include -#include -#include - -#define RAFT_MPI_TRY(call) \ - do { \ - int status = call; \ - if (MPI_SUCCESS != status) { \ - int mpi_error_string_lenght = 0; \ - char mpi_error_string[MPI_MAX_ERROR_STRING]; \ - MPI_Error_string(status, mpi_error_string, &mpi_error_string_lenght); \ - RAFT_EXPECTS( \ - MPI_SUCCESS == status, "ERROR: MPI call='%s'. Reason:%s\n", #call, mpi_error_string); \ - } \ - } while (0) - -// FIXME: Remove after consumer rename -#ifndef MPI_TRY -#define MPI_TRY(call) RAFT_MPI_TRY(call) -#endif - -#define RAFT_MPI_TRY_NO_THROW(call) \ - do { \ - int status = call; \ - if (MPI_SUCCESS != status) { \ - int mpi_error_string_lenght = 0; \ - char mpi_error_string[MPI_MAX_ERROR_STRING]; \ - MPI_Error_string(status, mpi_error_string, &mpi_error_string_lenght); \ - printf("MPI call='%s' at file=%s line=%d failed with %s ", \ - #call, \ - __FILE__, \ - __LINE__, \ - mpi_error_string); \ - } \ - } while (0) - -// FIXME: Remove after consumer rename -#ifndef MPI_TRY_NO_THROW -#define MPI_TRY_NO_THROW(call) RAFT_MPI_TRY_NO_THROW(call) -#endif +#include namespace raft { namespace comms { -constexpr MPI_Datatype get_mpi_datatype(const datatype_t datatype) -{ - switch (datatype) { - case datatype_t::CHAR: return MPI_CHAR; - case datatype_t::UINT8: return MPI_UNSIGNED_CHAR; - case datatype_t::INT32: return MPI_INT; - case datatype_t::UINT32: return MPI_UNSIGNED; - case datatype_t::INT64: return MPI_LONG_LONG; - case datatype_t::UINT64: return MPI_UNSIGNED_LONG_LONG; - case datatype_t::FLOAT32: return MPI_FLOAT; - case datatype_t::FLOAT64: return MPI_DOUBLE; - default: - // Execution should never reach here. This takes care of compiler warning. - return MPI_DOUBLE; - } -} - -constexpr MPI_Op get_mpi_op(const op_t op) -{ - switch (op) { - case op_t::SUM: return MPI_SUM; - case op_t::PROD: return MPI_PROD; - case op_t::MIN: return MPI_MIN; - case op_t::MAX: return MPI_MAX; - default: - // Execution should never reach here. This takes care of compiler warning. - return MPI_MAX; - } -} - -class mpi_comms : public comms_iface { - public: - mpi_comms(MPI_Comm comm, const bool owns_mpi_comm) - : owns_mpi_comm_(owns_mpi_comm), mpi_comm_(comm), size_(0), rank_(1), next_request_id_(0) - { - int mpi_is_initialized = 0; - RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized)); - RAFT_EXPECTS(mpi_is_initialized, "ERROR: MPI is not initialized!"); - RAFT_MPI_TRY(MPI_Comm_size(mpi_comm_, &size_)); - RAFT_MPI_TRY(MPI_Comm_rank(mpi_comm_, &rank_)); - // get NCCL unique ID at rank 0 and broadcast it to all others - ncclUniqueId id; - if (0 == rank_) RAFT_NCCL_TRY(ncclGetUniqueId(&id)); - RAFT_MPI_TRY(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, mpi_comm_)); - - // initializing NCCL - RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_)); - } - - virtual ~mpi_comms() - { - // finalizing NCCL - RAFT_NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_)); - if (owns_mpi_comm_) { RAFT_MPI_TRY_NO_THROW(MPI_Comm_free(&mpi_comm_)); } - } - - int get_size() const { return size_; } - - int get_rank() const { return rank_; } - - std::unique_ptr comm_split(int color, int key) const - { - MPI_Comm new_comm; - RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm)); - return std::unique_ptr(new mpi_comms(new_comm, true)); - } - - void barrier() const { RAFT_MPI_TRY(MPI_Barrier(mpi_comm_)); } - - void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const - { - MPI_Request mpi_req; - request_t req_id; - if (free_requests_.empty()) { - req_id = next_request_id_++; - } else { - auto it = free_requests_.begin(); - req_id = *it; - free_requests_.erase(it); - } - RAFT_MPI_TRY(MPI_Isend(buf, size, MPI_BYTE, dest, tag, mpi_comm_, &mpi_req)); - requests_in_flight_.insert(std::make_pair(req_id, mpi_req)); - *request = req_id; - } - - void irecv(void* buf, size_t size, int source, int tag, request_t* request) const - { - MPI_Request mpi_req; - request_t req_id; - if (free_requests_.empty()) { - req_id = next_request_id_++; - } else { - auto it = free_requests_.begin(); - req_id = *it; - free_requests_.erase(it); - } - - RAFT_MPI_TRY(MPI_Irecv(buf, size, MPI_BYTE, source, tag, mpi_comm_, &mpi_req)); - requests_in_flight_.insert(std::make_pair(req_id, mpi_req)); - *request = req_id; - } - - void waitall(int count, request_t array_of_requests[]) const - { - std::vector requests; - requests.reserve(count); - for (int i = 0; i < count; ++i) { - auto req_it = requests_in_flight_.find(array_of_requests[i]); - RAFT_EXPECTS(requests_in_flight_.end() != req_it, - "ERROR: waitall on invalid request: %d", - array_of_requests[i]); - requests.push_back(req_it->second); - free_requests_.insert(req_it->first); - requests_in_flight_.erase(req_it); - } - RAFT_MPI_TRY(MPI_Waitall(requests.size(), requests.data(), MPI_STATUSES_IGNORE)); - } - - void allreduce(const void* sendbuff, - void* recvbuff, - size_t count, - datatype_t datatype, - op_t op, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclAllReduce( - sendbuff, recvbuff, count, get_nccl_datatype(datatype), get_nccl_op(op), nccl_comm_, stream)); - } - - void bcast(void* buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const - { - RAFT_NCCL_TRY( - ncclBroadcast(buff, buff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream)); - } - - void bcast(const void* sendbuff, - void* recvbuff, - size_t count, - datatype_t datatype, - int root, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclBroadcast( - sendbuff, recvbuff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream)); - } - - void reduce(const void* sendbuff, - void* recvbuff, - size_t count, - datatype_t datatype, - op_t op, - int root, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclReduce(sendbuff, - recvbuff, - count, - get_nccl_datatype(datatype), - get_nccl_op(op), - root, - nccl_comm_, - stream)); - } - - void allgather(const void* sendbuff, - void* recvbuff, - size_t sendcount, - datatype_t datatype, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclAllGather( - sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), nccl_comm_, stream)); - } - - void allgatherv(const void* sendbuf, - void* recvbuf, - const size_t* recvcounts, - const size_t* displs, - datatype_t datatype, - cudaStream_t stream) const - { - // From: "An Empirical Evaluation of Allgatherv on Multi-GPU Systems" - - // https://arxiv.org/pdf/1812.05964.pdf Listing 1 on page 4. - for (int root = 0; root < size_; ++root) { - RAFT_NCCL_TRY( - ncclBroadcast(sendbuf, - static_cast(recvbuf) + displs[root] * get_datatype_size(datatype), - recvcounts[root], - get_nccl_datatype(datatype), - root, - nccl_comm_, - stream)); - } - } - - void gather(const void* sendbuff, - void* recvbuff, - size_t sendcount, - datatype_t datatype, - int root, - cudaStream_t stream) const - { - size_t dtype_size = get_datatype_size(datatype); - RAFT_NCCL_TRY(ncclGroupStart()); - if (get_rank() == root) { - for (int r = 0; r < get_size(); ++r) { - RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + sendcount * r * dtype_size, - sendcount, - get_nccl_datatype(datatype), - r, - nccl_comm_, - stream)); - } - } - RAFT_NCCL_TRY( - ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); - RAFT_NCCL_TRY(ncclGroupEnd()); - } - - void gatherv(const void* sendbuff, - void* recvbuff, - size_t sendcount, - const size_t* recvcounts, - const size_t* displs, - datatype_t datatype, - int root, - cudaStream_t stream) const - { - size_t dtype_size = get_datatype_size(datatype); - RAFT_NCCL_TRY(ncclGroupStart()); - if (get_rank() == root) { - for (int r = 0; r < get_size(); ++r) { - RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + displs[r] * dtype_size, - recvcounts[r], - get_nccl_datatype(datatype), - r, - nccl_comm_, - stream)); - } - } - RAFT_NCCL_TRY( - ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); - RAFT_NCCL_TRY(ncclGroupEnd()); - } - - void reducescatter(const void* sendbuff, - void* recvbuff, - size_t recvcount, - datatype_t datatype, - op_t op, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclReduceScatter(sendbuff, - recvbuff, - recvcount, - get_nccl_datatype(datatype), - get_nccl_op(op), - nccl_comm_, - stream)); - } - - status_t sync_stream(cudaStream_t stream) const - { - cudaError_t cudaErr; - ncclResult_t ncclErr, ncclAsyncErr; - while (1) { - cudaErr = cudaStreamQuery(stream); - if (cudaErr == cudaSuccess) return status_t::SUCCESS; - - if (cudaErr != cudaErrorNotReady) { - // An error occurred querying the status of the stream - return status_t::ERROR; - } - - ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr); - if (ncclErr != ncclSuccess) { - // An error occurred retrieving the asynchronous error - return status_t::ERROR; - } - - if (ncclAsyncErr != ncclSuccess) { - // An asynchronous error happened. Stop the operation and destroy - // the communicator - ncclErr = ncclCommAbort(nccl_comm_); - if (ncclErr != ncclSuccess) - // Caller may abort with an exception or try to re-create a new communicator. - return status_t::ABORT; - } - - // Let other threads (including NCCL threads) use the CPU. - pthread_yield(); - } - }; - - // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock - void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); - } - - // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock - void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); - } - - void device_sendrecv(const void* sendbuf, - size_t sendsize, - int dest, - void* recvbuf, - size_t recvsize, - int source, - cudaStream_t stream) const - { - // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock - RAFT_NCCL_TRY(ncclGroupStart()); - RAFT_NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); - RAFT_NCCL_TRY(ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream)); - RAFT_NCCL_TRY(ncclGroupEnd()); - } - - void device_multicast_sendrecv(const void* sendbuf, - std::vector const& sendsizes, - std::vector const& sendoffsets, - std::vector const& dests, - void* recvbuf, - std::vector const& recvsizes, - std::vector const& recvoffsets, - std::vector const& sources, - cudaStream_t stream) const - { - // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock - RAFT_NCCL_TRY(ncclGroupStart()); - for (size_t i = 0; i < sendsizes.size(); ++i) { - RAFT_NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], - sendsizes[i], - ncclUint8, - dests[i], - nccl_comm_, - stream)); - } - for (size_t i = 0; i < recvsizes.size(); ++i) { - RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuf) + recvoffsets[i], - recvsizes[i], - ncclUint8, - sources[i], - nccl_comm_, - stream)); - } - RAFT_NCCL_TRY(ncclGroupEnd()); - } - - private: - bool owns_mpi_comm_; - MPI_Comm mpi_comm_; - - ncclComm_t nccl_comm_; - int size_; - int rank_; - mutable request_t next_request_id_; - mutable std::unordered_map requests_in_flight_; - mutable std::unordered_set free_requests_; -}; +using mpi_comms = detail::mpi_comms; inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm) { @@ -443,5 +31,5 @@ inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm) handle->set_comms(communicator); }; -}; // end namespace comms +}; // namespace comms }; // end namespace raft diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 99f15643a1..f54535a88c 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,539 +16,93 @@ #pragma once -#include - -#include #include -#include - -#include - -#include -#include -#include +#include +#include -#include -#include -#include +#include +#include #include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include +#include namespace raft { namespace comms { -class std_comms : public comms_iface { - public: - std_comms() = delete; - - /** - * @brief Constructor for collective + point-to-point operation. - * @param nccl_comm initialized nccl comm - * @param ucp_worker initialized ucp_worker instance - * @param eps shared pointer to array of ucp endpoints - * @param num_ranks number of ranks in the cluster - * @param rank rank of the current worker - * @param stream cuda stream for synchronizing and ordering collective operations - * @param subcomms_ucp use ucp for subcommunicators - */ - std_comms(ncclComm_t nccl_comm, - ucp_worker_h ucp_worker, - std::shared_ptr eps, - int num_ranks, - int rank, - cudaStream_t stream, - bool subcomms_ucp = true) - : nccl_comm_(nccl_comm), - stream_(stream), - status_(2, stream), - num_ranks_(num_ranks), - rank_(rank), - subcomms_ucp_(subcomms_ucp), - ucp_worker_(ucp_worker), - ucp_eps_(eps), - next_request_id_(0) - { - initialize(); - }; - - /** - * @brief constructor for collective-only operation - * @param nccl_comm initilized nccl communicator - * @param num_ranks size of the cluster - * @param rank rank of the current worker - * @param stream stream for ordering collective operations - */ - std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, cudaStream_t stream) - : nccl_comm_(nccl_comm), - stream_(stream), - status_(2, stream), - num_ranks_(num_ranks), - rank_(rank), - subcomms_ucp_(false) - { - initialize(); - }; - - void initialize() - { - sendbuff_ = status_.data(); - recvbuff_ = status_.data() + 1; - } - int get_size() const { return num_ranks_; } +using std_comms = detail::std_comms; - int get_rank() const { return rank_; } - - std::unique_ptr comm_split(int color, int key) const - { - rmm::device_uvector d_colors(get_size(), stream_); - rmm::device_uvector d_keys(get_size(), stream_); - - update_device(d_colors.data() + get_rank(), &color, 1, stream_); - update_device(d_keys.data() + get_rank(), &key, 1, stream_); - - allgather(d_colors.data() + get_rank(), d_colors.data(), 1, datatype_t::INT32, stream_); - allgather(d_keys.data() + get_rank(), d_keys.data(), 1, datatype_t::INT32, stream_); - this->sync_stream(stream_); - - std::vector h_colors(get_size()); - std::vector h_keys(get_size()); - - update_host(h_colors.data(), d_colors.data(), get_size(), stream_); - update_host(h_keys.data(), d_keys.data(), get_size(), stream_); - - RAFT_CUDA_TRY(cudaStreamSynchronize(stream_)); - - std::vector subcomm_ranks{}; - std::vector new_ucx_ptrs{}; - - for (int i = 0; i < get_size(); ++i) { - if (h_colors[i] == color) { - subcomm_ranks.push_back(i); - if (ucp_worker_ != nullptr && subcomms_ucp_) { new_ucx_ptrs.push_back((*ucp_eps_)[i]); } - } - } +/** + * Function to construct comms_t and inject it on a handle_t. This + * is used for convenience in the Python layer. + * + * @param handle raft::handle_t for injecting the comms + * @param nccl_comm initialized NCCL communicator to use for collectives + * @param num_ranks number of ranks in communicator clique + * @param rank rank of local instance + */ +void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks, int rank) +{ + cudaStream_t stream = handle->get_stream(); + + auto communicator = std::make_shared( + std::unique_ptr(new raft::comms::std_comms(nccl_comm, num_ranks, rank, stream))); + handle->set_comms(communicator); +} + +/** + * Function to construct comms_t and inject it on a handle_t. This + * is used for convenience in the Python layer. + * + * @param handle raft::handle_t for injecting the comms + * @param nccl_comm initialized NCCL communicator to use for collectives + * @param ucp_worker of local process + * Note: This is purposefully left as void* so that the ucp_worker_h + * doesn't need to be exposed through the cython layer + * @param eps array of ucp_ep_h instances. + * Note: This is purposefully left as void* so that + * the ucp_ep_h doesn't need to be exposed through the cython layer. + * @param num_ranks number of ranks in communicator clique + * @param rank rank of local instance + */ +void build_comms_nccl_ucx( + handle_t* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) +{ + auto eps_sp = std::make_shared(new ucp_ep_h[num_ranks]); - ncclUniqueId id{}; - if (get_rank() == subcomm_ranks[0]) { // root of the new subcommunicator - RAFT_NCCL_TRY(ncclGetUniqueId(&id)); - std::vector requests(subcomm_ranks.size() - 1); - for (size_t i = 1; i < subcomm_ranks.size(); ++i) { - isend(&id, sizeof(ncclUniqueId), subcomm_ranks[i], color, requests.data() + (i - 1)); - } - waitall(requests.size(), requests.data()); - } else { - request_t request{}; - irecv(&id, sizeof(ncclUniqueId), subcomm_ranks[0], color, &request); - waitall(1, &request); - } - // FIXME: this seems unnecessary, do more testing and remove this - barrier(); + auto size_t_ep_arr = reinterpret_cast(eps); - ncclComm_t nccl_comm; - RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_ranks.size(), id, key)); + for (int i = 0; i < num_ranks; i++) { + size_t ptr = size_t_ep_arr[i]; + auto ucp_ep_v = reinterpret_cast(*eps_sp); - if (ucp_worker_ != nullptr && subcomms_ucp_) { - auto eps_sp = std::make_shared(new_ucx_ptrs.data()); - return std::unique_ptr(new std_comms(nccl_comm, - (ucp_worker_h)ucp_worker_, - eps_sp, - subcomm_ranks.size(), - key, - stream_, - subcomms_ucp_)); + if (ptr != 0) { + auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); + ucp_ep_v[i] = eps_ptr; } else { - return std::unique_ptr( - new std_comms(nccl_comm, subcomm_ranks.size(), key, stream_)); - } - } - - void barrier() const - { - RAFT_CUDA_TRY(cudaMemsetAsync(sendbuff_, 1, sizeof(int), stream_)); - RAFT_CUDA_TRY(cudaMemsetAsync(recvbuff_, 1, sizeof(int), stream_)); - - allreduce(sendbuff_, recvbuff_, 1, datatype_t::INT32, op_t::SUM, stream_); - - ASSERT(sync_stream(stream_) == status_t::SUCCESS, - "ERROR: syncStream failed. This can be caused by a failed rank_."); - } - - void get_request_id(request_t* req) const - { - request_t req_id; - - if (this->free_requests_.empty()) - req_id = this->next_request_id_++; - else { - auto it = this->free_requests_.begin(); - req_id = *it; - this->free_requests_.erase(it); - } - *req = req_id; - } - - void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const - { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); - - get_request_id(request); - ucp_ep_h ep_ptr = (*ucp_eps_)[dest]; - - ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); - - this->ucp_handler_.ucp_isend(ucp_req, ep_ptr, buf, size, tag, default_tag_mask, get_rank()); - - requests_in_flight_.insert(std::make_pair(*request, ucp_req)); - } - - void irecv(void* buf, size_t size, int source, int tag, request_t* request) const - { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); - - get_request_id(request); - - ucp_ep_h ep_ptr = (*ucp_eps_)[source]; - - ucp_tag_t tag_mask = default_tag_mask; - - ucp_request* ucp_req = (ucp_request*)malloc(sizeof(ucp_request)); - ucp_handler_.ucp_irecv(ucp_req, ucp_worker_, ep_ptr, buf, size, tag, tag_mask, source); - - requests_in_flight_.insert(std::make_pair(*request, ucp_req)); - } - - void waitall(int count, request_t array_of_requests[]) const - { - ASSERT(ucp_worker_ != nullptr, "ERROR: UCX comms not initialized on communicator."); - - std::vector requests; - requests.reserve(count); - - time_t start = time(NULL); - - for (int i = 0; i < count; ++i) { - auto req_it = requests_in_flight_.find(array_of_requests[i]); - ASSERT(requests_in_flight_.end() != req_it, - "ERROR: waitall on invalid request: %d", - array_of_requests[i]); - requests.push_back(req_it->second); - free_requests_.insert(req_it->first); - requests_in_flight_.erase(req_it); - } - - while (requests.size() > 0) { - time_t now = time(NULL); - - // Timeout if we have not gotten progress or completed any requests - // in 10 or more seconds. - ASSERT(now - start < 10, "Timed out waiting for requests."); - - for (std::vector::iterator it = requests.begin(); it != requests.end();) { - bool restart = false; // resets the timeout when any progress was made - - // Causes UCP to progress through the send/recv message queue - while (ucp_handler_.ucp_progress(ucp_worker_) != 0) { - restart = true; - } - - auto req = *it; - - // If the message needs release, we know it will be sent/received - // asynchronously, so we will need to track and verify its state - if (req->needs_release) { - ASSERT(UCS_PTR_IS_PTR(req->req), "UCX Request Error. Request is not valid UCX pointer"); - ASSERT(!UCS_PTR_IS_ERR(req->req), "UCX Request Error: %d\n", UCS_PTR_STATUS(req->req)); - ASSERT(req->req->completed == 1 || req->req->completed == 0, - "request->completed not a valid value: %d\n", - req->req->completed); - } - - // If a message was sent synchronously (eg. completed before - // `isend`/`irecv` completed) or an asynchronous message - // is complete, we can go ahead and clean it up. - if (!req->needs_release || req->req->completed == 1) { - restart = true; - - // perform cleanup - ucp_handler_.free_ucp_request(req); - - // remove from pending requests - it = requests.erase(it); - } else { - ++it; - } - // if any progress was made, reset the timeout start time - if (restart) { start = time(NULL); } - } - } - } - - void allreduce(const void* sendbuff, - void* recvbuff, - size_t count, - datatype_t datatype, - op_t op, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclAllReduce( - sendbuff, recvbuff, count, get_nccl_datatype(datatype), get_nccl_op(op), nccl_comm_, stream)); - } - - void bcast(void* buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const - { - RAFT_NCCL_TRY( - ncclBroadcast(buff, buff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream)); - } - - void bcast(const void* sendbuff, - void* recvbuff, - size_t count, - datatype_t datatype, - int root, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclBroadcast( - sendbuff, recvbuff, count, get_nccl_datatype(datatype), root, nccl_comm_, stream)); - } - - void reduce(const void* sendbuff, - void* recvbuff, - size_t count, - datatype_t datatype, - op_t op, - int root, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclReduce(sendbuff, - recvbuff, - count, - get_nccl_datatype(datatype), - get_nccl_op(op), - root, - nccl_comm_, - stream)); - } - - void allgather(const void* sendbuff, - void* recvbuff, - size_t sendcount, - datatype_t datatype, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclAllGather( - sendbuff, recvbuff, sendcount, get_nccl_datatype(datatype), nccl_comm_, stream)); - } - - void allgatherv(const void* sendbuf, - void* recvbuf, - const size_t* recvcounts, - const size_t* displs, - datatype_t datatype, - cudaStream_t stream) const - { - // From: "An Empirical Evaluation of Allgatherv on Multi-GPU Systems" - - // https://arxiv.org/pdf/1812.05964.pdf Listing 1 on page 4. - for (int root = 0; root < num_ranks_; ++root) { - size_t dtype_size = get_datatype_size(datatype); - RAFT_NCCL_TRY(ncclBroadcast(sendbuf, - static_cast(recvbuf) + displs[root] * dtype_size, - recvcounts[root], - get_nccl_datatype(datatype), - root, - nccl_comm_, - stream)); - } - } - - void gather(const void* sendbuff, - void* recvbuff, - size_t sendcount, - datatype_t datatype, - int root, - cudaStream_t stream) const - { - size_t dtype_size = get_datatype_size(datatype); - RAFT_NCCL_TRY(ncclGroupStart()); - if (get_rank() == root) { - for (int r = 0; r < get_size(); ++r) { - RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + sendcount * r * dtype_size, - sendcount, - get_nccl_datatype(datatype), - r, - nccl_comm_, - stream)); - } - } - RAFT_NCCL_TRY( - ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); - RAFT_NCCL_TRY(ncclGroupEnd()); - } - - void gatherv(const void* sendbuff, - void* recvbuff, - size_t sendcount, - const size_t* recvcounts, - const size_t* displs, - datatype_t datatype, - int root, - cudaStream_t stream) const - { - size_t dtype_size = get_datatype_size(datatype); - RAFT_NCCL_TRY(ncclGroupStart()); - if (get_rank() == root) { - for (int r = 0; r < get_size(); ++r) { - RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuff) + displs[r] * dtype_size, - recvcounts[r], - get_nccl_datatype(datatype), - r, - nccl_comm_, - stream)); - } - } - RAFT_NCCL_TRY( - ncclSend(sendbuff, sendcount, get_nccl_datatype(datatype), root, nccl_comm_, stream)); - RAFT_NCCL_TRY(ncclGroupEnd()); - } - - void reducescatter(const void* sendbuff, - void* recvbuff, - size_t recvcount, - datatype_t datatype, - op_t op, - cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclReduceScatter(sendbuff, - recvbuff, - recvcount, - get_nccl_datatype(datatype), - get_nccl_op(op), - nccl_comm_, - stream)); - } - - status_t sync_stream(cudaStream_t stream) const - { - cudaError_t cudaErr; - ncclResult_t ncclErr, ncclAsyncErr; - while (1) { - cudaErr = cudaStreamQuery(stream); - if (cudaErr == cudaSuccess) return status_t::SUCCESS; - - if (cudaErr != cudaErrorNotReady) { - // An error occurred querying the status of the stream_ - return status_t::ERROR; - } - - ncclErr = ncclCommGetAsyncError(nccl_comm_, &ncclAsyncErr); - if (ncclErr != ncclSuccess) { - // An error occurred retrieving the asynchronous error - return status_t::ERROR; - } - - if (ncclAsyncErr != ncclSuccess) { - // An asynchronous error happened. Stop the operation and destroy - // the communicator - ncclErr = ncclCommAbort(nccl_comm_); - if (ncclErr != ncclSuccess) - // Caller may abort with an exception or try to re-create a new communicator. - return status_t::ABORT; - } - - // Let other threads (including NCCL threads) use the CPU. - std::this_thread::yield(); - } - } - - // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock - void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclSend(buf, size, ncclUint8, dest, nccl_comm_, stream)); - } - - // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock - void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const - { - RAFT_NCCL_TRY(ncclRecv(buf, size, ncclUint8, source, nccl_comm_, stream)); - } - - void device_sendrecv(const void* sendbuf, - size_t sendsize, - int dest, - void* recvbuf, - size_t recvsize, - int source, - cudaStream_t stream) const - { - // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock - RAFT_NCCL_TRY(ncclGroupStart()); - RAFT_NCCL_TRY(ncclSend(sendbuf, sendsize, ncclUint8, dest, nccl_comm_, stream)); - RAFT_NCCL_TRY(ncclRecv(recvbuf, recvsize, ncclUint8, source, nccl_comm_, stream)); - RAFT_NCCL_TRY(ncclGroupEnd()); - } - - void device_multicast_sendrecv(const void* sendbuf, - std::vector const& sendsizes, - std::vector const& sendoffsets, - std::vector const& dests, - void* recvbuf, - std::vector const& recvsizes, - std::vector const& recvoffsets, - std::vector const& sources, - cudaStream_t stream) const - { - // ncclSend/ncclRecv pair needs to be inside ncclGroupStart/ncclGroupEnd to avoid deadlock - RAFT_NCCL_TRY(ncclGroupStart()); - for (size_t i = 0; i < sendsizes.size(); ++i) { - RAFT_NCCL_TRY(ncclSend(static_cast(sendbuf) + sendoffsets[i], - sendsizes[i], - ncclUint8, - dests[i], - nccl_comm_, - stream)); - } - for (size_t i = 0; i < recvsizes.size(); ++i) { - RAFT_NCCL_TRY(ncclRecv(static_cast(recvbuf) + recvoffsets[i], - recvsizes[i], - ncclUint8, - sources[i], - nccl_comm_, - stream)); + ucp_ep_v[i] = nullptr; } - RAFT_NCCL_TRY(ncclGroupEnd()); } - private: - ncclComm_t nccl_comm_; - cudaStream_t stream_; + cudaStream_t stream = handle->get_stream(); - int *sendbuff_, *recvbuff_; - rmm::device_uvector status_; + auto communicator = + std::make_shared(std::unique_ptr(new raft::comms::std_comms( + nccl_comm, (ucp_worker_h)ucp_worker, eps_sp, num_ranks, rank, stream))); + handle->set_comms(communicator); +} - int num_ranks_; - int rank_; +inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId, int size) +{ + memcpy(id->internal, uniqueId, size); +} - bool subcomms_ucp_; +inline void get_unique_id(char* uid, int size) +{ + ncclUniqueId id; + ncclGetUniqueId(&id); - comms_ucp_handler ucp_handler_; - ucp_worker_h ucp_worker_; - std::shared_ptr ucp_eps_; - mutable request_t next_request_id_; - mutable std::unordered_map requests_in_flight_; - mutable std::unordered_set free_requests_; -}; -} // end namespace comms -} // end namespace raft + memcpy(uid, id.internal, size); +} +}; // namespace comms +}; // end namespace raft \ No newline at end of file diff --git a/python/raft/dask/common/comms_utils.pyx b/python/raft/dask/common/comms_utils.pyx index 7370085805..990e882be5 100644 --- a/python/raft/dask/common/comms_utils.pyx +++ b/python/raft/dask/common/comms_utils.pyx @@ -37,11 +37,6 @@ cdef extern from "raft/handle.hpp" namespace "raft": cdef extern from "raft/comms/std_comms.hpp" namespace "raft::comms": - cdef cppclass std_comms: - pass - -cdef extern from "raft/comms/helper.hpp" namespace "raft::comms": - void build_comms_nccl_ucx(handle_t *handle, ncclComm_t comm, void *ucp_worker, @@ -54,7 +49,7 @@ cdef extern from "raft/comms/helper.hpp" namespace "raft::comms": int size, int rank) except + -cdef extern from "raft/comms/test.hpp" namespace "raft::comms": +cdef extern from "raft/comms/comms_test.hpp" namespace "raft::comms": bool test_collective_allreduce(const handle_t &h, int root) except + bool test_collective_broadcast(const handle_t &h, int root) except + diff --git a/python/raft/dask/common/nccl.pyx b/python/raft/dask/common/nccl.pyx index 7fc813b515..fd91f34eb5 100644 --- a/python/raft/dask/common/nccl.pyx +++ b/python/raft/dask/common/nccl.pyx @@ -25,7 +25,7 @@ from cython.operator cimport dereference as deref from libcpp cimport bool from libc.stdlib cimport malloc, free -cdef extern from "raft/comms/helper.hpp" namespace "raft::comms": +cdef extern from "raft/comms/std_comms.hpp" namespace "raft::comms": void get_unique_id(char *uid, int size) except + void nccl_unique_id_from_char(ncclUniqueId *id, char *uniqueId,