diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index de2a7d3415..323e408cab 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,6 +81,7 @@ class std_comms : public comms_iface { num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(subcomms_ucp), + own_nccl_comm_(false), ucp_worker_(ucp_worker), ucp_eps_(eps), next_request_id_(0) @@ -95,13 +96,18 @@ class std_comms : public comms_iface { * @param rank rank of the current worker * @param stream stream for ordering collective operations */ - std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, rmm::cuda_stream_view stream) + std_comms(const ncclComm_t nccl_comm, + int num_ranks, + int rank, + rmm::cuda_stream_view stream, + bool own_nccl_comm = false) : nccl_comm_(nccl_comm), stream_(stream), status_(stream), num_ranks_(num_ranks), rank_(rank), - subcomms_ucp_(false) + subcomms_ucp_(false), + own_nccl_comm_(own_nccl_comm) { initialize(); }; @@ -116,6 +122,11 @@ class std_comms : public comms_iface { { requests_in_flight_.clear(); free_requests_.clear(); + + if (own_nccl_comm_) { + RAFT_NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_)); + nccl_comm_ = nullptr; + } } int get_size() const { return num_ranks_; } @@ -172,7 +183,7 @@ class std_comms : public comms_iface { RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key)); - return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_)); + return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_, true)); } void barrier() const @@ -515,6 +526,7 @@ class std_comms : public comms_iface { int rank_; bool subcomms_ucp_; + bool own_nccl_comm_; comms_ucp_handler ucp_handler_; ucp_worker_h ucp_worker_;