Skip to content

Commit

Permalink
Fix the destruction of interruptible token registry (#1229)
Browse files Browse the repository at this point in the history
Because there's no way to control the order of destruction between the global and thread-local static objects, the token registry may sometimes be accessed after it has already been destructed (in the program exit handlers).

This fix wraps the registry in a shared pointer and keeps the weak pointers in the deleters which cause the problem, thus it avoids accessing the registry after it's been destroyed.

Closes #1225
Closes #1275

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Allard Hendriksen (https://github.com/ahendriksen)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1229
  • Loading branch information
achirkin authored Feb 15, 2023
1 parent 7d6e4dc commit 27ca9b9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 25 deletions.
71 changes: 51 additions & 20 deletions cpp/include/raft/core/interruptible.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -179,9 +179,44 @@ class interruptible {

private:
/** Global registry of thread-local cancellation stores. */
static inline std::unordered_map<std::thread::id, std::weak_ptr<interruptible>> registry_;
/** Protect the access to the registry. */
static inline std::mutex mutex_;
using registry_t =
std::tuple<std::mutex, std::unordered_map<std::thread::id, std::weak_ptr<interruptible>>>;

/**
* The registry "garbage collector": a custom deleter for the interruptible tokens that removes
* the token from the registry, if the registry still exists.
*/
struct registry_gc_t {
std::weak_ptr<registry_t> weak_registry;
std::thread::id thread_id;

inline void operator()(interruptible* thread_store) const noexcept
{
// the deleter kicks in at thread/program exit; in some cases, the registry_ (static variable)
// may have been destructed by this point of time.
// Hence, we use a weak pointer to check if the registry still exists.
auto registry = weak_registry.lock();
if (registry) {
std::lock_guard<std::mutex> guard_erase(std::get<0>(*registry));
auto& map = std::get<1>(*registry);
auto found = map.find(thread_id);
if (found != map.end()) {
auto stored = found->second.lock();
// thread_store is not moveable, thus retains its original location.
// Not equal pointers below imply the new store has been already placed
// in the registry by the same std::thread::id
if (!stored || stored.get() == thread_store) { map.erase(found); }
}
}
delete thread_store;
}
};

/**
* The registry itself is stored in the static memory, in a shared pointer.
* This is to safely access it from the destructors of the thread-local tokens.
*/
static inline std::shared_ptr<registry_t> registry_{new registry_t{}};

/**
* Create a new interruptible token or get an existing from the global registry_.
Expand All @@ -201,26 +236,22 @@ class interruptible {
template <bool Claim>
static auto get_token_impl(std::thread::id thread_id) -> std::shared_ptr<interruptible>
{
std::lock_guard<std::mutex> guard_get(mutex_);
// the following constructs an empty shared_ptr if the key does not exist.
auto& weak_store = registry_[thread_id];
// Make a local copy of the shared pointer to make sure the registry is not destroyed,
// if, for any reason, this function is called at program exit.
std::shared_ptr<registry_t> shared_registry = registry_;
// If the registry is not available, create a lone token that cannot be accessed from
// the outside of the thread.
if (!shared_registry) { return std::shared_ptr<interruptible>{new interruptible()}; }
// Otherwise, proceed with the normal logic
std::lock_guard<std::mutex> guard_get(std::get<0>(*shared_registry));
// the following two lines construct an empty shared_ptr if the key does not exist.
auto& weak_store = std::get<1>(*shared_registry)[thread_id];
auto thread_store = weak_store.lock();
if (!thread_store || (Claim && thread_store->claimed_)) {
// Create a new thread_store in two cases:
// 1. It does not exist in the map yet
// 2. The previous store in the map has not yet been deleted
thread_store.reset(new interruptible(), [thread_id](auto ts) {
std::lock_guard<std::mutex> guard_erase(mutex_);
auto found = registry_.find(thread_id);
if (found != registry_.end()) {
auto stored = found->second.lock();
// thread_store is not moveable, thus retains its original location.
// Not equal pointers below imply the new store has been already placed
// in the registry_ by the same std::thread::id
if (!stored || stored.get() == ts) { registry_.erase(found); }
}
delete ts;
});
thread_store.reset(new interruptible(), registry_gc_t{shared_registry, thread_id});
std::weak_ptr<interruptible>(thread_store).swap(weak_store);
}
// The thread_store is "claimed" by the thread
Expand Down Expand Up @@ -268,4 +299,4 @@ class interruptible {

} // namespace raft

#endif
#endif
7 changes: 2 additions & 5 deletions cpp/include/raft/core/resource/cuda_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_v
*/
inline void sync_stream(const resources& res, rmm::cuda_stream_view stream)
{
// TODO: Fix interruptible segfault:
// https://github.com/rapidsai/raft/issues/1225
// interruptible::synchronize(stream);
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
interruptible::synchronize(stream);
}

/**
Expand All @@ -106,4 +103,4 @@ inline void sync_stream(const resources& res) { sync_stream(res, get_cuda_stream
* @}
*/

} // namespace raft::resource
} // namespace raft::resource

0 comments on commit 27ca9b9

Please sign in to comment.