From b78a07d9d2a117c12f396f330526a8e8ed23a9ba Mon Sep 17 00:00:00 2001 From: Victor Lafargue Date: Fri, 27 Sep 2024 00:36:44 +0200 Subject: [PATCH] Adding NCCL clique to the RAFT handle (#2431) Authors: - Victor Lafargue (https://github.com/viclafargue) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2431 --- cpp/include/raft/comms/detail/std_comms.hpp | 8 +- cpp/include/raft/comms/nccl_clique.hpp | 156 ++++++++++++++++++ .../raft/core/resource/nccl_clique.hpp | 66 ++++++++ .../raft/core/resource/resource_types.hpp | 1 + 4 files changed, 227 insertions(+), 4 deletions(-) create mode 100644 cpp/include/raft/comms/nccl_clique.hpp create mode 100644 cpp/include/raft/core/resource/nccl_clique.hpp diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index c5d64f6a29..ed869e6cae 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -310,13 +310,13 @@ class std_comms : public comms_iface { // Wait for a UCXX progress thread roundtrip, prevent waiting for longer // than 10ms for each operation, will retry in next iteration. ucxx::utils::CallbackNotifier callbackNotifierPre{}; - worker->registerGenericPre([&callbackNotifierPre]() { callbackNotifierPre.set(); }, - 10000000 /* 10ms */); + (void)worker->registerGenericPre( + [&callbackNotifierPre]() { callbackNotifierPre.set(); }, 10000000 /* 10ms */); callbackNotifierPre.wait(); ucxx::utils::CallbackNotifier callbackNotifierPost{}; - worker->registerGenericPost([&callbackNotifierPost]() { callbackNotifierPost.set(); }, - 10000000 /* 10ms */); + (void)worker->registerGenericPost( + [&callbackNotifierPost]() { callbackNotifierPost.set(); }, 10000000 /* 10ms */); callbackNotifierPost.wait(); } else { // Causes UCXX to progress through the send/recv message queue diff --git a/cpp/include/raft/comms/nccl_clique.hpp b/cpp/include/raft/comms/nccl_clique.hpp new file mode 100644 index 0000000000..c6520af753 --- /dev/null +++ b/cpp/include/raft/comms/nccl_clique.hpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include + +/** + * @brief Error checking macro for NCCL runtime API functions. + * + * Invokes a NCCL runtime API function call, if the call does not return ncclSuccess, throws an + * exception detailing the NCCL error that occurred + */ +#define RAFT_NCCL_TRY(call) \ + do { \ + ncclResult_t const status = (call); \ + if (ncclSuccess != status) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "NCCL error encountered at: ", \ + "call='%s', Reason=%d:%s", \ + #call, \ + status, \ + ncclGetErrorString(status)); \ + throw raft::logic_error(msg); \ + } \ + } while (0); + +namespace raft::comms { +void build_comms_nccl_only(raft::resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank); +} + +namespace raft::comms { + +struct nccl_clique { + using pool_mr = rmm::mr::pool_memory_resource; + + /** + * Instantiates a NCCL clique with all available GPUs + * + * @param[in] percent_of_free_memory percentage of device memory to pre-allocate as memory pool + * + */ + nccl_clique(int percent_of_free_memory = 80) + : root_rank_(0), + percent_of_free_memory_(percent_of_free_memory), + per_device_pools_(0), + device_resources_(0) + { + cudaGetDeviceCount(&num_ranks_); + device_ids_.resize(num_ranks_); + std::iota(device_ids_.begin(), device_ids_.end(), 0); + nccl_comms_.resize(num_ranks_); + nccl_clique_init(); + } + + /** + * Instantiates a NCCL clique + * + * Usage example: + * @code{.cpp} + * int n_devices; + * cudaGetDeviceCount(&n_devices); + * std::vector device_ids(n_devices); + * std::iota(device_ids.begin(), device_ids.end(), 0); + * cuvs::neighbors::mg::nccl_clique& clique(device_ids); // first device is the root rank + * @endcode + * + * @param[in] device_ids list of device IDs to be used to initiate the clique + * @param[in] percent_of_free_memory percentage of device memory to pre-allocate as memory pool + * + */ + nccl_clique(const std::vector& device_ids, int percent_of_free_memory = 80) + : root_rank_(0), + num_ranks_(device_ids.size()), + percent_of_free_memory_(percent_of_free_memory), + device_ids_(device_ids), + nccl_comms_(device_ids.size()), + per_device_pools_(0), + device_resources_(0) + { + nccl_clique_init(); + } + + void nccl_clique_init() + { + RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), num_ranks_, device_ids_.data())); + + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(device_ids_[rank])); + + // create a pool memory resource for each device + auto old_mr = rmm::mr::get_current_device_resource(); + per_device_pools_.push_back(std::make_unique( + old_mr, rmm::percent_of_free_device_memory(percent_of_free_memory_))); + rmm::cuda_device_id id(device_ids_[rank]); + rmm::mr::set_per_device_resource(id, per_device_pools_.back().get()); + + // create a device resource handle for each device + device_resources_.emplace_back(); + + // add NCCL communications to the device resource handle + raft::comms::build_comms_nccl_only( + &device_resources_[rank], nccl_comms_[rank], num_ranks_, rank); + } + + for (int rank = 0; rank < num_ranks_; rank++) { + RAFT_CUDA_TRY(cudaSetDevice(device_ids_[rank])); + raft::resource::sync_stream(device_resources_[rank]); + } + } + + const raft::device_resources& set_current_device_to_root_rank() const + { + int root_device_id = device_ids_[root_rank_]; + RAFT_CUDA_TRY(cudaSetDevice(root_device_id)); + return device_resources_[root_rank_]; + } + + ~nccl_clique() + { +#pragma omp parallel for // necessary to avoid hangs + for (int rank = 0; rank < num_ranks_; rank++) { + cudaSetDevice(device_ids_[rank]); + ncclCommDestroy(nccl_comms_[rank]); + rmm::cuda_device_id id(device_ids_[rank]); + rmm::mr::set_per_device_resource(id, nullptr); + } + } + + int root_rank_; + int num_ranks_; + int percent_of_free_memory_; + std::vector device_ids_; + std::vector nccl_comms_; + std::vector> per_device_pools_; + std::vector device_resources_; +}; + +} // namespace raft::comms diff --git a/cpp/include/raft/core/resource/nccl_clique.hpp b/cpp/include/raft/core/resource/nccl_clique.hpp new file mode 100644 index 0000000000..edda5043ae --- /dev/null +++ b/cpp/include/raft/core/resource/nccl_clique.hpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include + +namespace raft::resource { + +class nccl_clique_resource : public resource { + public: + nccl_clique_resource() : clique_(std::make_unique()) {} + ~nccl_clique_resource() override {} + void* get_resource() override { return clique_.get(); } + + private: + std::unique_ptr clique_; +}; + +/** Factory that knows how to construct a specific raft::resource to populate the res_t. */ +class nccl_clique_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::NCCL_CLIQUE; } + resource* make_resource() override { return new nccl_clique_resource(); } +}; + +/** + * @defgroup nccl_clique_resource resource functions + * @{ + */ + +/** + * Retrieves a NCCL clique from raft res if it exists, otherwise initializes it and return it. + * + * @param[in] res the raft resources object + * @return NCCL clique + */ +inline const raft::comms::nccl_clique& get_nccl_clique(resources const& res) +{ + if (!res.has_resource_factory(resource_type::NCCL_CLIQUE)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::NCCL_CLIQUE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index d9126251c9..4fa84c3bdb 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -46,6 +46,7 @@ enum resource_type { CUBLASLT_HANDLE, // cublasLt handle CUSTOM, // runtime-shared default-constructible resource LARGE_WORKSPACE_RESOURCE, // rmm device memory resource for somewhat large temporary allocations + NCCL_CLIQUE, // nccl clique LAST_KEY // reserved for the last key };