diff --git a/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/optimized/OptimizeContextsAutoTest.java b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/optimized/OptimizeContextsAutoTest.java index 666c38dbc79f3..56be28ce16c3b 100644 --- a/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/optimized/OptimizeContextsAutoTest.java +++ b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/optimized/OptimizeContextsAutoTest.java @@ -10,6 +10,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import io.quarkus.arc.Arc; import io.quarkus.arc.ComponentsProvider; import io.quarkus.test.QuarkusUnitTest; @@ -29,7 +30,8 @@ public void testContexts() { assertTrue(bean.ping()); for (ComponentsProvider componentsProvider : ServiceLoader.load(ComponentsProvider.class)) { // We have less than 1000 beans - assertFalse(componentsProvider.getComponents().getContextInstances().isEmpty()); + assertFalse(componentsProvider.getComponents(Arc.container().getCurrentContextFactory()).getContextInstances() + .isEmpty()); } } } diff --git a/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/optimized/OptimizeContextsDisabledTest.java b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/optimized/OptimizeContextsDisabledTest.java index b1b611c81312c..64fc8c86de6ae 100644 --- a/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/optimized/OptimizeContextsDisabledTest.java +++ b/extensions/arc/deployment/src/test/java/io/quarkus/arc/test/context/optimized/OptimizeContextsDisabledTest.java @@ -9,6 +9,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import io.quarkus.arc.Arc; import io.quarkus.arc.ComponentsProvider; import io.quarkus.test.QuarkusUnitTest; @@ -27,7 +28,8 @@ public class OptimizeContextsDisabledTest { public void testContexts() { assertTrue(bean.ping()); for (ComponentsProvider componentsProvider : ServiceLoader.load(ComponentsProvider.class)) { - assertTrue(componentsProvider.getComponents().getContextInstances().isEmpty()); + assertTrue(componentsProvider.getComponents(Arc.container().getCurrentContextFactory()).getContextInstances() + .isEmpty()); } } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java index f6e931a2c850f..3d6c488289c41 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketSessionContext.java @@ -24,6 +24,7 @@ import io.quarkus.arc.ArcContainer; import io.quarkus.arc.ContextInstanceHandle; import io.quarkus.arc.CurrentContext; +import io.quarkus.arc.CurrentContextFactory; import io.quarkus.arc.InjectableBean; import io.quarkus.arc.ManagedContext; import io.quarkus.arc.impl.ComputingCacheContextInstances; @@ -35,19 +36,13 @@ public class WebSocketSessionContext implements ManagedContext { private static final Logger LOG = Logger.getLogger(WebSocketSessionContext.class); - private final LazyValue> currentContext; + private final CurrentContext currentContext; private final LazyValue> initializedEvent; private final LazyValue> beforeDestroyEvent; private final LazyValue> destroyEvent; - public WebSocketSessionContext() { - // Use lazy value because no-args constructor is needed - this.currentContext = new LazyValue<>(new Supplier>() { - @Override - public CurrentContext get() { - return Arc.container().getCurrentContextFactory().create(SessionScoped.class); - } - }); + public WebSocketSessionContext(CurrentContextFactory currentContextFactory) { + this.currentContext = currentContextFactory.create(SessionScoped.class); this.initializedEvent = newEvent(Initialized.Literal.SESSION, Any.Literal.INSTANCE); this.beforeDestroyEvent = newEvent(BeforeDestroyed.Literal.SESSION, Any.Literal.INSTANCE); this.destroyEvent = newEvent(Destroyed.Literal.SESSION, Any.Literal.INSTANCE); @@ -62,7 +57,6 @@ public Class getScope() { public ContextState getState() { SessionContextState state = currentState(); if (state == null) { - // Thread local not set - context is not active! throw notActive(); } return state; @@ -72,11 +66,11 @@ public ContextState getState() { public ContextState activate(ContextState initialState) { if (initialState == null) { SessionContextState state = initializeContextState(); - currentContext().set(state); + currentContext.set(state); return state; } else { if (initialState instanceof SessionContextState) { - currentContext().set((SessionContextState) initialState); + currentContext.set((SessionContextState) initialState); return initialState; } else { throw new IllegalArgumentException("Invalid initial state: " + initialState.getClass().getName()); @@ -86,7 +80,7 @@ public ContextState activate(ContextState initialState) { @Override public void deactivate() { - currentContext().remove(); + currentContext.remove(); } @SuppressWarnings("unchecked") @@ -176,12 +170,8 @@ SessionContextState initializeContextState() { return state; } - private CurrentContext currentContext() { - return currentContext.get(); - } - private SessionContextState currentState() { - return currentContext().get(); + return currentContext.get(); } private IllegalArgumentException invalidScope() { diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanDeployment.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanDeployment.java index b58413e26ead9..39d5ce8b1769c 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanDeployment.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanDeployment.java @@ -114,6 +114,7 @@ public class BeanDeployment { private final Set beansWithRuntimeDeferredUnproxyableError; + // scope -> fun that accepts the method creator for ComponentsProvider#getComponents() private final Map>> customContexts; private final Map beanDefiningAnnotations; @@ -214,7 +215,7 @@ public class BeanDeployment { additionalStereotypes.addAll(stereotypeRegistrar.getAdditionalStereotypes()); } - this.stereotypes = findStereotypes(interceptorBindings, customContexts, additionalStereotypes, + this.stereotypes = findStereotypes(interceptorBindings, customContexts.keySet(), additionalStereotypes, annotationStore); buildContext.putInternal(Key.STEREOTYPES, Collections.unmodifiableMap(stereotypes)); @@ -734,7 +735,7 @@ Map>> getCustomContexts() } ScopeInfo getScope(DotName scopeAnnotationName) { - return getScope(scopeAnnotationName, customContexts); + return getScope(scopeAnnotationName, customContexts.keySet()); } /** @@ -874,8 +875,7 @@ private static Set recursiveBuild(DotName name, } private Map findStereotypes(Map interceptorBindings, - Map>> customContexts, - Set additionalStereotypes, AnnotationStore annotationStore) { + Set customContextScopes, Set additionalStereotypes, AnnotationStore annotationStore) { Map stereotypes = new HashMap<>(); @@ -917,7 +917,7 @@ private Map findStereotypes(Map int } else if (DotNames.PRIORITY.equals(annotation.name())) { alternativePriority = annotation.value().asInt(); } else { - final ScopeInfo scope = getScope(annotation.name(), customContexts); + final ScopeInfo scope = getScope(annotation.name(), customContextScopes); if (scope != null) { scopes.add(scope); } @@ -933,13 +933,12 @@ private Map findStereotypes(Map int return stereotypes; } - private ScopeInfo getScope(DotName scopeAnnotationName, - Map>> customContexts) { + private ScopeInfo getScope(DotName scopeAnnotationName, Set customContextScopes) { BuiltinScope builtin = BuiltinScope.from(scopeAnnotationName); if (builtin != null) { return builtin.getInfo(); } - for (ScopeInfo customScope : customContexts.keySet()) { + for (ScopeInfo customScope : customContextScopes) { if (customScope.getDotName().equals(scopeAnnotationName)) { return customScope; } diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanProcessor.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanProcessor.java index 5cfb81c54e5f0..995405f064e9c 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanProcessor.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/BeanProcessor.java @@ -543,7 +543,8 @@ private Set findSingleContextNormalScopes() { // built-in contexts contextsForScope.put(BuiltinScope.REQUEST.getName(), 1); // custom contexts - for (Map.Entry>> entry : beanDeployment.getCustomContexts() + for (Map.Entry>> entry : beanDeployment + .getCustomContexts() .entrySet()) { if (entry.getKey().isNormal()) { contextsForScope.merge(entry.getKey().getDotName(), entry.getValue().size(), Integer::sum); diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/ComponentsProviderGenerator.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/ComponentsProviderGenerator.java index f60ee63e7907d..7a38309b826c3 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/ComponentsProviderGenerator.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/ComponentsProviderGenerator.java @@ -28,6 +28,7 @@ import io.quarkus.arc.Arc; import io.quarkus.arc.Components; import io.quarkus.arc.ComponentsProvider; +import io.quarkus.arc.CurrentContextFactory; import io.quarkus.arc.InjectableBean; import io.quarkus.arc.processor.ResourceOutput.Resource; import io.quarkus.gizmo.AssignableResultHandle; @@ -82,7 +83,8 @@ Collection generate(String name, BeanDeployment beanDeployment, Map> dependencyMap = initBeanDependencyMap(beanDeployment); @@ -100,9 +102,10 @@ Collection generate(String name, BeanDeployment beanDeployment, Map>> entry : beanDeployment.getCustomContexts() + for (Entry>> e : beanDeployment + .getCustomContexts() .entrySet()) { - for (Function func : entry.getValue()) { + for (Function func : e.getValue()) { ResultHandle contextHandle = func.apply(getComponents); getComponents.invokeInterfaceMethod(MethodDescriptors.LIST_ADD, contextsHandle, contextHandle); } diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/ContextConfigurator.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/ContextConfigurator.java index ab8f8959372f2..b7cdb0ee02074 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/ContextConfigurator.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/ContextConfigurator.java @@ -1,6 +1,8 @@ package io.quarkus.arc.processor; import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.Modifier; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; @@ -12,6 +14,7 @@ import jakarta.enterprise.context.NormalScope; import io.quarkus.arc.ContextCreator; +import io.quarkus.arc.CurrentContextFactory; import io.quarkus.arc.InjectableContext; import io.quarkus.gizmo.MethodCreator; import io.quarkus.gizmo.MethodDescriptor; @@ -100,12 +103,53 @@ public ContextConfigurator normal(boolean value) { } public ContextConfigurator contextClass(Class contextClazz) { - return creator(mc -> mc.newInstance(MethodDescriptor.ofConstructor(contextClazz))); + if (!Modifier.isPublic(contextClazz.getModifiers()) + || Modifier.isAbstract(contextClazz.getModifiers()) + || contextClazz.isAnonymousClass() + || contextClazz.isLocalClass() + || (contextClazz.getEnclosingClass() != null && !Modifier.isStatic(contextClazz.getModifiers()))) { + throw new IllegalArgumentException( + "A context class must be a public non-abstract top-level or static nested class"); + } + Constructor constructor = getConstructor(contextClazz); + if (constructor == null) { + throw new IllegalArgumentException( + "A context class must either declare a no-args constructor or a constructor that accepts a single parameter of type io.quarkus.arc.CurrentContextFactory"); + } + return creator(new Function<>() { + @Override + public ResultHandle apply(MethodCreator mc) { + ResultHandle[] args; + if (constructor.getParameterCount() == 0) { + args = new ResultHandle[0]; + } else { + args = new ResultHandle[] { mc.getMethodParam(0) }; + } + return mc.newInstance(MethodDescriptor.ofConstructor(contextClazz, constructor.getParameterTypes()), args); + } + }); + } + + private Constructor getConstructor(Class contextClazz) { + Constructor constructor = null; + try { + constructor = contextClazz.getDeclaredConstructor(CurrentContextFactory.class); + } catch (NoSuchMethodException ignored) { + } + if (constructor == null) { + try { + constructor = contextClazz.getDeclaredConstructor(); + } catch (NoSuchMethodException ignored) { + } + } + return constructor; } public ContextConfigurator creator(Class creatorClazz) { return creator(mc -> { ResultHandle paramsHandle = mc.newInstance(MethodDescriptor.ofConstructor(HashMap.class)); + mc.invokeInterfaceMethod(MethodDescriptors.MAP_PUT, paramsHandle, + mc.load(ContextCreator.KEY_CURRENT_CONTEXT_FACTORY), mc.getMethodParam(0)); for (Entry entry : params.entrySet()) { ResultHandle valHandle = null; if (entry.getValue() instanceof String) { diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ComponentsProvider.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ComponentsProvider.java index 6ef91d6baa22f..e306643dffbc8 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ComponentsProvider.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ComponentsProvider.java @@ -9,7 +9,7 @@ public interface ComponentsProvider { static Logger LOG = Logger.getLogger(ComponentsProvider.class); - Components getComponents(); + Components getComponents(CurrentContextFactory currentContextFactory); static void unableToLoadRemovedBeanType(String type, Throwable problem) { LOG.warnf("Unable to load removed bean type [%s]: %s", type, problem.toString()); diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextCreator.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextCreator.java index cdd590104eedb..f8163a1a085c3 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextCreator.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/ContextCreator.java @@ -7,7 +7,10 @@ */ public interface ContextCreator { + public static final String KEY_CURRENT_CONTEXT_FACTORY = "io.quarkus.arc.currentContextFactory"; + /** + * The {@link #KEY_CURRENT_CONTEXT_FACTORY} can be used to obtain the {@link CurrentContextFactory}. * * @param params * @return the context instance diff --git a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java index 25e8e75258b5d..c63710b5cbd86 100644 --- a/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java +++ b/independent-projects/arc/runtime/src/main/java/io/quarkus/arc/impl/ArcContainerImpl.java @@ -125,7 +125,7 @@ public ArcContainerImpl(CurrentContextFactory currentContextFactory, boolean str List components = new ArrayList<>(); for (ComponentsProvider componentsProvider : ServiceLoader.load(ComponentsProvider.class)) { - components.add(componentsProvider.getComponents()); + components.add(componentsProvider.getComponents(this.currentContextFactory)); } for (Components c : components) { diff --git a/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/buildextension/context/CustomContextTest.java b/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/buildextension/context/CustomContextTest.java new file mode 100644 index 0000000000000..80a9987956e67 --- /dev/null +++ b/independent-projects/arc/tests/src/test/java/io/quarkus/arc/test/buildextension/context/CustomContextTest.java @@ -0,0 +1,312 @@ +package io.quarkus.arc.test.buildextension.context; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.lang.annotation.Annotation; +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.Map; + +import jakarta.enterprise.context.NormalScope; +import jakarta.enterprise.context.spi.Contextual; +import jakarta.enterprise.context.spi.CreationalContext; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.arc.ArcContainer; +import io.quarkus.arc.ContextCreator; +import io.quarkus.arc.CurrentContextFactory; +import io.quarkus.arc.InjectableContext; +import io.quarkus.arc.processor.ContextConfigurator; +import io.quarkus.arc.processor.ContextRegistrar; +import io.quarkus.arc.test.ArcTestContainer; + +public class CustomContextTest { + + @RegisterExtension + public ArcTestContainer container = ArcTestContainer.builder() + .beanClasses(FieldScoped.class, MeadowScoped.class, Mina.class, InvalidNestedContext.class, + InvalidAbstractContext.class, InvalidAbstractContext.class, FieldContext.class, MeadowContext.class) + .contextRegistrars(new ContextRegistrar() { + @Override + public void register(RegistrationContext ctx) { + ContextConfigurator configurator = ctx.configure(FieldScoped.class); + assertThrows(IllegalArgumentException.class, () -> configurator.contextClass(InvalidNestedContext.class)); + assertThrows(IllegalArgumentException.class, () -> configurator.contextClass(InvalidAbstractContext.class)); + assertThrows(IllegalArgumentException.class, + () -> configurator.contextClass(InvalidConcstructorContext.class)); + configurator.contextClass(FieldContext.class).done(); + } + }) + .contextRegistrars(new ContextRegistrar() { + @Override + public void register(RegistrationContext ctx) { + ctx.configure(MeadowScoped.class).creator(MeadowCreator.class).done(); + } + + }) + .build(); + + @Test + public void testCustomScope() { + ArcContainer arc = Arc.container(); + assertEquals("bac", arc.instance(Mina.class).get().bum()); + } + + @FieldScoped + public static class Mina { + + public String bum() { + return "bac"; + } + + } + + @MeadowScoped + public static class Flower { + + public void bloom() { + } + + } + + @NormalScope + @Inherited + @Target({ TYPE, METHOD, FIELD }) + @Retention(RUNTIME) + public @interface FieldScoped { + } + + @NormalScope + @Inherited + @Target({ TYPE, METHOD, FIELD }) + @Retention(RUNTIME) + public @interface MeadowScoped { + } + + public static class FieldContext implements InjectableContext { + + public FieldContext(CurrentContextFactory ccf) { + assertNotNull(ccf); + } + + @Override + public void destroy(Contextual contextual) { + throw new UnsupportedOperationException(); + } + + @Override + public Class getScope() { + return FieldScoped.class; + } + + @SuppressWarnings("unchecked") + @Override + public T get(Contextual contextual, CreationalContext creationalContext) { + return (T) new Mina(); + } + + @SuppressWarnings("unchecked") + @Override + public T get(Contextual contextual) { + return (T) new Mina(); + } + + @Override + public boolean isActive() { + return true; + } + + @Override + public void destroy() { + throw new UnsupportedOperationException(); + } + + @Override + public ContextState getState() { + throw new UnsupportedOperationException(); + } + + } + + public static class MeadowCreator implements ContextCreator { + + @Override + public InjectableContext create(Map params) { + assertNotNull(params.get(KEY_CURRENT_CONTEXT_FACTORY)); + return new MeadowContext(); + } + + } + + public static class MeadowContext implements InjectableContext { + + @Override + public void destroy(Contextual contextual) { + throw new UnsupportedOperationException(); + } + + @Override + public Class getScope() { + return MeadowScoped.class; + } + + @SuppressWarnings("unchecked") + @Override + public T get(Contextual contextual, CreationalContext creationalContext) { + return (T) new Flower(); + } + + @SuppressWarnings("unchecked") + @Override + public T get(Contextual contextual) { + return (T) new Flower(); + } + + @Override + public boolean isActive() { + return true; + } + + @Override + public void destroy() { + throw new UnsupportedOperationException(); + } + + @Override + public ContextState getState() { + throw new UnsupportedOperationException(); + } + + } + + class InvalidNestedContext implements InjectableContext { + + @Override + public void destroy(Contextual contextual) { + throw new UnsupportedOperationException(); + } + + @Override + public Class getScope() { + return FieldScoped.class; + } + + @Override + public T get(Contextual contextual, CreationalContext creationalContext) { + throw new UnsupportedOperationException(); + } + + @Override + public T get(Contextual contextual) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isActive() { + throw new UnsupportedOperationException(); + } + + @Override + public void destroy() { + throw new UnsupportedOperationException(); + } + + @Override + public ContextState getState() { + throw new UnsupportedOperationException(); + } + + } + + public abstract class InvalidAbstractContext implements InjectableContext { + + @Override + public void destroy(Contextual contextual) { + throw new UnsupportedOperationException(); + } + + @Override + public Class getScope() { + return FieldScoped.class; + } + + @Override + public T get(Contextual contextual, CreationalContext creationalContext) { + throw new UnsupportedOperationException(); + } + + @Override + public T get(Contextual contextual) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isActive() { + throw new UnsupportedOperationException(); + } + + @Override + public void destroy() { + throw new UnsupportedOperationException(); + } + + @Override + public ContextState getState() { + throw new UnsupportedOperationException(); + } + + } + + public class InvalidConcstructorContext implements InjectableContext { + + public InvalidConcstructorContext(Long age) { + } + + @Override + public void destroy(Contextual contextual) { + throw new UnsupportedOperationException(); + } + + @Override + public Class getScope() { + return FieldScoped.class; + } + + @Override + public T get(Contextual contextual, CreationalContext creationalContext) { + throw new UnsupportedOperationException(); + } + + @Override + public T get(Contextual contextual) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isActive() { + throw new UnsupportedOperationException(); + } + + @Override + public void destroy() { + throw new UnsupportedOperationException(); + } + + @Override + public ContextState getState() { + throw new UnsupportedOperationException(); + } + + } + +}