Skip to content

Commit

Permalink
Revert "[mlir] Fix race condition introduced in ThreadLocalCache (#93… (
Browse files Browse the repository at this point in the history
#93290)

…280)"

This reverts commit 6977bfb.
  • Loading branch information
kiranchandramohan authored May 24, 2024
1 parent 430729d commit ebc6c28
Showing 1 changed file with 25 additions and 72 deletions.
97 changes: 25 additions & 72 deletions mlir/include/mlir/Support/ThreadLocalCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Mutex.h"

namespace mlir {
Expand All @@ -24,80 +25,28 @@ namespace mlir {
/// cache has very large lock contention.
template <typename ValueT>
class ThreadLocalCache {
struct PerInstanceState;

/// The "observer" is owned by a thread-local cache instance. It is
/// constructed the first time a `ThreadLocalCache` instance is accessed by a
/// thread, unless `perInstanceState` happens to get re-allocated to the same
/// address as a previous one. This class is destructed the thread in which
/// the `thread_local` cache lives is destroyed.
///
/// This class is called the "observer" because while values cached in
/// thread-local caches are owned by `PerInstanceState`, a reference is stored
/// via this class in the TLC. With a double pointer, it knows when the
/// referenced value has been destroyed.
struct Observer {
/// This is the double pointer, explicitly allocated because we need to keep
/// the address stable if the TLC map re-allocates. It is owned by the
/// observer and shared with the value owner.
std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr);
/// Because `Owner` living inside `PerInstanceState` contains a reference to
/// the double pointer, and livkewise this class contains a reference to the
/// value, we need to synchronize destruction of the TLC and the
/// `PerInstanceState` to avoid racing. This weak pointer is acquired during
/// TLC destruction if the `PerInstanceState` hasn't entered its destructor
/// yet, and prevents it from happening.
std::weak_ptr<PerInstanceState> keepalive;
};

/// This struct owns the cache entries. It contains a reference back to the
/// reference inside the cache so that it can be written to null to indicate
/// that the cache entry is invalidated. It needs to do this because
/// `perInstanceState` could get re-allocated to the same pointer and we don't
/// remove entries from the TLC when it is deallocated. Thus, we have to reset
/// the TLC entries to a starting state in case the `ThreadLocalCache` lives
/// shorter than the threads.
struct Owner {
/// Save a pointer to the reference and write it to the newly created entry.
Owner(Observer &observer)
: value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
*observer.ptr = value.get();
}
~Owner() {
if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
*ptr = nullptr;
}

Owner(Owner &&) = default;
Owner &operator=(Owner &&) = default;

std::unique_ptr<ValueT> value;
std::weak_ptr<ValueT *> ptrRef;
};

// Keep a separate shared_ptr protected state that can be acquired atomically
// instead of using shared_ptr's for each value. This avoids a problem
// where the instance shared_ptr is locked() successfully, and then the
// ThreadLocalCache gets destroyed before remove() can be called successfully.
struct PerInstanceState {
/// Remove the given value entry. This is called when a thread local cache
/// is destructing but still contains references to values owned by the
/// `PerInstanceState`. Removal is required because it prevents writeback to
/// a pointer that was deallocated.
/// Remove the given value entry. This is generally called when a thread
/// local cache is destructing.
void remove(ValueT *value) {
// Erase the found value directly, because it is guaranteed to be in the
// list.
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
auto it = llvm::find_if(instances, [&](Owner &instance) {
return instance.value.get() == value;
});
auto it =
llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
return instance.get() == value;
});
assert(it != instances.end() && "expected value to exist in cache");
instances.erase(it);
}

/// Owning pointers to all of the values that have been constructed for this
/// object in the static cache.
SmallVector<Owner, 1> instances;
SmallVector<std::unique_ptr<ValueT>, 1> instances;

/// A mutex used when a new thread instance has been added to the cache for
/// this object.
Expand All @@ -108,22 +57,22 @@ class ThreadLocalCache {
/// instance of the non-static cache and a weak reference to an instance of
/// ValueT. We use a weak reference here so that the object can be destroyed
/// without needing to lock access to the cache itself.
struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
struct CacheType
: public llvm::SmallDenseMap<PerInstanceState *,
std::pair<std::weak_ptr<ValueT>, ValueT *>> {
~CacheType() {
// Remove the values of this cache that haven't already expired. This is
// required because if we don't remove them, they will contain a reference
// back to the data here that is being destroyed.
for (auto &[instance, observer] : *this)
if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
state->remove(*observer.ptr);
// Remove the values of this cache that haven't already expired.
for (auto &it : *this)
if (std::shared_ptr<ValueT> value = it.second.first.lock())
it.first->remove(value.get());
}

/// Clear out any unused entries within the map. This method is not
/// thread-safe, and should only be called by the same thread as the cache.
void clearExpiredEntries() {
for (auto it = this->begin(), e = this->end(); it != e;) {
auto curIt = it++;
if (!*curIt->second.ptr)
if (curIt->second.first.expired())
this->erase(curIt);
}
}
Expand All @@ -140,23 +89,27 @@ class ThreadLocalCache {
ValueT &get() {
// Check for an already existing instance for this thread.
CacheType &staticCache = getStaticCache();
Observer &threadInstance = staticCache[perInstanceState.get()];
if (ValueT *value = *threadInstance.ptr)
std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
staticCache[perInstanceState.get()];
if (ValueT *value = threadInstance.second)
return *value;

// Otherwise, create a new instance for this thread.
{
llvm::sys::SmartScopedLock<true> threadInstanceLock(
perInstanceState->instanceMutex);
perInstanceState->instances.emplace_back(threadInstance);
threadInstance.second =
perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
.get();
}
threadInstance.keepalive = perInstanceState;
threadInstance.first =
std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);

// Before returning the new instance, take the chance to clear out any used
// entries in the static map. The cache is only cleared within the same
// thread to remove the need to lock the cache itself.
staticCache.clearExpiredEntries();
return **threadInstance.ptr;
return *threadInstance.second;
}
ValueT &operator*() { return get(); }
ValueT *operator->() { return &get(); }
Expand Down

0 comments on commit ebc6c28

Please sign in to comment.