From 3ee6a05c1b88013ff14d71d0be78bc6ac0067840 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 11 Jan 2024 18:17:42 -0500 Subject: [PATCH 1/2] Properly taking ownership of nccl subcomm (and destroying it) --- cpp/include/raft/comms/detail/std_comms.hpp | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index de2a7d3415..66b69c1b9e 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,6 +81,7 @@ class std_comms : public comms_iface { num_ranks_(num_ranks), rank_(rank), subcomms_ucp_(subcomms_ucp), + own_nccl_comm_(false), ucp_worker_(ucp_worker), ucp_eps_(eps), next_request_id_(0) @@ -95,13 +96,18 @@ class std_comms : public comms_iface { * @param rank rank of the current worker * @param stream stream for ordering collective operations */ - std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, rmm::cuda_stream_view stream) + std_comms(const ncclComm_t nccl_comm, + int num_ranks, + int rank, + rmm::cuda_stream_view stream, + bool own_nccl_comm = false) : nccl_comm_(nccl_comm), stream_(stream), status_(stream), num_ranks_(num_ranks), rank_(rank), - subcomms_ucp_(false) + subcomms_ucp_(false), + own_nccl_comm_(own_nccl_comm) { initialize(); }; @@ -116,6 +122,11 @@ class std_comms : public comms_iface { { requests_in_flight_.clear(); free_requests_.clear(); + + if (own_nccl_comm_) { + ncclCommDestroy(nccl_comm_); + nccl_comm_ = nullptr; + } } int get_size() const { return num_ranks_; } @@ -172,7 +183,7 @@ class std_comms : public comms_iface { RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key)); - return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_)); + return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_, true)); } void barrier() const @@ -515,6 +526,7 @@ class std_comms : public comms_iface { int rank_; bool subcomms_ucp_; + bool own_nccl_comm_; comms_ucp_handler ucp_handler_; ucp_worker_h ucp_worker_; From 052472fe1444f35db1aabecb3ffe6ad49e241a38 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 11 Jan 2024 21:00:34 -0500 Subject: [PATCH 2/2] Review feedback --- cpp/include/raft/comms/detail/std_comms.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 66b69c1b9e..323e408cab 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -124,7 +124,7 @@ class std_comms : public comms_iface { free_requests_.clear(); if (own_nccl_comm_) { - ncclCommDestroy(nccl_comm_); + RAFT_NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_)); nccl_comm_ = nullptr; } }