Skip to content

Commit

Permalink
fix when a weigher or expiry fail on an async completion (fixes #1687)
Browse files Browse the repository at this point in the history
  • Loading branch information
ben-manes committed Jun 9, 2024
1 parent 851ec95 commit 2d15a4e
Show file tree
Hide file tree
Showing 17 changed files with 946 additions and 214 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

import com.github.benmanes.caffeine.cache.LocalAsyncCache.AsyncBulkCompleter.NullMapCompletionException;
import com.github.benmanes.caffeine.cache.stats.CacheStats;
import com.google.errorprone.annotations.CanIgnoreReturnValue;

/**
* This class provides a skeletal implementation of the {@link AsyncCache} interface to minimize the
Expand Down Expand Up @@ -144,9 +145,9 @@ default CompletableFuture<Map<K, V>> getAll(Iterable<? extends K> keys,
try {
var loader = mappingFunction.apply(
Collections.unmodifiableSet(proxies.keySet()), cache().executor());
return loader.whenComplete(completer).thenCompose(ignored -> composeResult(futures));
return loader.handle(completer).thenCompose(ignored -> composeResult(futures));
} catch (Throwable t) {
completer.accept(/* result */ null, t);
completer.apply(/* result */ null, t);
throw t;
}
}
Expand Down Expand Up @@ -214,9 +215,15 @@ default void handleCompletion(K key, CompletableFuture<? extends V> valueFuture,
@SuppressWarnings("unchecked")
var castedFuture = (CompletableFuture<V>) valueFuture;

// update the weight and expiration timestamps
cache().statsCounter().recordLoadSuccess(loadTime);
cache().replace(key, castedFuture, castedFuture, /* shouldDiscardRefresh */ false);
try {
// update the weight and expiration timestamps
cache().replace(key, castedFuture, castedFuture, /* shouldDiscardRefresh */ false);
cache().statsCounter().recordLoadSuccess(loadTime);
} catch (Throwable t) {
logger.log(Level.WARNING, "Exception thrown during asynchronous load", t);
cache().statsCounter().recordLoadFailure(loadTime);
cache().remove(key, valueFuture);
}
}
if (recordMiss) {
cache().statsCounter().recordMisses(1);
Expand All @@ -226,7 +233,7 @@ default void handleCompletion(K key, CompletableFuture<? extends V> valueFuture,

/** A function executed asynchronously after a bulk load completes. */
final class AsyncBulkCompleter<K, V>
implements BiConsumer<Map<? extends K, ? extends V>, Throwable> {
implements BiFunction<Map<? extends K, ? extends V>, Throwable, Map<? extends K, ? extends V>> {
private final LocalCache<K, CompletableFuture<V>> cache;
private final Map<K, CompletableFuture<V>> proxies;
private final long startTime;
Expand All @@ -239,9 +246,28 @@ final class AsyncBulkCompleter<K, V>
}

@Override
public void accept(@Nullable Map<? extends K, ? extends V> result, @Nullable Throwable error) {
@CanIgnoreReturnValue
public @Nullable Map<? extends K, ? extends V> apply(
@Nullable Map<? extends K, ? extends V> result, @Nullable Throwable error) {
long loadTime = cache.statsTicker().read() - startTime;
var failure = handleResponse(result, error);

if (failure == null) {
cache.statsCounter().recordLoadSuccess(loadTime);
return result;
}

cache.statsCounter().recordLoadFailure(loadTime);
if (failure instanceof RuntimeException) {
throw (RuntimeException) failure;
} else if (failure instanceof Error) {
throw (Error) failure;
}
throw new CompletionException(failure);
}

private @Nullable Throwable handleResponse(
@Nullable Map<? extends K, ? extends V> result, @Nullable Throwable error) {
if (result == null) {
if (error == null) {
error = new NullMapCompletionException();
Expand All @@ -250,38 +276,65 @@ public void accept(@Nullable Map<? extends K, ? extends V> result, @Nullable Thr
cache.remove(entry.getKey(), entry.getValue());
entry.getValue().obtrudeException(error);
}
cache.statsCounter().recordLoadFailure(loadTime);
if (!(error instanceof CancellationException) && !(error instanceof TimeoutException)) {
logger.log(Level.WARNING, "Exception thrown during asynchronous load", error);
}
return error;
} else {
fillProxies(result);
addNewEntries(result);
cache.statsCounter().recordLoadSuccess(loadTime);
var failure = fillProxies(result);
return addNewEntries(result, failure);
}
}

/** Populates the proxies with the computed result. */
private void fillProxies(Map<? extends K, ? extends V> result) {
proxies.forEach((key, future) -> {
V value = result.get(key);
private @Nullable Throwable fillProxies(Map<? extends K, ? extends V> result) {
Throwable error = null;
for (var entry : proxies.entrySet()) {
var key = entry.getKey();
var value = result.get(key);
var future = entry.getValue();
future.obtrudeValue(value);

if (value == null) {
cache.remove(key, future);
} else {
// update the weight and expiration timestamps
cache.replace(key, future, future);
try {
// update the weight and expiration timestamps
cache.replace(key, future, future);
} catch (Throwable t) {
logger.log(Level.WARNING, "Exception thrown during asynchronous load", t);
cache.remove(key, future);
if (error == null) {
error = t;
} else {
error.addSuppressed(t);
}
}
}
});
}
return error;
}

/** Adds to the cache any extra entries computed that were not requested. */
private void addNewEntries(Map<? extends K, ? extends V> result) {
result.forEach((key, value) -> {
private @Nullable Throwable addNewEntries(
Map<? extends K, ? extends V> result, @Nullable Throwable error) {
for (var entry : result.entrySet()) {
var key = entry.getKey();
var value = result.get(key);
if (!proxies.containsKey(key)) {
cache.put(key, CompletableFuture.completedFuture(value));
try {
cache.put(key, CompletableFuture.completedFuture(value));
} catch (Throwable t) {
logger.log(Level.WARNING, "Exception thrown during asynchronous load", t);
if (error == null) {
error = t;
} else {
error.addSuppressed(t);
}
}
}
});
}
return error;
}

static final class NullMapCompletionException extends CompletionException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,36 +295,42 @@ public CompletableFuture<Map<K, V>> refreshAll(Iterable<? extends K> keys) {
return;
}

boolean[] discard = new boolean[1];
var value = asyncCache.cache().compute(key, (ignored, currentValue) -> {
var successful = asyncCache.cache().refreshes().remove(keyReference, castedFuture);
if (successful && (currentValue == oldValueFuture[0])) {
if (currentValue == null) {
// If the entry is absent then discard the refresh and maybe notifying the listener
discard[0] = (newValue != null);
return null;
} else if ((currentValue == newValue) || (currentValue == castedFuture)) {
// If the reloaded value is the same instance then no-op
return currentValue;
} else if (newValue == Async.getIfReady((CompletableFuture<?>) currentValue)) {
// If the completed futures hold the same value instance then no-op
return currentValue;
try {
boolean[] discard = new boolean[1];
var value = asyncCache.cache().compute(key, (ignored, currentValue) -> {
var successful = asyncCache.cache().refreshes().remove(keyReference, castedFuture);
if (successful && (currentValue == oldValueFuture[0])) {
if (currentValue == null) {
// If the entry is absent then discard the refresh and maybe notifying the listener
discard[0] = (newValue != null);
return null;
} else if ((currentValue == newValue) || (currentValue == castedFuture)) {
// If the reloaded value is the same instance then no-op
return currentValue;
} else if (newValue == Async.getIfReady((CompletableFuture<?>) currentValue)) {
// If the completed futures hold the same value instance then no-op
return currentValue;
}
return (newValue == null) ? null : castedFuture;
}
return (newValue == null) ? null : castedFuture;
// Otherwise, a write invalidated the refresh so discard it and notify the listener
discard[0] = true;
return currentValue;
}, asyncCache.cache().expiry(), /* recordLoad */ false, /* recordLoadFailure */ true);

if (discard[0] && (newValue != null)) {
var cause = (value == null) ? RemovalCause.EXPLICIT : RemovalCause.REPLACED;
asyncCache.cache().notifyRemoval(key, castedFuture, cause);
}
// Otherwise, a write invalidated the refresh so discard it and notify the listener
discard[0] = true;
return currentValue;
}, asyncCache.cache().expiry(), /* recordLoad */ false, /* recordLoadFailure */ true);

if (discard[0] && (newValue != null)) {
var cause = (value == null) ? RemovalCause.EXPLICIT : RemovalCause.REPLACED;
asyncCache.cache().notifyRemoval(key, castedFuture, cause);
}
if (newValue == null) {
if (newValue == null) {
asyncCache.cache().statsCounter().recordLoadFailure(loadTime);
} else {
asyncCache.cache().statsCounter().recordLoadSuccess(loadTime);
}
} catch (Throwable t) {
logger.log(Level.WARNING, "Exception thrown during asynchronous load", t);
asyncCache.cache().statsCounter().recordLoadFailure(loadTime);
} else {
asyncCache.cache().statsCounter().recordLoadSuccess(loadTime);
asyncCache.cache().remove(key, castedFuture);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static com.github.benmanes.caffeine.testing.CollectionSubject.assertThat;
import static com.github.benmanes.caffeine.testing.FutureSubject.assertThat;
import static com.github.benmanes.caffeine.testing.IntSubject.assertThat;
import static com.github.benmanes.caffeine.testing.LoggingEvents.logEvents;
import static com.github.benmanes.caffeine.testing.MapSubject.assertThat;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
Expand Down Expand Up @@ -66,7 +67,6 @@
import com.github.benmanes.caffeine.cache.testing.CheckNoEvictions;
import com.github.benmanes.caffeine.cache.testing.CheckNoStats;
import com.github.benmanes.caffeine.testing.Int;
import com.github.valfirst.slf4jtest.TestLoggerFactory;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
Expand Down Expand Up @@ -169,12 +169,12 @@ public void getFunc_absent_failure(AsyncCache<Int, Int> cache, CacheContext cont

assertThat(valueFuture).hasCompletedExceptionally();
assertThat(cache).doesNotContainKey(context.absentKey());

var event = Iterables.getOnlyElement(TestLoggerFactory.getLoggingEvents());
assertThat(event.getFormattedMessage()).isEqualTo("Exception thrown during asynchronous load");
assertThat(event.getThrowable().orElseThrow())
.hasCauseThat().isInstanceOf(IllegalStateException.class);
assertThat(event.getLevel()).isEqualTo(WARN);
assertThat(logEvents()
.withMessage("Exception thrown during asynchronous load")
.withUnderlyingCause(IllegalStateException.class)
.withLevel(WARN)
.exclusively())
.hasSize(1);
}

@Test(dataProvider = "caches")
Expand Down Expand Up @@ -297,11 +297,12 @@ public void getBiFunc_absent_failure_before(AsyncCache<Int, Int> cache, CacheCon

assertThat(valueFuture).hasCompletedExceptionally();
assertThat(cache).doesNotContainKey(key);

var event = Iterables.getOnlyElement(TestLoggerFactory.getLoggingEvents());
assertThat(event.getFormattedMessage()).isEqualTo("Exception thrown during asynchronous load");
assertThat(event.getThrowable().orElseThrow()).isInstanceOf(IllegalStateException.class);
assertThat(event.getLevel()).isEqualTo(WARN);
assertThat(logEvents()
.withMessage("Exception thrown during asynchronous load")
.withThrowable(IllegalStateException.class)
.withLevel(WARN)
.exclusively())
.hasSize(1);
}

@CacheSpec
Expand All @@ -316,11 +317,12 @@ public void getBiFunc_absent_failure_after(AsyncCache<Int, Int> cache, CacheCont

assertThat(valueFuture).hasCompletedExceptionally();
assertThat(cache).doesNotContainKey(key);

var event = Iterables.getOnlyElement(TestLoggerFactory.getLoggingEvents());
assertThat(event.getFormattedMessage()).isEqualTo("Exception thrown during asynchronous load");
assertThat(event.getThrowable().orElseThrow()).isInstanceOf(IllegalStateException.class);
assertThat(event.getLevel()).isEqualTo(WARN);
assertThat(logEvents()
.withMessage("Exception thrown during asynchronous load")
.withThrowable(IllegalStateException.class)
.withLevel(WARN)
.exclusively())
.hasSize(1);
}

@CacheSpec
Expand Down Expand Up @@ -446,12 +448,12 @@ public void getAllFunction_absent_failure(AsyncCache<Int, Int> cache, CacheConte
.hasCauseThat().isInstanceOf(IllegalStateException.class);
int misses = context.absentKeys().size();
assertThat(context).stats().hits(0).misses(misses).success(0).failures(1);

var event = Iterables.getOnlyElement(TestLoggerFactory.getLoggingEvents());
assertThat(event.getFormattedMessage()).isEqualTo("Exception thrown during asynchronous load");
assertThat(event.getThrowable().orElseThrow())
.hasCauseThat().isInstanceOf(IllegalStateException.class);
assertThat(event.getLevel()).isEqualTo(WARN);
assertThat(logEvents()
.withMessage("Exception thrown during asynchronous load")
.withUnderlyingCause(IllegalStateException.class)
.withLevel(WARN)
.exclusively())
.hasSize(1);
}

@Test(dataProvider = "caches")
Expand Down Expand Up @@ -739,11 +741,12 @@ public void getAllBifunction_absent_failure(AsyncCache<Int, Int> cache, CacheCon
assertThat(future).failsWith(CompletionException.class)
.hasCauseThat().isInstanceOf(IllegalStateException.class);
assertThat(context).stats().hits(0).misses(context.absentKeys().size()).success(0).failures(1);

var event = Iterables.getOnlyElement(TestLoggerFactory.getLoggingEvents());
assertThat(event.getFormattedMessage()).isEqualTo("Exception thrown during asynchronous load");
assertThat(event.getThrowable().orElseThrow()).isInstanceOf(IllegalStateException.class);
assertThat(event.getLevel()).isEqualTo(WARN);
assertThat(logEvents()
.withMessage("Exception thrown during asynchronous load")
.withThrowable(IllegalStateException.class)
.withLevel(WARN)
.exclusively())
.hasSize(1);
}

@CacheSpec
Expand Down Expand Up @@ -1010,7 +1013,7 @@ public void getAllBifunction_early_failure(AsyncCache<Int, Int> cache, CacheCont
} else {
assertThat(result.join()).containsExactlyEntriesIn(context.absent());
}
assertThat(TestLoggerFactory.getLoggingEvents()).isEmpty();
assertThat(logEvents()).isEmpty();
}

/* --------------- put --------------- */
Expand Down Expand Up @@ -1043,7 +1046,7 @@ public void put_insert_failure_before(AsyncCache<Int, Int> cache, CacheContext c
cache.put(context.absentKey(), failedFuture);
assertThat(cache).hasSize(context.initialSize());
assertThat(cache).doesNotContainKey(context.absentKey());
assertThat(TestLoggerFactory.getLoggingEvents()).isEmpty();
assertThat(logEvents()).isEmpty();
}

@Test(dataProvider = "caches")
Expand All @@ -1055,11 +1058,12 @@ public void put_insert_failure_after(AsyncCache<Int, Int> cache, CacheContext co
failedFuture.completeExceptionally(new IllegalStateException());
assertThat(cache).doesNotContainKey(context.absentKey());
assertThat(cache).hasSize(context.initialSize());

var event = Iterables.getOnlyElement(TestLoggerFactory.getLoggingEvents());
assertThat(event.getFormattedMessage()).isEqualTo("Exception thrown during asynchronous load");
assertThat(event.getThrowable().orElseThrow()).isInstanceOf(IllegalStateException.class);
assertThat(event.getLevel()).isEqualTo(WARN);
assertThat(logEvents()
.withMessage("Exception thrown during asynchronous load")
.withThrowable(IllegalStateException.class)
.withLevel(WARN)
.exclusively())
.hasSize(1);
}

@Test(dataProvider = "caches")
Expand All @@ -1081,7 +1085,7 @@ public void put_replace_failure_before(AsyncCache<Int, Int> cache, CacheContext
cache.put(context.middleKey(), failedFuture);
assertThat(cache).hasSize(context.initialSize() - 1);
assertThat(cache).doesNotContainKey(context.absentKey());
assertThat(TestLoggerFactory.getLoggingEvents()).isEmpty();
assertThat(logEvents()).isEmpty();
}

@Test(dataProvider = "caches")
Expand All @@ -1093,7 +1097,7 @@ public void put_replace_failure_after(AsyncCache<Int, Int> cache, CacheContext c
failedFuture.completeExceptionally(new IllegalStateException());
assertThat(cache).doesNotContainKey(context.absentKey());
assertThat(cache).hasSize(context.initialSize() - 1);
assertThat(TestLoggerFactory.getLoggingEvents()).isEmpty();
assertThat(logEvents()).isEmpty();
}

@Test(dataProvider = "caches")
Expand Down
Loading

0 comments on commit 2d15a4e

Please sign in to comment.