Skip to content

Commit

Permalink
Restore order of setImportMetadata in AOT optimized contexts
Browse files Browse the repository at this point in the history
This commit adapts the generated code for handling ImportAware to
register a bean definition rather than adding the BeanPostProcessor
directly on the beanFactory. The previous arrangement put the
post processor handling import aware callbacks first on the list,
leading to inconsistent callback orders.

Tests have been adapted to validate this exact scenario.

Closes gh-28915
  • Loading branch information
snicoll committed Aug 3, 2022
1 parent 058b5fe commit 1fdd91e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.beans.factory.support.BeanNameGenerator;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.ApplicationStartupAware;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.ResourceLoaderAware;
Expand Down Expand Up @@ -518,6 +519,10 @@ private static class AotContribution implements BeanFactoryInitializationAotCont

private static final String MAPPINGS_VARIABLE = "mappings";

private static final String BEAN_DEFINITION_VARIABLE = "beanDefinition";

private static final String BEAN_NAME = "org.springframework.context.annotation.internalImportAwareAotProcessor";


private final ConfigurableListableBeanFactory beanFactory;

Expand Down Expand Up @@ -561,9 +566,12 @@ private CodeBlock generateAddPostProcessorCode(Map<String, String> mappings) {
MAPPINGS_VARIABLE, HashMap.class);
mappings.forEach((type, from) -> builder.addStatement("$L.put($S, $S)",
MAPPINGS_VARIABLE, type, from));
builder.addStatement("$L.addBeanPostProcessor(new $T($L))",
BEAN_FACTORY_VARIABLE, ImportAwareAotBeanPostProcessor.class,
MAPPINGS_VARIABLE);
builder.addStatement("$T $L = new $T($T.class)", RootBeanDefinition.class,
BEAN_DEFINITION_VARIABLE, RootBeanDefinition.class, ImportAwareAotBeanPostProcessor.class);
builder.addStatement("$L.getConstructorArgumentValues().addIndexedArgumentValue(0, $L)",
BEAN_DEFINITION_VARIABLE, MAPPINGS_VARIABLE);
builder.addStatement("$L.registerBeanDefinition($S, $L)",
BEAN_FACTORY_VARIABLE, BEAN_NAME, BEAN_DEFINITION_VARIABLE);
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.context.annotation;

import java.util.ArrayList;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

Expand All @@ -28,14 +30,20 @@
import org.springframework.aot.hint.ResourcePatternHint;
import org.springframework.aot.test.generator.compile.Compiled;
import org.springframework.aot.test.generator.compile.TestCompiler;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.beans.testfixture.beans.factory.aot.MockBeanFactoryInitializationCode;
import org.springframework.beans.testfixture.beans.factory.generator.SimpleConfiguration;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.context.testfixture.context.generator.annotation.ImportAwareConfiguration;
import org.springframework.context.testfixture.context.generator.annotation.ImportConfiguration;
import org.springframework.core.testfixture.aot.generate.TestGenerationContext;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.javapoet.CodeBlock;
import org.springframework.javapoet.MethodSpec;
import org.springframework.javapoet.ParameterizedTypeName;
Expand All @@ -62,19 +70,43 @@ class ConfigurationClassPostProcessorAotContributionTests {
this.beanFactoryInitializationCode = new MockBeanFactoryInitializationCode(this.generationContext);
}

@Test
void processAheadOfTimeWhenNoImportAwareConfigurationReturnsNull() {
assertThat(getContribution(SimpleConfiguration.class)).isNull();
}

@Test
void applyToWhenHasImportAwareConfigurationRegistersBeanPostProcessorWithMapEntry() {
BeanFactoryInitializationAotContribution contribution = getContribution(
ImportConfiguration.class);
contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode);
compile((initializer, compiled) -> {
DefaultListableBeanFactory freshBeanFactory = new DefaultListableBeanFactory();
GenericApplicationContext freshContext = new GenericApplicationContext();
DefaultListableBeanFactory freshBeanFactory = freshContext.getDefaultListableBeanFactory();
initializer.accept(freshBeanFactory);
freshContext.refresh();
assertThat(freshBeanFactory.getBeanPostProcessors()).filteredOn(ImportAwareAotBeanPostProcessor.class::isInstance)
.singleElement().satisfies(postProcessor -> assertPostProcessorEntry(postProcessor, ImportAwareConfiguration.class,
ImportConfiguration.class));
});
}

@Test
void applyToWhenHasImportAwareConfigurationRegistersBeanPostProcessorAfterApplicationContextAwareProcessor() {
BeanFactoryInitializationAotContribution contribution = getContribution(
ImportConfiguration.class);
contribution.applyTo(this.generationContext, this.beanFactoryInitializationCode);
compile((initializer, compiled) -> {
GenericApplicationContext freshContext = new AnnotationConfigApplicationContext();
DefaultListableBeanFactory freshBeanFactory = freshContext.getDefaultListableBeanFactory();
initializer.accept(freshBeanFactory);
ImportAwareAotBeanPostProcessor postProcessor = (ImportAwareAotBeanPostProcessor) freshBeanFactory
.getBeanPostProcessors().get(0);
assertPostProcessorEntry(postProcessor, ImportAwareConfiguration.class,
ImportConfiguration.class);
freshContext.registerBean(TestAwareCallbackConfiguration.class);
freshContext.refresh();
TestAwareCallbackBean bean = freshContext.getBean(TestAwareCallbackBean.class);
assertThat(bean.instances).hasSize(2);
assertThat(bean.instances.get(0)).isEqualTo(freshContext);
assertThat(bean.instances.get(1)).isInstanceOfSatisfying(AnnotationMetadata.class, metadata ->
assertThat(metadata.getClassName()).isEqualTo(TestAwareCallbackConfiguration.class.getName()));
});
}

Expand All @@ -91,11 +123,6 @@ void applyToWhenHasImportAwareConfigurationRegistersHints() {
+ "ImportConfiguration.class"));
}

@Test
void processAheadOfTimeWhenNoImportAwareConfigurationReturnsNull() {
assertThat(getContribution(SimpleConfiguration.class)).isNull();
}

@Nullable
private BeanFactoryInitializationAotContribution getContribution(Class<?> type) {
DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory();
Expand All @@ -105,6 +132,13 @@ private BeanFactoryInitializationAotContribution getContribution(Class<?> type)
return postProcessor.processAheadOfTime(beanFactory);
}

private void assertPostProcessorEntry(BeanPostProcessor postProcessor,
Class<?> key, Class<?> value) {
assertThat(postProcessor).extracting("importsMapping")
.asInstanceOf(InstanceOfAssertFactories.MAP)
.containsExactly(entry(key.getName(), value.getName()));
}

@SuppressWarnings("unchecked")
private void compile(BiConsumer<Consumer<DefaultListableBeanFactory>, Compiled> result) {
MethodReference methodReference = this.beanFactoryInitializationCode
Expand All @@ -122,11 +156,26 @@ private void compile(BiConsumer<Consumer<DefaultListableBeanFactory>, Compiled>
result.accept(compiled.getInstance(Consumer.class), compiled));
}

private void assertPostProcessorEntry(ImportAwareAotBeanPostProcessor postProcessor,
Class<?> key, Class<?> value) {
assertThat(postProcessor).extracting("importsMapping")
.asInstanceOf(InstanceOfAssertFactories.MAP)
.containsExactly(entry(key.getName(), value.getName()));
@Configuration(proxyBeanMethods = false)
@Import(TestAwareCallbackBean.class)
static class TestAwareCallbackConfiguration {

}

static class TestAwareCallbackBean implements ImportAware, ApplicationContextAware {

private final List<Object> instances = new ArrayList<>();

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.instances.add(applicationContext);
}

@Override
public void setImportMetadata(AnnotationMetadata importMetadata) {
this.instances.add(importMetadata);
}

}

}

0 comments on commit 1fdd91e

Please sign in to comment.