Skip to content

Commit

Permalink
Infer proxy on @Lazy-annotated injection points
Browse files Browse the repository at this point in the history
This commit makes use of the new `getLazyResolutionProxyClass` on
`AutowireCandidateResolver` to detect if a injection point requires
a proxy.

Closes gh-28980
  • Loading branch information
snicoll committed Aug 22, 2022
1 parent e5f9bb7 commit 4557158
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.lang.reflect.Member;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -70,6 +71,8 @@
import org.springframework.beans.factory.config.DependencyDescriptor;
import org.springframework.beans.factory.config.SmartInstantiationAwareBeanPostProcessor;
import org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory;
import org.springframework.beans.factory.support.AutowireCandidateResolver;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.LookupOverride;
import org.springframework.beans.factory.support.MergedBeanDefinitionPostProcessor;
import org.springframework.beans.factory.support.RegisteredBean;
Expand Down Expand Up @@ -289,7 +292,7 @@ public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registe
InjectionMetadata metadata = findInjectionMetadata(beanName, beanClass, beanDefinition);
Collection<AutowiredElement> autowiredElements = getAutowiredElements(metadata);
if (!ObjectUtils.isEmpty(autowiredElements)) {
return new AotContribution(beanClass, autowiredElements);
return new AotContribution(beanClass, autowiredElements, getAutowireCandidateResolver());
}
return null;
}
Expand All @@ -300,6 +303,14 @@ private Collection<AutowiredElement> getAutowiredElements(InjectionMetadata meta
return (Collection) metadata.getInjectedElements();
}

@Nullable
private AutowireCandidateResolver getAutowireCandidateResolver() {
if (this.beanFactory instanceof DefaultListableBeanFactory lbf) {
return lbf.getAutowireCandidateResolver();
}
return null;
}

private InjectionMetadata findInjectionMetadata(String beanName, Class<?> beanType, RootBeanDefinition beanDefinition) {
InjectionMetadata metadata = findAutowiringMetadata(beanName, beanType, null);
metadata.checkConfigMembers(beanDefinition);
Expand Down Expand Up @@ -914,10 +925,15 @@ private static class AotContribution implements BeanRegistrationAotContribution

private final Collection<AutowiredElement> autowiredElements;

@Nullable
private final AutowireCandidateResolver candidateResolver;

AotContribution(Class<?> target, Collection<AutowiredElement> autowiredElements,
@Nullable AutowireCandidateResolver candidateResolver) {

AotContribution(Class<?> target, Collection<AutowiredElement> autowiredElements) {
this.target = target;
this.autowiredElements = autowiredElements;
this.candidateResolver = candidateResolver;
}


Expand All @@ -940,6 +956,10 @@ public void applyTo(GenerationContext generationContext,
});
beanRegistrationCode.addInstancePostProcessor(
MethodReference.ofStatic(generatedClass.getName(), generateMethod.getName()));

if (this.candidateResolver != null) {
registerHints(generationContext.getRuntimeHints());
}
}

private CodeBlock generateMethodCode(RuntimeHints hints) {
Expand Down Expand Up @@ -1023,6 +1043,35 @@ private CodeBlock generateParameterTypesCode(Class<?>[] parameterTypes) {
return code.build();
}

private void registerHints(RuntimeHints runtimeHints) {
this.autowiredElements.forEach(autowiredElement -> {
boolean required = autowiredElement.required;
Member member = autowiredElement.getMember();
if (member instanceof Field field) {
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(
field, required);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
if (member instanceof Method method) {
Class<?>[] parameterTypes = method.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
MethodParameter methodParam = new MethodParameter(method, i);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(
methodParam, required);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
}
});
}

private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) {
Class<?> proxyType = this.candidateResolver
.getLazyResolutionProxyClass(dependencyDescriptor, null);
if (proxyType != null && Proxy.isProxyClass(proxyType)) {
runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces());
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package org.springframework.beans.factory.aot;

import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.List;

import javax.lang.model.element.Modifier;
Expand All @@ -26,8 +29,13 @@
import org.springframework.aot.generate.GeneratedMethods;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.generate.MethodReference;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.DependencyDescriptor;
import org.springframework.beans.factory.support.AutowireCandidateResolver;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.core.MethodParameter;
import org.springframework.javapoet.ClassName;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -83,6 +91,7 @@ class BeanDefinitionMethodGenerator {
MethodReference generateBeanDefinitionMethod(GenerationContext generationContext,
BeanRegistrationsCode beanRegistrationsCode) {

registerRuntimeHintsIfNecessary(generationContext.getRuntimeHints());
BeanRegistrationCodeFragments codeFragments = getCodeFragments(generationContext,
beanRegistrationsCode);
Class<?> target = codeFragments.getTarget(this.registeredBean,
Expand Down Expand Up @@ -166,4 +175,54 @@ private String getSimpleBeanName(String beanName) {
return StringUtils.uncapitalize(beanName);
}

private void registerRuntimeHintsIfNecessary(RuntimeHints runtimeHints) {
if (this.registeredBean.getBeanFactory() instanceof DefaultListableBeanFactory dlbf) {
ProxyRuntimeHintsRegistrar registrar = new ProxyRuntimeHintsRegistrar(dlbf.getAutowireCandidateResolver());
if (this.constructorOrFactoryMethod instanceof Method method) {
registrar.registerRuntimeHints(runtimeHints, method);
}
else if (this.constructorOrFactoryMethod instanceof Constructor<?> constructor) {
registrar.registerRuntimeHints(runtimeHints, constructor);
}
}
}

private static class ProxyRuntimeHintsRegistrar {

private final AutowireCandidateResolver candidateResolver;

public ProxyRuntimeHintsRegistrar(AutowireCandidateResolver candidateResolver) {
this.candidateResolver = candidateResolver;
}

public void registerRuntimeHints(RuntimeHints runtimeHints, Method method) {
Class<?>[] parameterTypes = method.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
MethodParameter methodParam = new MethodParameter(method, i);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(
methodParam, true);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
}

public void registerRuntimeHints(RuntimeHints runtimeHints, Constructor<?> constructor) {
Class<?>[] parameterTypes = constructor.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
MethodParameter methodParam = new MethodParameter(constructor, i);
DependencyDescriptor dependencyDescriptor = new DependencyDescriptor(
methodParam, true);
registerProxyIfNecessary(runtimeHints, dependencyDescriptor);
}
}

private void registerProxyIfNecessary(RuntimeHints runtimeHints, DependencyDescriptor dependencyDescriptor) {
Class<?> proxyType = this.candidateResolver
.getLazyResolutionProxyClass(dependencyDescriptor, null);
if (proxyType != null && Proxy.isProxyClass(proxyType)) {
runtimeHints.proxies().registerJdkProxy(proxyType.getInterfaces());
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package org.springframework.context.aot;

import java.io.IOException;
import java.lang.reflect.Proxy;
import java.util.function.BiConsumer;

import org.junit.jupiter.api.Test;

import org.springframework.aot.generate.GeneratedFiles.Kind;
import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.TypeReference;
import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
import org.springframework.aot.test.generator.compile.Compiled;
Expand All @@ -44,11 +47,18 @@
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.AnnotationConfigUtils;
import org.springframework.context.annotation.CommonAnnotationBeanPostProcessor;
import org.springframework.context.annotation.ContextAnnotationAutowireCandidateResolver;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.context.testfixture.context.generator.SimpleComponent;
import org.springframework.context.testfixture.context.generator.annotation.AutowiredComponent;
import org.springframework.context.testfixture.context.generator.annotation.CglibConfiguration;
import org.springframework.context.testfixture.context.generator.annotation.InitDestroyComponent;
import org.springframework.context.testfixture.context.generator.annotation.LazyAutowiredFieldComponent;
import org.springframework.context.testfixture.context.generator.annotation.LazyAutowiredMethodComponent;
import org.springframework.context.testfixture.context.generator.annotation.LazyConstructorArgumentComponent;
import org.springframework.context.testfixture.context.generator.annotation.LazyFactoryMethodArgumentComponent;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -93,6 +103,87 @@ void processAheadOfTimeWhenHasAutowiring() {
});
}

@Test
void processAheadOfTimeWhenHasLazyAutowiringOnField() {
testAutowiredComponent(LazyAutowiredFieldComponent.class, (bean, generationContext) -> {
Environment environment = bean.getEnvironment();
assertThat(environment).isInstanceOf(Proxy.class);
ResourceLoader resourceLoader = bean.getResourceLoader();
assertThat(resourceLoader).isNotInstanceOf(Proxy.class);
RuntimeHints runtimeHints = generationContext.getRuntimeHints();
assertThat(runtimeHints.proxies().jdkProxies()).singleElement().satisfies(proxyHint ->
assertThat(proxyHint.getProxiedInterfaces()).isEqualTo(TypeReference.listOf(
environment.getClass().getInterfaces())));

});
}

@Test
void processAheadOfTimeWhenHasLazyAutowiringOnMethod() {
testAutowiredComponent(LazyAutowiredMethodComponent.class, (bean, generationContext) -> {
Environment environment = bean.getEnvironment();
assertThat(environment).isNotInstanceOf(Proxy.class);
ResourceLoader resourceLoader = bean.getResourceLoader();
assertThat(resourceLoader).isInstanceOf(Proxy.class);
RuntimeHints runtimeHints = generationContext.getRuntimeHints();
assertThat(runtimeHints.proxies().jdkProxies()).singleElement().satisfies(proxyHint ->
assertThat(proxyHint.getProxiedInterfaces()).isEqualTo(TypeReference.listOf(
resourceLoader.getClass().getInterfaces())));
});
}

@Test
void processAheadOfTimeWhenHasLazyAutowiringOnConstructor() {
testAutowiredComponent(LazyConstructorArgumentComponent.class, (bean, generationContext) -> {
Environment environment = bean.getEnvironment();
assertThat(environment).isInstanceOf(Proxy.class);
ResourceLoader resourceLoader = bean.getResourceLoader();
assertThat(resourceLoader).isNotInstanceOf(Proxy.class);
RuntimeHints runtimeHints = generationContext.getRuntimeHints();
assertThat(runtimeHints.proxies().jdkProxies()).singleElement().satisfies(proxyHint ->
assertThat(proxyHint.getProxiedInterfaces()).isEqualTo(TypeReference.listOf(
environment.getClass().getInterfaces())));
});
}

@Test
void processAheadOfTimeWhenHasLazyAutowiringOnFactoryMethod() {
RootBeanDefinition bd = new RootBeanDefinition(LazyFactoryMethodArgumentComponent.class);
bd.setFactoryMethodName("of");
testAutowiredComponent(LazyFactoryMethodArgumentComponent.class, bd, (bean, generationContext) -> {
Environment environment = bean.getEnvironment();
assertThat(environment).isInstanceOf(Proxy.class);
ResourceLoader resourceLoader = bean.getResourceLoader();
assertThat(resourceLoader).isNotInstanceOf(Proxy.class);
RuntimeHints runtimeHints = generationContext.getRuntimeHints();
assertThat(runtimeHints.proxies().jdkProxies()).singleElement().satisfies(proxyHint ->
assertThat(proxyHint.getProxiedInterfaces()).isEqualTo(TypeReference.listOf(
environment.getClass().getInterfaces())));
});
}

private <T> void testAutowiredComponent(Class<T> type, BiConsumer<T, GenerationContext> assertions) {
testAutowiredComponent(type, new RootBeanDefinition(type), assertions);
}

private <T> void testAutowiredComponent(Class<T> type, RootBeanDefinition beanDefinition,
BiConsumer<T, GenerationContext> assertions) {
GenericApplicationContext applicationContext = new GenericApplicationContext();
applicationContext.getDefaultListableBeanFactory().setAutowireCandidateResolver(
new ContextAnnotationAutowireCandidateResolver());
applicationContext.registerBeanDefinition(AnnotationConfigUtils.AUTOWIRED_ANNOTATION_PROCESSOR_BEAN_NAME,
BeanDefinitionBuilder
.rootBeanDefinition(AutowiredAnnotationBeanPostProcessor.class)
.setRole(BeanDefinition.ROLE_INFRASTRUCTURE).getBeanDefinition());
applicationContext.registerBeanDefinition("testComponent", beanDefinition);
TestGenerationContext generationContext = processAheadOfTime(applicationContext);
testCompiledResult(generationContext, (initializer, compiled) -> {
GenericApplicationContext freshApplicationContext = toFreshApplicationContext(initializer);
assertThat(freshApplicationContext.getBeanDefinitionNames()).containsOnly("testComponent");
assertions.accept(freshApplicationContext.getBean("testComponent", type), generationContext);
});
}

@Test
void processAheadOfTimeWhenHasInitDestroyMethods() {
GenericApplicationContext applicationContext = new GenericApplicationContext();
Expand Down Expand Up @@ -189,10 +280,14 @@ private static TestGenerationContext processAheadOfTime(GenericApplicationContex
return generationContext;
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void testCompiledResult(GenericApplicationContext applicationContext,
BiConsumer<ApplicationContextInitializer<GenericApplicationContext>, Compiled> result) {
TestGenerationContext generationContext = processAheadOfTime(applicationContext);
testCompiledResult(processAheadOfTime(applicationContext), result);
}

@SuppressWarnings({ "rawtypes", "unchecked" })
private void testCompiledResult(TestGenerationContext generationContext,
BiConsumer<ApplicationContextInitializer<GenericApplicationContext>, Compiled> result) {
TestCompiler.forSystem().withFiles(generationContext.getGeneratedFiles()).compile(compiled ->
result.accept(compiled.getInstance(ApplicationContextInitializer.class), compiled));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.context.testfixture.context.generator.annotation;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.env.Environment;
import org.springframework.core.io.ResourceLoader;

public class LazyAutowiredFieldComponent {

@Lazy
@Autowired
private Environment environment;

@Autowired
private ResourceLoader resourceLoader;

public Environment getEnvironment() {
return this.environment;
}


public ResourceLoader getResourceLoader() {
return this.resourceLoader;
}
}
Loading

0 comments on commit 4557158

Please sign in to comment.