Skip to content

Commit

Permalink
Fix evictable cache invalidation race condition
Browse files Browse the repository at this point in the history
Before the change the following race was possible between threads A, B, C:

- A calls invalidate(K)
- A increments: invalidations++
- B changes state to be cached, and therefore calls invalidate(K) too
- B increments: invalidations++
- C calls get(K)
- C reads invalidations counter
- C retrieves current token T for key K
- C reads value V for T from cache
- A reads and removes token T (same) for key K
- B attempts to read and remove token for key K, not found
- B exits invalidate(K)
- C checks invalidations counter (didn't check)
- C revives, i.e. re-inserts token T for key K
- B calls get(K)
- B retrieves token T (same) for key K
- B reads value V for T from cache -- despite B having called
  invalidate(K)

At least in this situation the problem is transient. Thread A will momentarily
invalidate dataCache for token T, completing the invalidation.

This commit fixes this. The bug was introduced by token reviving (commit
17faae3). This commit reverts that one
and provides a different solution to the problem that commit was
solving.
  • Loading branch information
findepi committed Jun 13, 2024
1 parent 4d3f279 commit e9062c0
Showing 1 changed file with 51 additions and 32 deletions.
83 changes: 51 additions & 32 deletions lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
Expand Down Expand Up @@ -68,7 +68,8 @@ class EvictableCache<K, V>
// The dataCache must be bounded.
private final LoadingCache<Token<K>, V> dataCache;

private final AtomicInteger invalidations = new AtomicInteger();
// Logically a concurrent Multiset
private final ConcurrentHashMap<Token<K>, Long> ongoingLoads = new ConcurrentHashMap<>();

EvictableCache(CacheBuilder<? super Token<K>, ? super V> cacheBuilder, CacheLoader<? super K, V> cacheLoader)
{
Expand All @@ -77,9 +78,13 @@ class EvictableCache<K, V>
.<Token<K>, V>removalListener(removal -> {
Token<K> token = removal.getKey();
verify(token != null, "token is null");
if (removal.getCause() != RemovalCause.REPLACED) {
tokens.remove(token.getKey(), token);
if (removal.getCause() == RemovalCause.REPLACED) {
return;
}
if (removal.getCause() == RemovalCause.EXPIRED && ongoingLoads.containsKey(token)) {
return;
}
tokens.remove(token.getKey(), token);
}),
new TokenCacheLoader<>(cacheLoader));
}
Expand All @@ -106,22 +111,15 @@ public V get(K key, Callable<? extends V> valueLoader)
throws ExecutionException
{
Token<K> newToken = new Token<>(key);
int invalidations = this.invalidations.get();
Token<K> token = tokens.computeIfAbsent(key, _ -> newToken);
try {
V value = dataCache.get(token, valueLoader);
if (invalidations == this.invalidations.get()) {
// Revive token if it got expired before reloading
if (tokens.putIfAbsent(key, token) == null) {
// Revived
if (!dataCache.asMap().containsKey(token)) {
// We revived, but the token does not correspond to a live entry anymore.
// It would stay in tokens forever, so let's remove it.
tokens.remove(key, token);
}
}
startLoading(token);
try {
return dataCache.get(token, valueLoader);
}
finally {
endLoading(token);
}
return value;
}
catch (Throwable e) {
if (newToken == token) {
Expand All @@ -132,29 +130,25 @@ public V get(K key, Callable<? extends V> valueLoader)
}
throw e;
}
finally {
removeDangling(token);
}
}

@Override
public V get(K key)
throws ExecutionException
{
Token<K> newToken = new Token<>(key);
int invalidations = this.invalidations.get();
Token<K> token = tokens.computeIfAbsent(key, _ -> newToken);
try {
V value = dataCache.get(token);
if (invalidations == this.invalidations.get()) {
// Revive token if it got expired before reloading
if (tokens.putIfAbsent(key, token) == null) {
// Revived
if (!dataCache.asMap().containsKey(token)) {
// We revived, but the token does not correspond to a live entry anymore.
// It would stay in tokens forever, so let's remove it.
tokens.remove(key, token);
}
}
startLoading(token);
try {
return dataCache.get(token);
}
finally {
endLoading(token);
}
return value;
}
catch (Throwable e) {
if (newToken == token) {
Expand All @@ -165,6 +159,9 @@ public V get(K key)
}
throw e;
}
finally {
removeDangling(token);
}
}

@Override
Expand Down Expand Up @@ -218,6 +215,30 @@ public ImmutableMap<K, V> getAll(Iterable<? extends K> keys)
}
}

private void startLoading(Token<K> token)
{
ongoingLoads.compute(token, (_, count) -> firstNonNull(count, 0L) + 1);
}

private void endLoading(Token<K> token)
{
ongoingLoads.compute(token, (_, count) -> {
verify(count != null && count > 0, "Incorrect count for token %s: %s", token, count);
if (count == 1) {
return null;
}
return count - 1;
});
}

// Token eviction via removalListener is blocked during loading, so we may need to do manual cleanup
private void removeDangling(Token<K> token)
{
if (!dataCache.asMap().containsKey(token)) {
tokens.remove(token.getKey(), token);
}
}

@Override
public void refresh(K key)
{
Expand Down Expand Up @@ -248,7 +269,6 @@ int tokensCount()
@Override
public void invalidate(Object key)
{
invalidations.incrementAndGet();
@SuppressWarnings("SuspiciousMethodCalls") // Object passed to map as key K
Token<K> token = tokens.remove(key);
if (token != null) {
Expand All @@ -259,7 +279,6 @@ public void invalidate(Object key)
@Override
public void invalidateAll()
{
invalidations.incrementAndGet();
dataCache.invalidateAll();
tokens.clear();
}
Expand Down

0 comments on commit e9062c0

Please sign in to comment.