diff --git a/extensions/cache/deployment/src/test/java/io/quarkus/cache/test/runtime/CacheResultCompletionStageReturnTypeTest.java b/extensions/cache/deployment/src/test/java/io/quarkus/cache/test/runtime/CacheResultCompletionStageReturnTypeTest.java deleted file mode 100644 index 3cc2d167cffefb..00000000000000 --- a/extensions/cache/deployment/src/test/java/io/quarkus/cache/test/runtime/CacheResultCompletionStageReturnTypeTest.java +++ /dev/null @@ -1,87 +0,0 @@ -package io.quarkus.cache.test.runtime; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionStage; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; - -import javax.enterprise.context.ApplicationScoped; -import javax.inject.Inject; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; - -import io.quarkus.cache.CacheResult; -import io.quarkus.test.QuarkusUnitTest; - -public class CacheResultCompletionStageReturnTypeTest { - - private static final Object KEY_1 = new Object(); - private static final Object KEY_2 = new Object(); - - @RegisterExtension - static final QuarkusUnitTest TEST = new QuarkusUnitTest().withApplicationRoot(jar -> jar.addClass(CachedService.class)); - - @Inject - CachedService cachedService; - - @Test - public void testAllCacheAnnotations() throws InterruptedException, ExecutionException { - // STEP 1 - // Action: @CacheResult-annotated method call. - // Expected effect: method invoked and result cached. - // Verified by: STEP 2. - CompletionStage completionStage1 = cachedService.cachedMethod(KEY_1); - - // STEP 2 - // Action: same call as STEP 1. - // Expected effect: method not invoked and result coming from the cache. - // Verified by: same object reference between STEPS 1 and 2 results. - CompletionStage completionStage2 = cachedService.cachedMethod(KEY_1); - assertTrue(completionStage1 == completionStage2); - - // STEP 3 - // Action: same call as STEP 2 with a new key. - // Expected effect: method invoked and result cached. - // Verified by: different objects references between STEPS 2 and 3 results. - CompletionStage completionStage3 = cachedService.cachedMethod(KEY_2); - assertTrue(completionStage2 != completionStage3); - - // We need all of the futures to complete at this point. - CompletableFuture.allOf(completionStage1.toCompletableFuture(), completionStage2.toCompletableFuture(), - completionStage3.toCompletableFuture()).get(); - - Object value1 = completionStage1.toCompletableFuture().get(); - Object value2 = completionStage2.toCompletableFuture().get(); - Object value3 = completionStage3.toCompletableFuture().get(); - - // Values objects references resulting from STEPS 1 and 2 should be equal since the same cache key was used. - assertTrue(value1 == value2); - - // Values objects references resulting from STEPS 2 and 3 should be different since a different cache key was used. - assertTrue(value2 != value3); - } - - @ApplicationScoped - static class CachedService { - - // This is required to make sure the CompletableFuture from the tests are executed concurrently. - private ExecutorService executorService = Executors.newFixedThreadPool(3); - - @CacheResult(cacheName = "test-cache") - public CompletionStage cachedMethod(Object key) { - return CompletableFuture.supplyAsync(() -> { - try { - // This is another requirement for concurrent CompletableFuture executions. - Thread.sleep(1000); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new Object(); - }, executorService); - } - } -} diff --git a/extensions/cache/deployment/src/test/java/io/quarkus/cache/test/runtime/CompletionStageReturnTypeTest.java b/extensions/cache/deployment/src/test/java/io/quarkus/cache/test/runtime/CompletionStageReturnTypeTest.java new file mode 100644 index 00000000000000..afa31f46fc6922 --- /dev/null +++ b/extensions/cache/deployment/src/test/java/io/quarkus/cache/test/runtime/CompletionStageReturnTypeTest.java @@ -0,0 +1,164 @@ +package io.quarkus.cache.test.runtime; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import javax.enterprise.context.ApplicationScoped; +import javax.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.cache.CacheInvalidate; +import io.quarkus.cache.CacheInvalidateAll; +import io.quarkus.cache.CacheResult; +import io.quarkus.test.QuarkusUnitTest; + +/** + * Tests the caching annotations on methods returning {@link CompletableFuture}. + */ +public class CompletionStageReturnTypeTest { + + private static final String CACHE_NAME_1 = "test-cache-1"; + private static final String CACHE_NAME_2 = "test-cache-2"; + private static final String KEY_1 = "key-1"; + private static final String KEY_2 = "key-2"; + + @RegisterExtension + static final QuarkusUnitTest TEST = new QuarkusUnitTest().withApplicationRoot((jar) -> jar.addClass(CachedService.class)); + + @Inject + CachedService cachedService; + + @Test + void testCacheResult() throws ExecutionException, InterruptedException { + // STEP 1 + // Action: a method annotated with @CacheResult and returning a CompletionStage is called. + // Expected effect: the method is invoked, as CompletionStage is eager. + // Verified by: invocations counter. + CompletableFuture cf1 = cachedService.cacheResult1(KEY_1); + assertEquals(1, cachedService.getCacheResultInvocations()); + + // STEP 2 + // Action: same call as STEP 1. + // Expected effect: same as STEP 1 with a different CompletionStage instance returned. + // Verified by: invocations counter and different objects references between STEPS 1 AND 2 results. + CompletableFuture cf2 = cachedService.cacheResult1(KEY_1); + assertEquals(1, cachedService.getCacheResultInvocations()); + assertNotSame(cf1, cf2); + + // STEP 3 + // Action: the Uni returned in STEP 1 is subscribed to and we wait for an item event to be fired. + // Expected effect: the method from STEP 1 is invoked and its result is cached. + // Verified by: invocations counter and STEP 4. + String emittedItem1 = cf1.get(); + assertEquals(1, cachedService.getCacheResultInvocations()); + + // STEP 4 + // Action: the Uni returned in STEP 2 is subscribed to and we wait for an item event to be fired. + // Expected effect: the method from STEP 2 is not invoked and the value cached in STEP 3 is returned. + // Verified by: invocations counter and same object reference between STEPS 3 and 4 emitted items. + String emittedItem2 = cf2.get(); + assertEquals(1, cachedService.getCacheResultInvocations()); + assertSame(emittedItem1, emittedItem2); + + // STEP 5 + // Action: same call as STEP 2 with a different key and an immediate subscription. + // Expected effect: the method is invoked and a new item is emitted (also cached). + // Verified by: invocations counter. + String emittedItem3 = cachedService.cacheResult1("another-key").get(); + assertEquals(2, cachedService.getCacheResultInvocations()); + } + + @Test + void testCacheInvalidate() throws ExecutionException, InterruptedException { + // First, let's put some data into the caches. + String value1 = cachedService.cacheResult1(KEY_1).get(); + Object value2 = cachedService.cacheResult2(KEY_2).get(); + + // We will invalidate some data (only KEY_1) in all caches later. + cachedService.cacheInvalidate(KEY_1).get(); + // For now, the method that will invalidate the data should not be invoked, as CompletionStage is eager. + assertEquals(1, cachedService.getCacheInvalidateInvocations()); + + // The data for the second key should still be cached at this point. + Object value4 = cachedService.cacheResult2(KEY_2).get(); + assertSame(value2, value4); + + // Let's call the methods annotated with @CacheResult again. + String value7 = cachedService.cacheResult1(KEY_1).get(); + + // The objects references should be different for the invalidated key. + assertNotSame(value1, value7); + } + + @Test + void testCacheInvalidateAll() throws ExecutionException, InterruptedException { + // First, let's put some data into the caches. + String value1 = cachedService.cacheResult1(KEY_1).get(); + Object value2 = cachedService.cacheResult2(KEY_2).get(); + + // We will invalidate all the data in all caches later. + cachedService.cacheInvalidateAll().get(); + + // For now, the method that will invalidate the data should not be invoked, as CompletionStage is eager. + assertEquals(1, cachedService.getCacheInvalidateAllInvocations()); + + // Let's call the methods annotated with @CacheResult again. + String value3 = cachedService.cacheResult1(KEY_1).get(); + Object value4 = cachedService.cacheResult2(KEY_2).get(); + + // All objects references should be different. + assertNotSame(value1, value3); + assertNotSame(value2, value4); + } + + @ApplicationScoped + static class CachedService { + + private volatile int cacheResultInvocations; + private volatile int cacheInvalidateInvocations; + private volatile int cacheInvalidateAllInvocations; + + @CacheResult(cacheName = CACHE_NAME_1) + public CompletableFuture cacheResult1(String key) { + cacheResultInvocations++; + return CompletableFuture.completedFuture(new String()); + } + + @CacheResult(cacheName = CACHE_NAME_2) + public CompletableFuture cacheResult2(String key) { + return CompletableFuture.completedFuture(new Object()); + } + + @CacheInvalidate(cacheName = CACHE_NAME_1) + @CacheInvalidate(cacheName = CACHE_NAME_2) + public CompletableFuture cacheInvalidate(String key) { + cacheInvalidateInvocations++; + return CompletableFuture.completedFuture(null); + } + + @CacheInvalidateAll(cacheName = CACHE_NAME_1) + @CacheInvalidateAll(cacheName = CACHE_NAME_2) + public CompletableFuture cacheInvalidateAll() { + cacheInvalidateAllInvocations++; + return CompletableFuture.completedFuture(null); + } + + public int getCacheResultInvocations() { + return cacheResultInvocations; + } + + public int getCacheInvalidateInvocations() { + return cacheInvalidateInvocations; + } + + public int getCacheInvalidateAllInvocations() { + return cacheInvalidateAllInvocations; + } + } +} diff --git a/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInterceptor.java b/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInterceptor.java index 35b0ef596a6271..fe86e0cfce9d0c 100644 --- a/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInterceptor.java +++ b/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInterceptor.java @@ -6,6 +6,7 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletionStage; import java.util.function.Supplier; import javax.inject.Inject; @@ -16,6 +17,7 @@ import io.quarkus.arc.runtime.InterceptorBindings; import io.quarkus.cache.Cache; +import io.quarkus.cache.CacheException; import io.quarkus.cache.CacheKey; import io.quarkus.cache.CacheManager; import io.quarkus.cache.CompositeCacheKey; @@ -27,6 +29,7 @@ public abstract class CacheInterceptor { private static final Logger LOGGER = Logger.getLogger(CacheInterceptor.class); private static final String PERFORMANCE_WARN_MSG = "Cache key resolution based on reflection calls. Please create a GitHub issue in the Quarkus repository, the maintainers might be able to improve your application performance."; + protected static final String UNHANDLED_ASYNC_RETURN_TYPE_MSG = "Unhandled async return type"; @Inject CacheManager cacheManager; @@ -135,7 +138,44 @@ protected Object getCacheKey(Cache cache, List cacheKeyParameterPositions } } - protected static boolean isUniReturnType(InvocationContext invocationContext) { - return Uni.class.isAssignableFrom(invocationContext.getMethod().getReturnType()); + protected static ReturnType determineReturnType(Class returnType) { + if (Uni.class.isAssignableFrom(returnType)) { + return ReturnType.Uni; + } + if (CompletionStage.class.isAssignableFrom(returnType)) { + return ReturnType.CompletionStage; + } + return ReturnType.NonAsync; + } + + protected Uni asyncInvocationResultToUni(Object invocationResult, ReturnType returnType) { + if (returnType == ReturnType.Uni) { + return (Uni) invocationResult; + } else if (returnType == ReturnType.CompletionStage) { + return Uni.createFrom().completionStage(new Supplier<>() { + @Override + public CompletionStage get() { + return (CompletionStage) invocationResult; + } + }); + } else { + throw new CacheException(new IllegalStateException(UNHANDLED_ASYNC_RETURN_TYPE_MSG)); + } + } + + protected Object createAsyncResult(Uni cacheValue, ReturnType returnType) { + if (returnType == ReturnType.Uni) { + return cacheValue; + } + if (returnType == ReturnType.CompletionStage) { + return cacheValue.subscribeAsCompletionStage(); + } + throw new CacheException(new IllegalStateException(UNHANDLED_ASYNC_RETURN_TYPE_MSG)); + } + + protected enum ReturnType { + NonAsync, + Uni, + CompletionStage } } diff --git a/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInvalidateAllInterceptor.java b/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInvalidateAllInterceptor.java index 8f6168335d2292..4d504f62e391d5 100644 --- a/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInvalidateAllInterceptor.java +++ b/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInvalidateAllInterceptor.java @@ -27,21 +27,26 @@ public class CacheInvalidateAllInterceptor extends CacheInterceptor { public Object intercept(InvocationContext invocationContext) throws Exception { CacheInterceptionContext interceptionContext = getInterceptionContext(invocationContext, CacheInvalidateAll.class, false); + if (interceptionContext.getInterceptorBindings().isEmpty()) { // This should never happen. LOGGER.warn(INTERCEPTOR_BINDINGS_ERROR_MSG); return invocationContext.proceed(); - } else if (isUniReturnType(invocationContext)) { - return invalidateAllNonBlocking(invocationContext, interceptionContext); - } else { + } + ReturnType returnType = determineReturnType(invocationContext.getMethod().getReturnType()); + if (returnType == ReturnType.NonAsync) { return invalidateAllBlocking(invocationContext, interceptionContext); + + } else { + return invalidateAllNonBlocking(invocationContext, interceptionContext, returnType); } } private Object invalidateAllNonBlocking(InvocationContext invocationContext, - CacheInterceptionContext interceptionContext) { + CacheInterceptionContext interceptionContext, + ReturnType returnType) { LOGGER.trace("Invalidating all cache entries in a non-blocking way"); - return Multi.createFrom().iterable(interceptionContext.getInterceptorBindings()) + var uni = Multi.createFrom().iterable(interceptionContext.getInterceptorBindings()) .onItem().transformToUniAndMerge(new Function>() { @Override public Uni apply(CacheInvalidateAll binding) { @@ -53,12 +58,13 @@ public Uni apply(CacheInvalidateAll binding) { @Override public Uni apply(Object ignored) { try { - return (Uni) invocationContext.proceed(); + return asyncInvocationResultToUni(invocationContext.proceed(), returnType); } catch (Exception e) { throw new CacheException(e); } } }); + return createAsyncResult(uni, returnType); } private Object invalidateAllBlocking(InvocationContext invocationContext, diff --git a/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInvalidateInterceptor.java b/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInvalidateInterceptor.java index 548bc1d4a6023f..62cf500d32af69 100644 --- a/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInvalidateInterceptor.java +++ b/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheInvalidateInterceptor.java @@ -32,17 +32,20 @@ public Object intercept(InvocationContext invocationContext) throws Exception { // This should never happen. LOGGER.warn(INTERCEPTOR_BINDINGS_ERROR_MSG); return invocationContext.proceed(); - } else if (isUniReturnType(invocationContext)) { - return invalidateNonBlocking(invocationContext, interceptionContext); - } else { + } + ReturnType returnType = determineReturnType(invocationContext.getMethod().getReturnType()); + if (returnType == ReturnType.NonAsync) { return invalidateBlocking(invocationContext, interceptionContext); + } else { + return invalidateNonBlocking(invocationContext, interceptionContext, returnType); } } private Object invalidateNonBlocking(InvocationContext invocationContext, - CacheInterceptionContext interceptionContext) { + CacheInterceptionContext interceptionContext, + ReturnType returnType) { LOGGER.trace("Invalidating cache entries in a non-blocking way"); - return Multi.createFrom().iterable(interceptionContext.getInterceptorBindings()) + var uni = Multi.createFrom().iterable(interceptionContext.getInterceptorBindings()) .onItem().transformToUniAndMerge(new Function>() { @Override public Uni apply(CacheInvalidate binding) { @@ -55,12 +58,13 @@ public Uni apply(CacheInvalidate binding) { @Override public Uni apply(Object ignored) { try { - return (Uni) invocationContext.proceed(); + return asyncInvocationResultToUni(invocationContext.proceed(), returnType); } catch (Exception e) { throw new CacheException(e); } } }); + return createAsyncResult(uni, returnType); } private Object invalidateBlocking(InvocationContext invocationContext, diff --git a/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheResultInterceptor.java b/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheResultInterceptor.java index 9eeb368866a2f1..d048be69f93341 100644 --- a/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheResultInterceptor.java +++ b/extensions/cache/runtime/src/main/java/io/quarkus/cache/runtime/CacheResultInterceptor.java @@ -50,7 +50,8 @@ public Object intercept(InvocationContext invocationContext) throws Throwable { LOGGER.debugf("Loading entry with key [%s] from cache [%s]", key, binding.cacheName()); try { - if (isUniReturnType(invocationContext)) { + ReturnType returnType = determineReturnType(invocationContext.getMethod().getReturnType()); + if (returnType != ReturnType.NonAsync) { Uni cacheValue = cache.get(key, new Function() { @Override public Object apply(Object k) { @@ -63,7 +64,7 @@ public Object apply(Object k) { public Uni apply(Object value) { if (value == UnresolvedUniValue.INSTANCE) { try { - return ((Uni) invocationContext.proceed()) + return asyncInvocationResultToUni(invocationContext.proceed(), returnType) .call(new Function>() { @Override public Uni apply(Object emittedValue) { @@ -81,14 +82,14 @@ public Uni apply(Object emittedValue) { } }); if (binding.lockTimeout() <= 0) { - return cacheValue; + return createAsyncResult(cacheValue, returnType); } - return cacheValue.ifNoItem().after(Duration.ofMillis(binding.lockTimeout())) + cacheValue = cacheValue.ifNoItem().after(Duration.ofMillis(binding.lockTimeout())) .recoverWithUni(new Supplier>() { @Override public Uni get() { try { - return (Uni) invocationContext.proceed(); + return asyncInvocationResultToUni(invocationContext.proceed(), returnType); } catch (CacheException e) { throw e; } catch (Exception e) { @@ -96,7 +97,7 @@ public Uni get() { } } }); - + return createAsyncResult(cacheValue, returnType); } else { Uni cacheValue = cache.get(key, new Function() { @Override @@ -137,4 +138,5 @@ public Object apply(Object k) { } } } + }