Skip to content

Commit

Permalink
Fix @ConditionalOnBean with annotation early FactoryBean initialization
Browse files Browse the repository at this point in the history
Update `OnBeanCondition` with a variant of `getBeanNamesForAnnotation`
that does not cause early `FactoryBean` initialization.

Fixes gh-38473
  • Loading branch information
philwebb committed Nov 22, 2023
1 parent e7aeeb8 commit bc504a8
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.config.SingletonBeanRegistry;
import org.springframework.boot.autoconfigure.AutoConfigurationMetadata;
import org.springframework.boot.autoconfigure.condition.ConditionMessage.Style;
import org.springframework.context.annotation.Bean;
Expand Down Expand Up @@ -279,7 +280,7 @@ private Class<? extends Annotation> resolveAnnotationType(ClassLoader classLoade

private Set<String> collectBeanNamesForAnnotation(ListableBeanFactory beanFactory,
Class<? extends Annotation> annotationType, boolean considerHierarchy, Set<String> result) {
result = addAll(result, beanFactory.getBeanNamesForAnnotation(annotationType));
result = addAll(result, getBeanNamesForAnnotation(beanFactory, annotationType));
if (considerHierarchy) {
BeanFactory parent = ((HierarchicalBeanFactory) beanFactory).getParentBeanFactory();
if (parent instanceof ListableBeanFactory listableBeanFactory) {
Expand All @@ -289,6 +290,30 @@ private Set<String> collectBeanNamesForAnnotation(ListableBeanFactory beanFactor
return result;
}

private String[] getBeanNamesForAnnotation(ListableBeanFactory beanFactory,
Class<? extends Annotation> annotationType) {
Set<String> foundBeanNames = new LinkedHashSet<>();
for (String beanName : beanFactory.getBeanDefinitionNames()) {
if (beanFactory instanceof ConfigurableListableBeanFactory configurableListableBeanFactory) {
BeanDefinition beanDefinition = configurableListableBeanFactory.getBeanDefinition(beanName);
if (beanDefinition != null && beanDefinition.isAbstract()) {
continue;
}
}
if (beanFactory.findAnnotationOnBean(beanName, annotationType, false) != null) {
foundBeanNames.add(beanName);
}
}
if (beanFactory instanceof SingletonBeanRegistry singletonBeanRegistry) {
for (String beanName : singletonBeanRegistry.getSingletonNames()) {
if (beanFactory.findAnnotationOnBean(beanName, annotationType) != null) {
foundBeanNames.add(beanName);
}
}
}
return foundBeanNames.toArray(String[]::new);
}

private boolean containsBean(ConfigurableListableBeanFactory beanFactory, String beanName,
boolean considerHierarchy) {
if (considerHierarchy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.boot.autoconfigure.condition.ConditionEvaluationReport.ConditionAndOutcomes;
import org.springframework.boot.test.context.assertj.AssertableApplicationContext;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
Expand Down Expand Up @@ -131,6 +133,19 @@ void beanProducedByFactoryBeanIsConsideredWhenMatchingOnAnnotation() {
});
}

@Test
void beanProducedByFactoryBeanIsConsideredWhenMatchingOnAnnotation2() {
this.contextRunner
.withUserConfiguration(EarlyInitializationFactoryBeanConfiguration.class,
EarlyInitializationOnAnnotationFactoryBeanConfiguration.class)
.run((context) -> {
assertThat(EarlyInitializationFactoryBeanConfiguration.calledWhenNoFrozen).as("calledWhenNoFrozen")
.isFalse();
assertThat(context).hasBean("bar");
assertThat(context).hasSingleBean(ExampleBean.class);
});
}

private void hasBarBean(AssertableApplicationContext context) {
assertThat(context).hasBean("bar");
assertThat(context.getBean("bar")).isEqualTo("bar");
Expand Down Expand Up @@ -352,6 +367,35 @@ String bar() {

}

@Configuration(proxyBeanMethods = false)
static class EarlyInitializationFactoryBeanConfiguration {

static boolean calledWhenNoFrozen;

@Bean
@TestAnnotation
static FactoryBean<?> exampleBeanFactoryBean(ApplicationContext applicationContext) {
// NOTE: must be static and return raw FactoryBean and not the subclass so
// Spring can't guess type
ConfigurableListableBeanFactory beanFactory = ((ConfigurableApplicationContext) applicationContext)
.getBeanFactory();
calledWhenNoFrozen = calledWhenNoFrozen || !beanFactory.isConfigurationFrozen();
return new ExampleFactoryBean();
}

}

@Configuration(proxyBeanMethods = false)
@ConditionalOnBean(annotation = TestAnnotation.class)
static class EarlyInitializationOnAnnotationFactoryBeanConfiguration {

@Bean
String bar() {
return "bar";
}

}

static class WithPropertyPlaceholderClassNameRegistrar implements ImportBeanDefinitionRegistrar {

@Override
Expand Down Expand Up @@ -518,7 +562,7 @@ static class OtherExampleBean extends ExampleBean {

}

@Target(ElementType.TYPE)
@Target({ ElementType.TYPE, ElementType.METHOD })
@Retention(RetentionPolicy.RUNTIME)
@Documented
@interface TestAnnotation {
Expand Down

0 comments on commit bc504a8

Please sign in to comment.