diff --git a/cpp/include/raft/core/interruptible.hpp b/cpp/include/raft/core/interruptible.hpp index 76fb7aa7c3..0cc4af2bbf 100644 --- a/cpp/include/raft/core/interruptible.hpp +++ b/cpp/include/raft/core/interruptible.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -179,9 +179,44 @@ class interruptible { private: /** Global registry of thread-local cancellation stores. */ - static inline std::unordered_map> registry_; - /** Protect the access to the registry. */ - static inline std::mutex mutex_; + using registry_t = + std::tuple>>; + + /** + * The registry "garbage collector": a custom deleter for the interruptible tokens that removes + * the token from the registry, if the registry still exists. + */ + struct registry_gc_t { + std::weak_ptr weak_registry; + std::thread::id thread_id; + + inline void operator()(interruptible* thread_store) const noexcept + { + // the deleter kicks in at thread/program exit; in some cases, the registry_ (static variable) + // may have been destructed by this point of time. + // Hence, we use a weak pointer to check if the registry still exists. + auto registry = weak_registry.lock(); + if (registry) { + std::lock_guard guard_erase(std::get<0>(*registry)); + auto& map = std::get<1>(*registry); + auto found = map.find(thread_id); + if (found != map.end()) { + auto stored = found->second.lock(); + // thread_store is not moveable, thus retains its original location. + // Not equal pointers below imply the new store has been already placed + // in the registry by the same std::thread::id + if (!stored || stored.get() == thread_store) { map.erase(found); } + } + } + delete thread_store; + } + }; + + /** + * The registry itself is stored in the static memory, in a shared pointer. + * This is to safely access it from the destructors of the thread-local tokens. + */ + static inline std::shared_ptr registry_{new registry_t{}}; /** * Create a new interruptible token or get an existing from the global registry_. @@ -201,26 +236,22 @@ class interruptible { template static auto get_token_impl(std::thread::id thread_id) -> std::shared_ptr { - std::lock_guard guard_get(mutex_); - // the following constructs an empty shared_ptr if the key does not exist. - auto& weak_store = registry_[thread_id]; + // Make a local copy of the shared pointer to make sure the registry is not destroyed, + // if, for any reason, this function is called at program exit. + std::shared_ptr shared_registry = registry_; + // If the registry is not available, create a lone token that cannot be accessed from + // the outside of the thread. + if (!shared_registry) { return std::shared_ptr{new interruptible()}; } + // Otherwise, proceed with the normal logic + std::lock_guard guard_get(std::get<0>(*shared_registry)); + // the following two lines construct an empty shared_ptr if the key does not exist. + auto& weak_store = std::get<1>(*shared_registry)[thread_id]; auto thread_store = weak_store.lock(); if (!thread_store || (Claim && thread_store->claimed_)) { // Create a new thread_store in two cases: // 1. It does not exist in the map yet // 2. The previous store in the map has not yet been deleted - thread_store.reset(new interruptible(), [thread_id](auto ts) { - std::lock_guard guard_erase(mutex_); - auto found = registry_.find(thread_id); - if (found != registry_.end()) { - auto stored = found->second.lock(); - // thread_store is not moveable, thus retains its original location. - // Not equal pointers below imply the new store has been already placed - // in the registry_ by the same std::thread::id - if (!stored || stored.get() == ts) { registry_.erase(found); } - } - delete ts; - }); + thread_store.reset(new interruptible(), registry_gc_t{shared_registry, thread_id}); std::weak_ptr(thread_store).swap(weak_store); } // The thread_store is "claimed" by the thread @@ -268,4 +299,4 @@ class interruptible { } // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/core/resource/cuda_stream.hpp b/cpp/include/raft/core/resource/cuda_stream.hpp index fc69f10d83..6120a4b75c 100644 --- a/cpp/include/raft/core/resource/cuda_stream.hpp +++ b/cpp/include/raft/core/resource/cuda_stream.hpp @@ -91,10 +91,7 @@ inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_v */ inline void sync_stream(const resources& res, rmm::cuda_stream_view stream) { - // TODO: Fix interruptible segfault: - // https://github.com/rapidsai/raft/issues/1225 - // interruptible::synchronize(stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + interruptible::synchronize(stream); } /** @@ -106,4 +103,4 @@ inline void sync_stream(const resources& res) { sync_stream(res, get_cuda_stream * @} */ -} // namespace raft::resource \ No newline at end of file +} // namespace raft::resource