diff --git a/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java b/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java index 18e87eb2557e..c62648e89d3a 100644 --- a/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java +++ b/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java @@ -38,9 +38,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; +import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Verify.verify; import static java.lang.String.format; @@ -68,7 +68,8 @@ class EvictableCache // The dataCache must be bounded. private final LoadingCache, V> dataCache; - private final AtomicInteger invalidations = new AtomicInteger(); + // Logically a concurrent Multiset + private final ConcurrentHashMap, Long> ongoingLoads = new ConcurrentHashMap<>(); EvictableCache(CacheBuilder, ? super V> cacheBuilder, CacheLoader cacheLoader) { @@ -77,9 +78,13 @@ class EvictableCache ., V>removalListener(removal -> { Token token = removal.getKey(); verify(token != null, "token is null"); - if (removal.getCause() != RemovalCause.REPLACED) { - tokens.remove(token.getKey(), token); + if (removal.getCause() == RemovalCause.REPLACED) { + return; } + if (removal.getCause() == RemovalCause.EXPIRED && ongoingLoads.containsKey(token)) { + return; + } + tokens.remove(token.getKey(), token); }), new TokenCacheLoader<>(cacheLoader)); } @@ -106,22 +111,15 @@ public V get(K key, Callable valueLoader) throws ExecutionException { Token newToken = new Token<>(key); - int invalidations = this.invalidations.get(); Token token = tokens.computeIfAbsent(key, _ -> newToken); try { - V value = dataCache.get(token, valueLoader); - if (invalidations == this.invalidations.get()) { - // Revive token if it got expired before reloading - if (tokens.putIfAbsent(key, token) == null) { - // Revived - if (!dataCache.asMap().containsKey(token)) { - // We revived, but the token does not correspond to a live entry anymore. - // It would stay in tokens forever, so let's remove it. - tokens.remove(key, token); - } - } + startLoading(token); + try { + return dataCache.get(token, valueLoader); + } + finally { + endLoading(token); } - return value; } catch (Throwable e) { if (newToken == token) { @@ -132,6 +130,9 @@ public V get(K key, Callable valueLoader) } throw e; } + finally { + removeDangling(token); + } } @Override @@ -139,22 +140,15 @@ public V get(K key) throws ExecutionException { Token newToken = new Token<>(key); - int invalidations = this.invalidations.get(); Token token = tokens.computeIfAbsent(key, _ -> newToken); try { - V value = dataCache.get(token); - if (invalidations == this.invalidations.get()) { - // Revive token if it got expired before reloading - if (tokens.putIfAbsent(key, token) == null) { - // Revived - if (!dataCache.asMap().containsKey(token)) { - // We revived, but the token does not correspond to a live entry anymore. - // It would stay in tokens forever, so let's remove it. - tokens.remove(key, token); - } - } + startLoading(token); + try { + return dataCache.get(token); + } + finally { + endLoading(token); } - return value; } catch (Throwable e) { if (newToken == token) { @@ -165,6 +159,9 @@ public V get(K key) } throw e; } + finally { + removeDangling(token); + } } @Override @@ -218,6 +215,30 @@ public ImmutableMap getAll(Iterable keys) } } + private void startLoading(Token token) + { + ongoingLoads.compute(token, (_, count) -> firstNonNull(count, 0L) + 1); + } + + private void endLoading(Token token) + { + ongoingLoads.compute(token, (_, count) -> { + verify(count != null && count > 0, "Incorrect count for token %s: %s", token, count); + if (count == 1) { + return null; + } + return count - 1; + }); + } + + // Token eviction via removalListener is blocked during loading, so we may need to do manual cleanup + private void removeDangling(Token token) + { + if (!dataCache.asMap().containsKey(token)) { + tokens.remove(token.getKey(), token); + } + } + @Override public void refresh(K key) { @@ -248,7 +269,6 @@ int tokensCount() @Override public void invalidate(Object key) { - invalidations.incrementAndGet(); @SuppressWarnings("SuspiciousMethodCalls") // Object passed to map as key K Token token = tokens.remove(key); if (token != null) { @@ -259,7 +279,6 @@ public void invalidate(Object key) @Override public void invalidateAll() { - invalidations.incrementAndGet(); dataCache.invalidateAll(); tokens.clear(); }