Skip to content

Commit

Permalink
Properly taking ownership of nccl subcomm (and destroying it) (#2094)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Chuck Hastings (https://github.com/ChuckHastings)

URL: #2094
  • Loading branch information
cjnolet authored Jan 12, 2024
1 parent 856288a commit 7d5bb3c
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -81,6 +81,7 @@ class std_comms : public comms_iface {
num_ranks_(num_ranks),
rank_(rank),
subcomms_ucp_(subcomms_ucp),
own_nccl_comm_(false),
ucp_worker_(ucp_worker),
ucp_eps_(eps),
next_request_id_(0)
Expand All @@ -95,13 +96,18 @@ class std_comms : public comms_iface {
* @param rank rank of the current worker
* @param stream stream for ordering collective operations
*/
std_comms(const ncclComm_t nccl_comm, int num_ranks, int rank, rmm::cuda_stream_view stream)
std_comms(const ncclComm_t nccl_comm,
int num_ranks,
int rank,
rmm::cuda_stream_view stream,
bool own_nccl_comm = false)
: nccl_comm_(nccl_comm),
stream_(stream),
status_(stream),
num_ranks_(num_ranks),
rank_(rank),
subcomms_ucp_(false)
subcomms_ucp_(false),
own_nccl_comm_(own_nccl_comm)
{
initialize();
};
Expand All @@ -116,6 +122,11 @@ class std_comms : public comms_iface {
{
requests_in_flight_.clear();
free_requests_.clear();

if (own_nccl_comm_) {
RAFT_NCCL_TRY_NO_THROW(ncclCommDestroy(nccl_comm_));
nccl_comm_ = nullptr;
}
}

int get_size() const { return num_ranks_; }
Expand Down Expand Up @@ -172,7 +183,7 @@ class std_comms : public comms_iface {

RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key));

return std::unique_ptr<comms_iface>(new std_comms(nccl_comm, subcomm_size, key, stream_));
return std::unique_ptr<comms_iface>(new std_comms(nccl_comm, subcomm_size, key, stream_, true));
}

void barrier() const
Expand Down Expand Up @@ -515,6 +526,7 @@ class std_comms : public comms_iface {
int rank_;

bool subcomms_ucp_;
bool own_nccl_comm_;

comms_ucp_handler ucp_handler_;
ucp_worker_h ucp_worker_;
Expand Down

0 comments on commit 7d5bb3c

Please sign in to comment.