Skip to content

Commit

Permalink
Add AOT support for @resource
Browse files Browse the repository at this point in the history
This commit adds ahead of time support for @resource on fields and
methods. Lookup elements are discovered and code is generated to replace
that introspection at runtime.

Closes spring-projectsgh-29614
  • Loading branch information
snicoll committed Oct 17, 2023
1 parent 4fd1431 commit 6944fec
Show file tree
Hide file tree
Showing 19 changed files with 1,555 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
Expand All @@ -35,6 +36,13 @@

import org.springframework.aop.TargetSource;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.aot.generate.AccessControl;
import org.springframework.aot.generate.GeneratedClass;
import org.springframework.aot.generate.GeneratedMethod;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.ExecutableMode;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.support.ClassHintUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.PropertyValues;
import org.springframework.beans.factory.BeanCreationException;
Expand All @@ -43,20 +51,30 @@
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.annotation.InitDestroyAnnotationBeanPostProcessor;
import org.springframework.beans.factory.annotation.InjectionMetadata;
import org.springframework.beans.factory.aot.BeanRegistrationAotContribution;
import org.springframework.beans.factory.aot.BeanRegistrationCode;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableBeanFactory;
import org.springframework.beans.factory.config.DependencyDescriptor;
import org.springframework.beans.factory.config.EmbeddedValueResolver;
import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessor;
import org.springframework.beans.factory.support.AutowireCandidateResolver;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.aot.ResourceFieldValueResolver;
import org.springframework.context.aot.ResourceMethodArgumentResolver;
import org.springframework.core.BridgeMethodResolver;
import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.javapoet.ClassName;
import org.springframework.javapoet.CodeBlock;
import org.springframework.jndi.support.SimpleJndiBeanFactory;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.util.StringValueResolver;
Expand Down Expand Up @@ -298,6 +316,37 @@ public void postProcessMergedBeanDefinition(RootBeanDefinition beanDefinition, C
metadata.checkConfigMembers(beanDefinition);
}

@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
BeanRegistrationAotContribution parentAotContribution = super.processAheadOfTime(registeredBean);
Class<?> beanClass = registeredBean.getBeanClass();
String beanName = registeredBean.getBeanName();
RootBeanDefinition beanDefinition = registeredBean.getMergedBeanDefinition();
InjectionMetadata metadata = findResourceMetadata(beanName, beanClass,
beanDefinition.getPropertyValues());
Collection<LookupElement> injectedElements = getInjectedElements(metadata,
beanDefinition.getPropertyValues());
if (!ObjectUtils.isEmpty(injectedElements)) {
AotContribution aotContribution = new AotContribution(beanClass, injectedElements,
getAutowireCandidateResolver(registeredBean));
return BeanRegistrationAotContribution.concat(parentAotContribution, aotContribution);
}
return parentAotContribution;
}

@Nullable
private AutowireCandidateResolver getAutowireCandidateResolver(RegisteredBean registeredBean) {
if (registeredBean.getBeanFactory() instanceof DefaultListableBeanFactory lbf) {
return lbf.getAutowireCandidateResolver();
}
return null;
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private Collection<LookupElement> getInjectedElements(InjectionMetadata metadata, PropertyValues propertyValues) {
return (Collection) metadata.getInjectedElements(propertyValues);
}

@Override
public void resetBeanDefinition(String beanName) {
this.injectionMetadataCache.remove(beanName);
Expand Down Expand Up @@ -789,4 +838,144 @@ public Class<?> getDependencyType() {
}
}

/**
* {@link BeanRegistrationAotContribution} to inject resources on fields and methods.
*/
private static class AotContribution implements BeanRegistrationAotContribution {

private static final String REGISTERED_BEAN_PARAMETER = "registeredBean";

private static final String INSTANCE_PARAMETER = "instance";

private final Class<?> target;

private final Collection<LookupElement> lookupElements;

@Nullable
private final AutowireCandidateResolver candidateResolver;

AotContribution(Class<?> target, Collection<LookupElement> lookupElements,
@Nullable AutowireCandidateResolver candidateResolver) {

this.target = target;
this.lookupElements = lookupElements;
this.candidateResolver = candidateResolver;
}

@Override
public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) {
GeneratedClass generatedClass = generationContext.getGeneratedClasses()
.addForFeatureComponent("ResourceAutowiring", this.target, type -> {
type.addJavadoc("Resource autowiring for {@link $T}.", this.target);
type.addModifiers(javax.lang.model.element.Modifier.PUBLIC);
});
GeneratedMethod generateMethod = generatedClass.getMethods().add("apply", method -> {
method.addJavadoc("Apply resource autowiring.");
method.addModifiers(javax.lang.model.element.Modifier.PUBLIC,
javax.lang.model.element.Modifier.STATIC);
method.addParameter(RegisteredBean.class, REGISTERED_BEAN_PARAMETER);
method.addParameter(this.target, INSTANCE_PARAMETER);
method.returns(this.target);
method.addCode(generateMethodCode(generatedClass.getName(),
generationContext.getRuntimeHints()));
});
beanRegistrationCode.addInstancePostProcessor(generateMethod.toMethodReference());

registerHints(generationContext.getRuntimeHints());
}

private CodeBlock generateMethodCode(ClassName targetClassName, RuntimeHints hints) {
CodeBlock.Builder code = CodeBlock.builder();
for (LookupElement lookupElement : this.lookupElements) {
code.addStatement(generateMethodStatementForElement(
targetClassName, lookupElement, hints));
}
code.addStatement("return $L", INSTANCE_PARAMETER);
return code.build();
}

private CodeBlock generateMethodStatementForElement(ClassName targetClassName,
LookupElement lookupElement, RuntimeHints hints) {

Member member = lookupElement.getMember();
if (member instanceof Field field) {
return generateMethodStatementForField(
targetClassName, field, lookupElement, hints);
}
if (member instanceof Method method) {
return generateMethodStatementForMethod(
targetClassName, method, lookupElement, hints);
}
throw new IllegalStateException(
"Unsupported member type " + member.getClass().getName());
}

private CodeBlock generateMethodStatementForField(ClassName targetClassName,
Field field, LookupElement lookupElement, RuntimeHints hints) {

hints.reflection().registerField(field);
CodeBlock resolver = generateFieldResolverCode(field, lookupElement);
AccessControl accessControl = AccessControl.forMember(field);
if (!accessControl.isAccessibleFrom(targetClassName)) {
return CodeBlock.of("$L.resolveAndSet($L, $L)", resolver,
REGISTERED_BEAN_PARAMETER, INSTANCE_PARAMETER);
}
return CodeBlock.of("$L.$L = $L.resolve($L)", INSTANCE_PARAMETER,
field.getName(), resolver, REGISTERED_BEAN_PARAMETER);
}

private CodeBlock generateFieldResolverCode(Field field, LookupElement lookupElement) {
if (lookupElement.isDefaultName) {
return CodeBlock.of("$T.$L($S)", ResourceFieldValueResolver.class,
"forField", field.getName());
}
else {
return CodeBlock.of("$T.$L($S, $S)", ResourceFieldValueResolver.class,
"forField", field.getName(), lookupElement.getName());
}
}

private CodeBlock generateMethodStatementForMethod(ClassName targetClassName,
Method method, LookupElement lookupElement, RuntimeHints hints) {

CodeBlock resolver = generateMethodResolverCode(method, lookupElement);
AccessControl accessControl = AccessControl.forMember(method);
if (!accessControl.isAccessibleFrom(targetClassName)) {
hints.reflection().registerMethod(method, ExecutableMode.INVOKE);
return CodeBlock.of("$L.resolveAndInvoke($L, $L)", resolver,
REGISTERED_BEAN_PARAMETER, INSTANCE_PARAMETER);
}
hints.reflection().registerMethod(method, ExecutableMode.INTROSPECT);
return CodeBlock.of("$L.$L($L.resolve($L))", INSTANCE_PARAMETER,
method.getName(), resolver, REGISTERED_BEAN_PARAMETER);

}

private CodeBlock generateMethodResolverCode(Method method, LookupElement lookupElement) {
if (lookupElement.isDefaultName) {
return CodeBlock.of("$T.$L($S, $T.class)", ResourceMethodArgumentResolver.class,
"forMethod", method.getName(), lookupElement.getLookupType());
}
else {
return CodeBlock.of("$T.$L($S, $T.class, $S)", ResourceMethodArgumentResolver.class,
"forMethod", method.getName(), lookupElement.getLookupType(), lookupElement.getName());
}
}

private void registerHints(RuntimeHints runtimeHints) {
this.lookupElements.forEach(lookupElement ->
registerProxyIfNecessary(runtimeHints, lookupElement.getDependencyDescriptor()));
}

private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) {
if (this.candidateResolver != null) {
Class<?> proxyClass =
this.candidateResolver.getLazyResolutionProxyClass(dependencyDescriptor, null);
if (proxyClass != null) {
ClassHintUtils.registerProxyIfNecessary(proxyClass, runtimeHints);
}
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.context.aot;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;

import javax.lang.model.element.Element;

import jakarta.annotation.Resource;

import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.DependencyDescriptor;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.MethodParameter;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
* Base class for resolvers that support injection of named beans on
* an {@link Element}.
*
* @author Stephane Nicoll
* @since 6.1
* @see Resource
*/
public abstract class ResourceElementResolver {

protected final String name;

protected final boolean defaultName;

protected ResourceElementResolver(String name, boolean defaultName) {
this.name = name;
this.defaultName = defaultName;
}

/**
* Resolve the field value for the specified registered bean.
* @param registeredBean the registered bean
* @return the resolved field value
*/
@Nullable
@SuppressWarnings("unchecked")
public <T> T resolve(RegisteredBean registeredBean) {
return (T) resolveObject(registeredBean);
}

/**
* Resolve the field value for the specified registered bean.
* @param registeredBean the registered bean
* @return the resolved field value
*/
public Object resolveObject(RegisteredBean registeredBean) {
Assert.notNull(registeredBean, "'registeredBean' must not be null");
return resolveValue(registeredBean);
}


/**
* Create a suitable {@link DependencyDescriptor} for the specified bean.
* @param bean the registered bean
* @return a descriptor for that bean
*/
protected abstract DependencyDescriptor createDependencyDescriptor(RegisteredBean bean);

/**
* Resolve the value to inject for this instance.
* @param bean the bean registration
* @return the value to inject
*/
protected Object resolveValue(RegisteredBean bean) {
ConfigurableListableBeanFactory factory = bean.getBeanFactory();

Object resource;
Set<String> autowiredBeanNames;
DependencyDescriptor descriptor = createDependencyDescriptor(bean);
if (this.defaultName && !factory.containsBean(this.name)) {
autowiredBeanNames = new LinkedHashSet<>();
resource = factory.resolveDependency(descriptor, bean.getBeanName(), autowiredBeanNames, null);
if (resource == null) {
throw new NoSuchBeanDefinitionException(descriptor.getDependencyType(), "No resolvable resource object");
}
}
else {
resource = factory.resolveBeanByName(this.name, descriptor);
autowiredBeanNames = Collections.singleton(this.name);
}

for (String autowiredBeanName : autowiredBeanNames) {
if (factory.containsBean(autowiredBeanName)) {
factory.registerDependentBean(autowiredBeanName, bean.getBeanName());
}
}
return resource;
}


@SuppressWarnings("serial")
protected static class LookupDependencyDescriptor extends DependencyDescriptor {

private final Class<?> lookupType;

public LookupDependencyDescriptor(Field field, Class<?> lookupType) {
super(field, true);
this.lookupType = lookupType;
}

public LookupDependencyDescriptor(Method method, Class<?> lookupType) {
super(new MethodParameter(method, 0), true);
this.lookupType = lookupType;
}

@Override
public Class<?> getDependencyType() {
return this.lookupType;
}
}

}
Loading

0 comments on commit 6944fec

Please sign in to comment.