Skip to content

Commit

Permalink
Consistently set target type in generated code
Browse files Browse the repository at this point in the history
Previously, a bean definition that is optimized AOT could have
different metadata based on whether its resolved type had a generic or
not. This is due to RootBeanDefinition taking either a Class or a
ResolvableType doing fundamentally different things. While the former
sets the bean class which is to little use with an instance supplier,
the latter specifies the target type of the bean.

This commit sets the target type of the bean, using the existing
setter methods that take either a class or a ResolvableType and set the
same attribute consistently.

Closes gh-30689
  • Loading branch information
snicoll committed Jun 21, 2023
1 parent 4879b56 commit 6c42f37
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@

import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.function.Predicate;

Expand Down Expand Up @@ -47,12 +48,6 @@
*/
class DefaultBeanRegistrationCodeFragments implements BeanRegistrationCodeFragments {

/**
* The variable name used to hold the bean type.
*/
private static final String BEAN_TYPE_VARIABLE = "beanType";


private final BeanRegistrationsCode beanRegistrationsCode;

private final RegisteredBean registeredBean;
Expand Down Expand Up @@ -118,19 +113,45 @@ public CodeBlock generateNewBeanDefinitionCode(GenerationContext generationConte
ResolvableType beanType, BeanRegistrationCode beanRegistrationCode) {

CodeBlock.Builder code = CodeBlock.builder();
code.addStatement(generateBeanTypeCode(beanType));
RootBeanDefinition mergedBeanDefinition = this.registeredBean.getMergedBeanDefinition();
Class<?> beanClass = (mergedBeanDefinition.hasBeanClass()
? ClassUtils.getUserClass(mergedBeanDefinition.getBeanClass()) : null);
CodeBlock beanClassCode = generateBeanClassCode(
beanRegistrationCode.getClassName().packageName(), beanClass);
code.addStatement("$T $L = new $T($L)", RootBeanDefinition.class,
BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, BEAN_TYPE_VARIABLE);
BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, beanClassCode);
if (targetTypeNecessary(beanType, beanClass)) {
code.addStatement("$L.setTargetType($L)", BEAN_DEFINITION_VARIABLE,
generateBeanTypeCode(beanType));
}
return code.build();
}

private CodeBlock generateBeanClassCode(String targetPackage, @Nullable Class<?> beanClass) {
if (beanClass != null) {
if (Modifier.isPublic(beanClass.getModifiers()) || targetPackage.equals(beanClass.getPackageName())) {
return CodeBlock.of("$T.class", beanClass);
}
else {
return CodeBlock.of("$S", beanClass.getName());
}
}
return CodeBlock.of("");
}

private CodeBlock generateBeanTypeCode(ResolvableType beanType) {
if (!beanType.hasGenerics()) {
return CodeBlock.of("$T<?> $L = $T.class", Class.class, BEAN_TYPE_VARIABLE,
ClassUtils.getUserClass(beanType.toClass()));
return CodeBlock.of("$T.class", ClassUtils.getUserClass(beanType.toClass()));
}
return ResolvableTypeCodeGenerator.generateCode(beanType);
}

private boolean targetTypeNecessary(ResolvableType beanType, @Nullable Class<?> beanClass) {
if (beanType.hasGenerics() || beanClass == null) {
return true;
}
return CodeBlock.of("$T $L = $L", ResolvableType.class, BEAN_TYPE_VARIABLE,
ResolvableTypeCodeGenerator.generateCode(beanType));
return (!beanType.toClass().equals(beanClass)
|| this.registeredBean.getMergedBeanDefinition().getFactoryMethodName() != null);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
import org.springframework.beans.testfixture.beans.factory.aot.InnerBeanConfiguration;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanRegistrationsCode;
import org.springframework.beans.testfixture.beans.factory.aot.SimpleBean;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Two;
import org.springframework.core.ResolvableType;
import org.springframework.core.test.io.support.MockSpringFactoriesLoader;
import org.springframework.core.test.tools.CompileWithForkedClassLoader;
Expand All @@ -56,6 +60,7 @@
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterizedTypeName;
import org.springframework.util.ReflectionUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
Expand Down Expand Up @@ -89,8 +94,10 @@ class BeanDefinitionMethodGeneratorTests {


@Test
void generateBeanDefinitionMethodGeneratesMethod() {
RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class));
void generateBeanDefinitionMethodWithOnlyTargetTypeDoesNotSetBeanClass() {
RootBeanDefinition beanDefinition = new RootBeanDefinition();
beanDefinition.setTargetType(TestBean.class);
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
Expand All @@ -99,7 +106,67 @@ void generateBeanDefinitionMethodGeneratesMethod() {
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("beanType = TestBean.class");
assertThat(sourceFile).contains("new RootBeanDefinition()");
assertThat(sourceFile).contains("setTargetType(TestBean.class)");
assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
}

@Test
void generateBeanDefinitionMethodSpecifiesBeanClassIfSet() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(TestBean.class);
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, this.beanRegistrationsCode);
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("new RootBeanDefinition(TestBean.class)");
assertThat(sourceFile).doesNotContain("setTargetType(");
assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
}

@Test
void generateBeanDefinitionMethodSpecifiesBeanClassAndTargetTypIfDifferent() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class);
beanDefinition.setTargetType(Implementation.class);
beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean"));
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, this.beanRegistrationsCode);
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("new RootBeanDefinition(TestHierarchy.One.class)");
assertThat(sourceFile).contains("setTargetType(TestHierarchy.Implementation.class)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
}

@Test
void generateBeanDefinitionMethodUSeBeanClassNameIfNotReachable() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(PackagePrivateTestBean.class);
beanDefinition.setTargetType(TestBean.class);
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, this.beanRegistrationsCode);
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("new RootBeanDefinition(\"org.springframework.beans.factory.aot.PackagePrivateTestBean\"");
assertThat(sourceFile).contains("setTargetType(TestBean.class)");
assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
Expand All @@ -116,7 +183,6 @@ void generateBeanDefinitionMethodGeneratesMethodWithInstanceSupplier() {
compile(method, (actual, compiled) -> {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains("beanType = TestBean.class");
assertThat(sourceFile).contains("setInstanceSupplier(TestBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
Expand Down Expand Up @@ -183,12 +249,26 @@ void generateBeanDefinitionMethodWhenHasGenericsGeneratesMethod() {
SourceFile sourceFile = compiled.getSourceFile(".*BeanDefinitions");
assertThat(sourceFile).contains("Get the bean definition for 'testBean'");
assertThat(sourceFile).contains(
"beanType = ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class)");
"setTargetType(ResolvableType.forClassWithGenerics(GenericBean.class, Integer.class))");
assertThat(sourceFile).contains("setInstanceSupplier(GenericBean::new)");
assertThat(actual).isInstanceOf(RootBeanDefinition.class);
});
}

@Test
void generateBeanDefinitionMethodWhenHasExplicitResolvableType() {
RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class);
beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean"));
beanDefinition.setTargetType(Two.class);
RegisteredBean registeredBean = registerBean(beanDefinition);
BeanDefinitionMethodGenerator generator = new BeanDefinitionMethodGenerator(
this.methodGeneratorFactory, registeredBean, null,
Collections.emptyList());
MethodReference method = generator.generateBeanDefinitionMethod(
this.generationContext, this.beanRegistrationsCode);
compile(method, (actual, compiled) -> assertThat(actual.getResolvableType().resolve()).isEqualTo(Two.class));
}

@Test
void generateBeanDefinitionMethodWhenHasInstancePostProcessorGeneratesMethod() {
RegisteredBean registeredBean = registerBean(new RootBeanDefinition(TestBean.class));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.beans.testfixture.beans.factory.aot;

/**
* A hierarchy where the exposed type of a bean is a partial signature.
*
* @author Stephane Nicoll
*/
public class TestHierarchy {

public interface One {
}

public interface Two {
}

public static class Implementation implements One, Two {
}

public static One oneBean() {
return new Implementation();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RegisteredBean;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Implementation;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.One;
import org.springframework.beans.testfixture.beans.factory.aot.TestHierarchy.Two;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.AnnotationConfigUtils;
Expand Down Expand Up @@ -75,6 +80,7 @@
import org.springframework.core.test.tools.Compiled;
import org.springframework.core.test.tools.TestCompiler;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.util.ReflectionUtils;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -327,6 +333,22 @@ void processAheadOfTimeWithInjectionPoint() {
});
}

@Test // gh-30689
void processAheadOfTimeWithExplicitResolvableType() {
GenericApplicationContext applicationContext = new GenericApplicationContext();
DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory();
RootBeanDefinition beanDefinition = new RootBeanDefinition(One.class);
beanDefinition.setResolvedFactoryMethod(ReflectionUtils.findMethod(TestHierarchy.class, "oneBean"));
// Override target type
beanDefinition.setTargetType(Two.class);
beanFactory.registerBeanDefinition("hierarchyBean", beanDefinition);
testCompiledResult(applicationContext, (initializer, compiled) -> {
GenericApplicationContext freshApplicationContext = toFreshApplicationContext(initializer);
assertThat(freshApplicationContext.getBean(Two.class))
.isInstanceOf(Implementation.class);
});
}

@Nested
@CompileWithForkedClassLoader
class ConfigurationClassCglibProxy {
Expand Down

0 comments on commit 6c42f37

Please sign in to comment.