Skip to content

Commit

Permalink
ArC: custom context can get CurrentContextFactory during instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
mkouba committed May 14, 2024
1 parent ae51d7e commit d964eba
Show file tree
Hide file tree
Showing 11 changed files with 391 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.Arc;
import io.quarkus.arc.ComponentsProvider;
import io.quarkus.test.QuarkusUnitTest;

Expand All @@ -29,7 +30,8 @@ public void testContexts() {
assertTrue(bean.ping());
for (ComponentsProvider componentsProvider : ServiceLoader.load(ComponentsProvider.class)) {
// We have less than 1000 beans
assertFalse(componentsProvider.getComponents().getContextInstances().isEmpty());
assertFalse(componentsProvider.getComponents(Arc.container().getCurrentContextFactory()).getContextInstances()
.isEmpty());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.arc.Arc;
import io.quarkus.arc.ComponentsProvider;
import io.quarkus.test.QuarkusUnitTest;

Expand All @@ -27,7 +28,8 @@ public class OptimizeContextsDisabledTest {
public void testContexts() {
assertTrue(bean.ping());
for (ComponentsProvider componentsProvider : ServiceLoader.load(ComponentsProvider.class)) {
assertTrue(componentsProvider.getComponents().getContextInstances().isEmpty());
assertTrue(componentsProvider.getComponents(Arc.container().getCurrentContextFactory()).getContextInstances()
.isEmpty());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.quarkus.arc.ArcContainer;
import io.quarkus.arc.ContextInstanceHandle;
import io.quarkus.arc.CurrentContext;
import io.quarkus.arc.CurrentContextFactory;
import io.quarkus.arc.InjectableBean;
import io.quarkus.arc.ManagedContext;
import io.quarkus.arc.impl.ComputingCacheContextInstances;
Expand All @@ -35,19 +36,13 @@ public class WebSocketSessionContext implements ManagedContext {

private static final Logger LOG = Logger.getLogger(WebSocketSessionContext.class);

private final LazyValue<CurrentContext<SessionContextState>> currentContext;
private final CurrentContext<SessionContextState> currentContext;
private final LazyValue<Event<Object>> initializedEvent;
private final LazyValue<Event<Object>> beforeDestroyEvent;
private final LazyValue<Event<Object>> destroyEvent;

public WebSocketSessionContext() {
// Use lazy value because no-args constructor is needed
this.currentContext = new LazyValue<>(new Supplier<CurrentContext<SessionContextState>>() {
@Override
public CurrentContext<SessionContextState> get() {
return Arc.container().getCurrentContextFactory().create(SessionScoped.class);
}
});
public WebSocketSessionContext(CurrentContextFactory currentContextFactory) {
this.currentContext = currentContextFactory.create(SessionScoped.class);
this.initializedEvent = newEvent(Initialized.Literal.SESSION, Any.Literal.INSTANCE);
this.beforeDestroyEvent = newEvent(BeforeDestroyed.Literal.SESSION, Any.Literal.INSTANCE);
this.destroyEvent = newEvent(Destroyed.Literal.SESSION, Any.Literal.INSTANCE);
Expand All @@ -62,7 +57,6 @@ public Class<? extends Annotation> getScope() {
public ContextState getState() {
SessionContextState state = currentState();
if (state == null) {
// Thread local not set - context is not active!
throw notActive();
}
return state;
Expand All @@ -72,11 +66,11 @@ public ContextState getState() {
public ContextState activate(ContextState initialState) {
if (initialState == null) {
SessionContextState state = initializeContextState();
currentContext().set(state);
currentContext.set(state);
return state;
} else {
if (initialState instanceof SessionContextState) {
currentContext().set((SessionContextState) initialState);
currentContext.set((SessionContextState) initialState);
return initialState;
} else {
throw new IllegalArgumentException("Invalid initial state: " + initialState.getClass().getName());
Expand All @@ -86,7 +80,7 @@ public ContextState activate(ContextState initialState) {

@Override
public void deactivate() {
currentContext().remove();
currentContext.remove();
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -176,12 +170,8 @@ SessionContextState initializeContextState() {
return state;
}

private CurrentContext<SessionContextState> currentContext() {
return currentContext.get();
}

private SessionContextState currentState() {
return currentContext().get();
return currentContext.get();
}

private IllegalArgumentException invalidScope() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public class BeanDeployment {

private final Set<BeanInfo> beansWithRuntimeDeferredUnproxyableError;

// scope -> fun that accepts the method creator for ComponentsProvider#getComponents()
private final Map<ScopeInfo, List<Function<MethodCreator, ResultHandle>>> customContexts;

private final Map<DotName, BeanDefiningAnnotation> beanDefiningAnnotations;
Expand Down Expand Up @@ -214,7 +215,7 @@ public class BeanDeployment {
additionalStereotypes.addAll(stereotypeRegistrar.getAdditionalStereotypes());
}

this.stereotypes = findStereotypes(interceptorBindings, customContexts, additionalStereotypes,
this.stereotypes = findStereotypes(interceptorBindings, customContexts.keySet(), additionalStereotypes,
annotationStore);
buildContext.putInternal(Key.STEREOTYPES, Collections.unmodifiableMap(stereotypes));

Expand Down Expand Up @@ -734,7 +735,7 @@ Map<ScopeInfo, List<Function<MethodCreator, ResultHandle>>> getCustomContexts()
}

ScopeInfo getScope(DotName scopeAnnotationName) {
return getScope(scopeAnnotationName, customContexts);
return getScope(scopeAnnotationName, customContexts.keySet());
}

/**
Expand Down Expand Up @@ -874,8 +875,7 @@ private static Set<AnnotationInstance> recursiveBuild(DotName name,
}

private Map<DotName, StereotypeInfo> findStereotypes(Map<DotName, ClassInfo> interceptorBindings,
Map<ScopeInfo, List<Function<MethodCreator, ResultHandle>>> customContexts,
Set<DotName> additionalStereotypes, AnnotationStore annotationStore) {
Set<ScopeInfo> customContextScopes, Set<DotName> additionalStereotypes, AnnotationStore annotationStore) {

Map<DotName, StereotypeInfo> stereotypes = new HashMap<>();

Expand Down Expand Up @@ -917,7 +917,7 @@ private Map<DotName, StereotypeInfo> findStereotypes(Map<DotName, ClassInfo> int
} else if (DotNames.PRIORITY.equals(annotation.name())) {
alternativePriority = annotation.value().asInt();
} else {
final ScopeInfo scope = getScope(annotation.name(), customContexts);
final ScopeInfo scope = getScope(annotation.name(), customContextScopes);
if (scope != null) {
scopes.add(scope);
}
Expand All @@ -933,13 +933,12 @@ private Map<DotName, StereotypeInfo> findStereotypes(Map<DotName, ClassInfo> int
return stereotypes;
}

private ScopeInfo getScope(DotName scopeAnnotationName,
Map<ScopeInfo, List<Function<MethodCreator, ResultHandle>>> customContexts) {
private ScopeInfo getScope(DotName scopeAnnotationName, Set<ScopeInfo> customContextScopes) {
BuiltinScope builtin = BuiltinScope.from(scopeAnnotationName);
if (builtin != null) {
return builtin.getInfo();
}
for (ScopeInfo customScope : customContexts.keySet()) {
for (ScopeInfo customScope : customContextScopes) {
if (customScope.getDotName().equals(scopeAnnotationName)) {
return customScope;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,8 @@ private Set<DotName> findSingleContextNormalScopes() {
// built-in contexts
contextsForScope.put(BuiltinScope.REQUEST.getName(), 1);
// custom contexts
for (Map.Entry<ScopeInfo, List<Function<MethodCreator, ResultHandle>>> entry : beanDeployment.getCustomContexts()
for (Map.Entry<ScopeInfo, List<Function<MethodCreator, ResultHandle>>> entry : beanDeployment
.getCustomContexts()
.entrySet()) {
if (entry.getKey().isNormal()) {
contextsForScope.merge(entry.getKey().getDotName(), entry.getValue().size(), Integer::sum);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.quarkus.arc.Arc;
import io.quarkus.arc.Components;
import io.quarkus.arc.ComponentsProvider;
import io.quarkus.arc.CurrentContextFactory;
import io.quarkus.arc.InjectableBean;
import io.quarkus.arc.processor.ResourceOutput.Resource;
import io.quarkus.gizmo.AssignableResultHandle;
Expand Down Expand Up @@ -82,7 +83,8 @@ Collection<Resource> generate(String name, BeanDeployment beanDeployment, Map<Be
ClassCreator componentsProvider = ClassCreator.builder().classOutput(classOutput).className(generatedName)
.interfaces(ComponentsProvider.class).build();

MethodCreator getComponents = componentsProvider.getMethodCreator("getComponents", Components.class)
MethodCreator getComponents = componentsProvider
.getMethodCreator("getComponents", Components.class, CurrentContextFactory.class)
.setModifiers(ACC_PUBLIC);

Map<BeanInfo, List<BeanInfo>> dependencyMap = initBeanDependencyMap(beanDeployment);
Expand All @@ -100,9 +102,10 @@ Collection<Resource> generate(String name, BeanDeployment beanDeployment, Map<Be

// Custom contexts
ResultHandle contextsHandle = getComponents.newInstance(MethodDescriptor.ofConstructor(ArrayList.class));
for (Entry<ScopeInfo, List<Function<MethodCreator, ResultHandle>>> entry : beanDeployment.getCustomContexts()
for (Entry<ScopeInfo, List<Function<MethodCreator, ResultHandle>>> e : beanDeployment
.getCustomContexts()
.entrySet()) {
for (Function<MethodCreator, ResultHandle> func : entry.getValue()) {
for (Function<MethodCreator, ResultHandle> func : e.getValue()) {
ResultHandle contextHandle = func.apply(getComponents);
getComponents.invokeInterfaceMethod(MethodDescriptors.LIST_ADD, contextsHandle, contextHandle);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.quarkus.arc.processor;

import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
Expand All @@ -12,6 +14,7 @@
import jakarta.enterprise.context.NormalScope;

import io.quarkus.arc.ContextCreator;
import io.quarkus.arc.CurrentContextFactory;
import io.quarkus.arc.InjectableContext;
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
Expand Down Expand Up @@ -100,12 +103,53 @@ public ContextConfigurator normal(boolean value) {
}

public ContextConfigurator contextClass(Class<? extends InjectableContext> contextClazz) {
return creator(mc -> mc.newInstance(MethodDescriptor.ofConstructor(contextClazz)));
if (!Modifier.isPublic(contextClazz.getModifiers())
|| Modifier.isAbstract(contextClazz.getModifiers())
|| contextClazz.isAnonymousClass()
|| contextClazz.isLocalClass()
|| (contextClazz.getEnclosingClass() != null && !Modifier.isStatic(contextClazz.getModifiers()))) {
throw new IllegalArgumentException(
"A context class must be a public non-abstract top-level or static nested class");
}
Constructor<?> constructor = getConstructor(contextClazz);
if (constructor == null) {
throw new IllegalArgumentException(
"A context class must either declare a no-args constructor or a constructor that accepts a single parameter of type io.quarkus.arc.CurrentContextFactory");
}
return creator(new Function<>() {
@Override
public ResultHandle apply(MethodCreator mc) {
ResultHandle[] args;
if (constructor.getParameterCount() == 0) {
args = new ResultHandle[0];
} else {
args = new ResultHandle[] { mc.getMethodParam(0) };
}
return mc.newInstance(MethodDescriptor.ofConstructor(contextClazz, constructor.getParameterTypes()), args);
}
});
}

private Constructor<?> getConstructor(Class<? extends InjectableContext> contextClazz) {
Constructor<?> constructor = null;
try {
constructor = contextClazz.getDeclaredConstructor(CurrentContextFactory.class);
} catch (NoSuchMethodException ignored) {
}
if (constructor == null) {
try {
constructor = contextClazz.getDeclaredConstructor();
} catch (NoSuchMethodException ignored) {
}
}
return constructor;
}

public ContextConfigurator creator(Class<? extends ContextCreator> creatorClazz) {
return creator(mc -> {
ResultHandle paramsHandle = mc.newInstance(MethodDescriptor.ofConstructor(HashMap.class));
mc.invokeInterfaceMethod(MethodDescriptors.MAP_PUT, paramsHandle,
mc.load(ContextCreator.KEY_CURRENT_CONTEXT_FACTORY), mc.getMethodParam(0));
for (Entry<String, Object> entry : params.entrySet()) {
ResultHandle valHandle = null;
if (entry.getValue() instanceof String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public interface ComponentsProvider {

static Logger LOG = Logger.getLogger(ComponentsProvider.class);

Components getComponents();
Components getComponents(CurrentContextFactory currentContextFactory);

static void unableToLoadRemovedBeanType(String type, Throwable problem) {
LOG.warnf("Unable to load removed bean type [%s]: %s", type, problem.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
*/
public interface ContextCreator {

public static final String KEY_CURRENT_CONTEXT_FACTORY = "io.quarkus.arc.currentContextFactory";

/**
* The {@link #KEY_CURRENT_CONTEXT_FACTORY} can be used to obtain the {@link CurrentContextFactory}.
*
* @param params
* @return the context instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public ArcContainerImpl(CurrentContextFactory currentContextFactory, boolean str

List<Components> components = new ArrayList<>();
for (ComponentsProvider componentsProvider : ServiceLoader.load(ComponentsProvider.class)) {
components.add(componentsProvider.getComponents());
components.add(componentsProvider.getComponents(this.currentContextFactory));
}

for (Components c : components) {
Expand Down
Loading

0 comments on commit d964eba

Please sign in to comment.