From 422e937f5cf4a3ff7dbf14b4929a2026e558259e Mon Sep 17 00:00:00 2001 From: Chuck Hastings <45364586+ChuckHastings@users.noreply.github.com> Date: Mon, 24 Jul 2023 14:13:15 -0400 Subject: [PATCH] Modify comm_split to avoid ucp (#1649) During testing of a new feature in cugraph I discovered that the method required either MPI comms or UCP. I have an application that has neither. This PR modifies the `comm_split` implementation to continue using `allgather` when performing the split instead of using `allgather` followed by UCP comms. Authors: - Chuck Hastings (https://github.com/ChuckHastings) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1649 --- cpp/include/raft/comms/detail/std_comms.hpp | 71 +++++++++------------ 1 file changed, 31 insertions(+), 40 deletions(-) diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 8b92ed48f7..de2a7d3415 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -28,6 +28,8 @@ #include +#include + #include #include @@ -138,50 +140,39 @@ class std_comms : public comms_iface { update_host(h_colors.data(), d_colors.data(), get_size(), stream_); update_host(h_keys.data(), d_keys.data(), get_size(), stream_); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream_)); - - std::vector subcomm_ranks{}; - std::vector new_ucx_ptrs{}; + this->sync_stream(stream_); - for (int i = 0; i < get_size(); ++i) { - if (h_colors[i] == color) { - subcomm_ranks.push_back(i); - if (ucp_worker_ != nullptr && subcomms_ucp_) { new_ucx_ptrs.push_back((*ucp_eps_)[i]); } - } - } + ncclComm_t nccl_comm; + // Create a structure to allgather... ncclUniqueId id{}; - if (get_rank() == subcomm_ranks[0]) { // root of the new subcommunicator - RAFT_NCCL_TRY(ncclGetUniqueId(&id)); - std::vector requests(subcomm_ranks.size() - 1); - for (size_t i = 1; i < subcomm_ranks.size(); ++i) { - isend(&id, sizeof(ncclUniqueId), subcomm_ranks[i], color, requests.data() + (i - 1)); - } - waitall(requests.size(), requests.data()); - } else { - request_t request{}; - irecv(&id, sizeof(ncclUniqueId), subcomm_ranks[0], color, &request); - waitall(1, &request); - } - // FIXME: this seems unnecessary, do more testing and remove this - barrier(); + rmm::device_uvector d_nccl_ids(get_size(), stream_); - ncclComm_t nccl_comm; - RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_ranks.size(), id, key)); - - if (ucp_worker_ != nullptr && subcomms_ucp_) { - auto eps_sp = std::make_shared(new_ucx_ptrs.data()); - return std::unique_ptr(new std_comms(nccl_comm, - (ucp_worker_h)ucp_worker_, - eps_sp, - subcomm_ranks.size(), - key, - stream_, - subcomms_ucp_)); - } else { - return std::unique_ptr( - new std_comms(nccl_comm, subcomm_ranks.size(), key, stream_)); - } + if (key == 0) { RAFT_NCCL_TRY(ncclGetUniqueId(&id)); } + + update_device(d_nccl_ids.data() + get_rank(), &id, 1, stream_); + + allgather(d_nccl_ids.data() + get_rank(), + d_nccl_ids.data(), + sizeof(ncclUniqueId), + datatype_t::UINT8, + stream_); + + auto offset = + std::distance(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()), + std::find_if(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()), + thrust::make_zip_iterator(h_colors.end(), h_keys.end()), + [color](auto tuple) { return thrust::get<0>(tuple) == color; })); + + auto subcomm_size = std::count(h_colors.begin(), h_colors.end(), color); + + update_host(&id, d_nccl_ids.data() + offset, 1, stream_); + + this->sync_stream(stream_); + + RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key)); + + return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_)); } void barrier() const