Skip to content

Commit

Permalink
Less aggressively discard refreshes due to possible conflicts
Browse files Browse the repository at this point in the history
Linearization means that the refresh should be dropped if another write
for that entry occurs, as we may populate the cache with stale data. The
cache used a write timestamp to help detect this. However, when this
could be too aggressive when a refresh runs concurrently with the async
load completing.

An inflight load is given an infinite timestamp so that it never expires.
A whenComplete callback then updates the entry with its policy metadata,
such as its expiration time. That changes the write timestamp and if a
refresh runs concurrently with the callback then the refresh may see the
infinite timestamp, the replace updates it, and the refresh drops the
entry as a conflict.

Instead, if the refresh is successfully unregistered within the entry's
compute to swap the value, then it can be assumed to be valid. Any other
write will discard the refresh under the entry's compute, so we retain
linearizability. This requires that the load callback to update the
metadata does not discard the refresh, since that is not needed for a
psuedo write.

The write timestamp trick for ABA detection is a layover from the
previous implementation (2.x) which did not promise linearization.

Note that a refresh is optimistic and races with other writes, so it
may be discarded. As it relies on callbacks to write back into the
cache, one cannot expect the entry to be updated after the reload
completes (e.g. see #714).
  • Loading branch information
ben-manes committed May 29, 2022
1 parent 407340e commit 2d83b02
Show file tree
Hide file tree
Showing 17 changed files with 157 additions and 115 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/qodana.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name: Qodana
permissions: read-all
on: [ push, pull_request ]

env:
Expand All @@ -11,6 +10,7 @@ jobs:
qodana:
runs-on: ubuntu-latest
permissions:
checks: write
actions: read
contents: read
security-events: write
Expand Down
11 changes: 10 additions & 1 deletion caffeine/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ tasks.register('memoryOverhead', JavaExec) {
group = 'Benchmarks'
description = 'Evaluates cache overhead'
classpath sourceSets.jmh.runtimeClasspath
jvmArgs "-javaagent:${configurations.javaAgent.singleFile}"
classpath sourceSets.codeGen.runtimeClasspath
mainClass = 'com.github.benmanes.caffeine.cache.MemoryBenchmark'
jvmArgs += [
'--add-opens', 'java.base/java.util.concurrent.atomic=ALL-UNNAMED',
'--add-opens', 'java.base/java.util.concurrent.locks=ALL-UNNAMED',
'--add-opens', 'java.base/java.util.concurrent=ALL-UNNAMED',
'--add-opens', 'java.base/java.lang.ref=ALL-UNNAMED',
'--add-opens', 'java.base/java.lang=ALL-UNNAMED',
'--add-opens', 'java.base/java.util=ALL-UNNAMED',
"-javaagent:${configurations.javaAgent.singleFile}",
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -2097,18 +2097,6 @@ public boolean containsValue(Object value) {
return value;
}

@Override
public @Nullable V getIfPresentQuietly(K key, long[/* 1 */] writeTime) {
V value;
Node<K, V> node = data.get(nodeFactory.newLookupKey(key));
if ((node == null) || ((value = node.getValue()) == null)
|| hasExpired(node, expirationTicker().read())) {
return null;
}
writeTime[0] = node.getWriteTime();
return value;
}

/**
* Returns the key associated with the mapping in this cache, or {@code null} if there is none.
*
Expand Down Expand Up @@ -2441,6 +2429,11 @@ public boolean remove(Object key, Object value) {

@Override
public boolean replace(K key, V oldValue, V newValue) {
return replace(key, oldValue, newValue, /* shouldDiscardRefresh */ true);
}

@Override
public boolean replace(K key, V oldValue, V newValue, boolean shouldDiscardRefresh) {
requireNonNull(key);
requireNonNull(oldValue);
requireNonNull(newValue);
Expand Down Expand Up @@ -2471,7 +2464,10 @@ public boolean replace(K key, V oldValue, V newValue) {
setAccessTime(n, now[0]);
setWriteTime(n, now[0]);
replaced[0] = true;
discardRefresh(k);

if (shouldDiscardRefresh) {
discardRefresh(k);
}
}
return n;
});
Expand Down Expand Up @@ -2693,7 +2689,7 @@ public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
* @param computeIfAbsent if an absent entry can be computed
* @return the new value associated with the specified key, or null if none
*/
@SuppressWarnings("PMD.EmptyIfStmt")
@SuppressWarnings("PMD.EmptyControlStatement")
@Nullable V remap(K key, Object keyRef,
BiFunction<? super K, ? super V, ? extends V> remappingFunction,
Expiry<? super K, ? super V> expiry, long[/* 1 */] now, boolean computeIfAbsent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,15 @@ default void handleCompletion(K key, CompletableFuture<? extends V> valueFuture,
&& !(error instanceof TimeoutException)) {
logger.log(Level.WARNING, "Exception thrown during asynchronous load", error);
}
cache().remove(key, valueFuture);
cache().statsCounter().recordLoadFailure(loadTime);
cache().remove(key, valueFuture);
} else {
@SuppressWarnings("unchecked")
var castedFuture = (CompletableFuture<V>) valueFuture;

// update the weight and expiration timestamps
cache().replace(key, castedFuture, castedFuture);
cache().statsCounter().recordLoadSuccess(loadTime);
cache().replace(key, castedFuture, castedFuture, /* shouldDiscardRefresh */ false);
}
if (recordMiss) {
cache().statsCounter().recordMisses(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ public CompletableFuture<Map<K, V>> refreshAll(Iterable<? extends K> keys) {
}

/** Attempts to avoid a reload if the entry is absent, or a load or reload is in-flight. */
@SuppressWarnings("FutureReturnValueIgnored")
private @Nullable CompletableFuture<V> tryOptimisticRefresh(K key, Object keyReference) {
// If a refresh is in-flight, then return it directly. If completed and not yet removed, then
// remove to trigger a new reload.
Expand All @@ -234,7 +235,9 @@ public CompletableFuture<Map<K, V>> refreshAll(Iterable<? extends K> keys) {
@SuppressWarnings("unchecked")
var prior = (CompletableFuture<V>) asyncCache.cache()
.refreshes().putIfAbsent(keyReference, future);
return (prior == null) ? future : prior;
var result = (prior == null) ? future : prior;
result.whenComplete((r, e) -> asyncCache.cache().refreshes().remove(keyReference, result));
return result;
} else if (!oldValueFuture.isDone()) {
// no-op if load is pending
return oldValueFuture;
Expand All @@ -248,12 +251,11 @@ public CompletableFuture<Map<K, V>> refreshAll(Iterable<? extends K> keys) {
@SuppressWarnings("FutureReturnValueIgnored")
private @Nullable CompletableFuture<V> tryComputeRefresh(K key, Object keyReference) {
long[] startTime = new long[1];
long[] writeTime = new long[1];
boolean[] refreshed = new boolean[1];
@SuppressWarnings({"unchecked", "rawtypes"})
CompletableFuture<V>[] oldValueFuture = new CompletableFuture[1];
var future = asyncCache.cache().refreshes().computeIfAbsent(keyReference, k -> {
oldValueFuture[0] = asyncCache.cache().getIfPresentQuietly(key, writeTime);
oldValueFuture[0] = asyncCache.cache().getIfPresentQuietly(key);
V oldValue = Async.getIfReady(oldValueFuture[0]);
if (oldValue == null) {
return null;
Expand Down Expand Up @@ -282,19 +284,20 @@ public CompletableFuture<Map<K, V>> refreshAll(Iterable<? extends K> keys) {
var castedFuture = (CompletableFuture<V>) future;
if (refreshed[0]) {
castedFuture.whenComplete((newValue, error) -> {
asyncCache.cache().refreshes().remove(keyReference, castedFuture);
long loadTime = asyncCache.cache().statsTicker().read() - startTime[0];
if (error != null) {
if (!(error instanceof CancellationException) && !(error instanceof TimeoutException)) {
logger.log(Level.WARNING, "Exception thrown during refresh", error);
}
asyncCache.cache().refreshes().remove(keyReference, castedFuture);
asyncCache.cache().statsCounter().recordLoadFailure(loadTime);
return;
}

boolean[] discard = new boolean[1];
var value = asyncCache.cache().compute(key, (ignored, currentValue) -> {
if (currentValue == oldValueFuture[0]) {
var successful = asyncCache.cache().refreshes().remove(keyReference, castedFuture);
if (successful && (currentValue == oldValueFuture[0])) {
if (currentValue == null) {
// If the entry is absent then discard the refresh and maybe notifying the listener
discard[0] = (newValue != null);
Expand All @@ -305,16 +308,8 @@ public CompletableFuture<Map<K, V>> refreshAll(Iterable<? extends K> keys) {
} else if (newValue == Async.getIfReady((CompletableFuture<?>) currentValue)) {
// If the completed futures hold the same value instance then no-op
return currentValue;
} else {
// If the entry was not modified while in-flight (no ABA) then replace
long expectedWriteTime = writeTime[0];
if (asyncCache.cache().hasWriteTime()) {
asyncCache.cache().getIfPresentQuietly(key, writeTime);
}
if (writeTime[0] == expectedWriteTime) {
return (newValue == null) ? null : castedFuture;
}
}
return (newValue == null) ? null : castedFuture;
}
// Otherwise, a write invalidated the refresh so discard it and notify the listener
discard[0] = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,15 @@ interface LocalCache<K, V> extends ConcurrentMap<K, V> {
@Nullable
V getIfPresentQuietly(Object key);

/**
* See {@link Cache#getIfPresent(K)}. This method differs by not recording the access with
* the statistics nor the eviction policy, and populates the write-time if known.
*/
@Nullable
V getIfPresentQuietly(K key, long[/* 1 */] writeTime);

/** See {@link Cache#getAllPresent}. */
Map<K, V> getAllPresent(Iterable<? extends K> keys);

/**
* See {@link ConcurrentMap#replace(K, K, V)}. This method differs by optionally not discarding an
* in-flight refresh for the entry if replaced.
*/
boolean replace(K key, V oldValue, V newValue, boolean shouldDiscardRefresh);

@Override
default @Nullable V compute(K key,
BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ default Map<K, V> loadSequentially(Iterable<? extends K> keys) {
default CompletableFuture<V> refresh(K key) {
requireNonNull(key);

long[] writeTime = new long[1];
long[] startTime = new long[1];
@SuppressWarnings("unchecked")
V[] oldValue = (V[]) new Object[1];
Expand All @@ -113,7 +112,7 @@ default CompletableFuture<V> refresh(K key) {

try {
startTime[0] = cache().statsTicker().read();
oldValue[0] = cache().getIfPresentQuietly(key, writeTime);
oldValue[0] = cache().getIfPresentQuietly(key);
var refreshFuture = (oldValue[0] == null)
? cacheLoader().asyncLoad(key, cache().executor())
: cacheLoader().asyncReload(key, oldValue[0], cache().executor());
Expand All @@ -131,34 +130,21 @@ default CompletableFuture<V> refresh(K key) {

if (reloading[0] != null) {
reloading[0].whenComplete((newValue, error) -> {
boolean removed = cache().refreshes().remove(keyReference, reloading[0]);
long loadTime = cache().statsTicker().read() - startTime[0];
if (error != null) {
if (!(error instanceof CancellationException) && !(error instanceof TimeoutException)) {
logger.log(Level.WARNING, "Exception thrown during refresh", error);
}
cache().refreshes().remove(keyReference, reloading[0]);
cache().statsCounter().recordLoadFailure(loadTime);
return;
}

boolean[] discard = new boolean[1];
var value = cache().compute(key, (k, currentValue) -> {
if (currentValue == oldValue[0]) {
if (currentValue == null) {
if (newValue == null) {
return null;
} else if (removed) {
return newValue;
}
} else {
long expectedWriteTime = writeTime[0];
if (cache().hasWriteTime()) {
cache().getIfPresentQuietly(key, writeTime);
}
if (writeTime[0] == expectedWriteTime) {
return newValue;
}
}
boolean removed = cache().refreshes().remove(keyReference, reloading[0]);
if (removed && (currentValue == oldValue[0])) {
return (currentValue == null) && (newValue == null) ? null : newValue;
}
discard[0] = (currentValue != newValue);
return currentValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ public Object referenceKey(K key) {
return data.get(key);
}

@Override
public @Nullable V getIfPresentQuietly(K key, long[/* 1 */] writeTime) {
return data.get(key);
}

@Override
public long estimatedSize() {
return data.mappingCount();
Expand Down Expand Up @@ -519,15 +514,22 @@ public boolean remove(Object key, Object value) {

@Override
public boolean replace(K key, V oldValue, V newValue) {
return replace(key, oldValue, newValue, /* shouldDiscardRefresh */ true);
}

@Override
public boolean replace(K key, V oldValue, V newValue, boolean shouldDiscardRefresh) {
requireNonNull(oldValue);
requireNonNull(newValue);

@SuppressWarnings({"unchecked", "rawtypes"})
V[] prev = (V[]) new Object[1];
data.computeIfPresent(key, (k, v) -> {
if (v.equals(oldValue)) {
if (shouldDiscardRefresh) {
discardRefresh(k);
}
prev[0] = v;
discardRefresh(k);
return newValue;
}
return v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import static com.github.benmanes.caffeine.cache.testing.CacheSpec.Expiration.VARIABLE;
import static com.github.benmanes.caffeine.cache.testing.CacheSubject.assertThat;
import static com.github.benmanes.caffeine.testing.Awaits.await;
import static com.github.benmanes.caffeine.testing.FutureSubject.assertThat;
import static com.github.benmanes.caffeine.testing.MapSubject.assertThat;
import static com.google.common.truth.Truth.assertThat;
import static java.lang.Thread.State.BLOCKED;
Expand All @@ -57,14 +58,17 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.testng.Assert;
import org.testng.annotations.Listeners;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -1417,6 +1421,65 @@ public CompletableFuture<Int> asyncReload(Int key, Int oldValue, Executor execut
await().untilAsserted(() -> assertThat(cache).containsEntry(context.absentKey(), newValue));
}

@Test(dataProvider = "caches", groups = "isolated")
@CacheSpec(population = Population.EMPTY, executor = CacheExecutor.THREADED,
compute = Compute.ASYNC, stats = Stats.DISABLED)
public void refresh_startReloadBeforeLoadCompletion(CacheContext context) {
var stats = Mockito.mock(StatsCounter.class);
var beganLoadSuccess = new AtomicBoolean();
var endLoadSuccess = new CountDownLatch(1);
var beganReloading = new AtomicBoolean();
var beganLoading = new AtomicBoolean();
var endReloading = new AtomicBoolean();
var endLoading = new AtomicBoolean();

context.ticker().setAutoIncrementStep(Duration.ofSeconds(1));
context.caffeine().recordStats(() -> stats);
var asyncCache = context.buildAsync(new CacheLoader<Int, Int>() {
@Override public Int load(Int key) {
beganLoading.set(true);
await().untilTrue(endLoading);
return new Int(ThreadLocalRandom.current().nextInt());
}
@Override public Int reload(Int key, Int oldValue) {
beganReloading.set(true);
await().untilTrue(endReloading);
return new Int(ThreadLocalRandom.current().nextInt());
}
});

Answer<?> answer = invocation -> {
beganLoadSuccess.set(true);
endLoadSuccess.await();
return null;
};
doAnswer(answer).when(stats).recordLoadSuccess(anyLong());

// Start load
var future1 = asyncCache.get(context.absentKey());
await().untilTrue(beganLoading);

// Complete load; start load callback
endLoading.set(true);
await().untilTrue(beganLoadSuccess);

// Start reload
var refresh = asyncCache.synchronous().refresh(context.absentKey());
await().untilTrue(beganReloading);

// Complete load callback
endLoadSuccess.countDown();
await().untilAsserted(() -> assertThat(future1.getNumberOfDependents()).isEqualTo(0));

// Complete reload callback
endReloading.set(true);
await().untilAsserted(() -> assertThat(refresh.getNumberOfDependents()).isEqualTo(0));

// Assert new value
await().untilAsserted(() ->
assertThat(asyncCache.get(context.absentKey())).succeedsWith(refresh.get()));
}

/* --------------- Miscellaneous --------------- */

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import com.github.benmanes.caffeine.cache.testing.CacheSpec;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.CacheExecutor;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.Compute;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.ExecutorFailure;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.Implementation;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.Listener;
import com.github.benmanes.caffeine.cache.testing.CacheSpec.Population;
Expand Down Expand Up @@ -769,7 +770,8 @@ public void invalidateAll_null(Cache<Int, Int> cache, CacheContext context) {
@CheckNoStats
@Test(dataProvider = "caches")
@CacheSpec(population = Population.FULL, compute = Compute.SYNC,
executor = CacheExecutor.REJECTING, removalListener = Listener.CONSUMING)
executorFailure = ExecutorFailure.IGNORED, executor = CacheExecutor.REJECTING,
removalListener = Listener.CONSUMING)
public void removalListener_rejected(Cache<Int, Int> cache, CacheContext context) {
cache.invalidateAll();
assertThat(context).removalNotifications().withCause(EXPLICIT)
Expand Down
Loading

0 comments on commit 2d83b02

Please sign in to comment.