diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java index a098524d396e..4e2ca911c3c9 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/SpringApplication.java @@ -17,12 +17,15 @@ package org.springframework.boot; import java.lang.StackWalker.StackFrame; +import java.lang.reflect.Method; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; +import java.util.IdentityHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -38,14 +41,17 @@ import org.springframework.aot.AotDetector; import org.springframework.beans.BeansException; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.groovy.GroovyBeanDefinitionReader; import org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanNameGenerator; import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.beans.factory.xml.XmlBeanDefinitionReader; import org.springframework.boot.Banner.Mode; import org.springframework.boot.context.properties.bind.Bindable; @@ -68,6 +74,8 @@ import org.springframework.context.support.AbstractApplicationContext; import org.springframework.context.support.GenericApplicationContext; import org.springframework.core.GenericTypeResolver; +import org.springframework.core.OrderComparator; +import org.springframework.core.OrderComparator.OrderSourceProvider; import org.springframework.core.Ordered; import org.springframework.core.annotation.AnnotationAwareOrderComparator; import org.springframework.core.annotation.Order; @@ -746,35 +754,42 @@ protected void refresh(ConfigurableApplicationContext applicationContext) { protected void afterRefresh(ConfigurableApplicationContext context, ApplicationArguments args) { } - private void callRunners(ApplicationContext context, ApplicationArguments args) { - context.getBeanProvider(Runner.class).orderedStream().forEach((runner) -> { - if (runner instanceof ApplicationRunner applicationRunner) { - callRunner(applicationRunner, args); - } - if (runner instanceof CommandLineRunner commandLineRunner) { - callRunner(commandLineRunner, args); - } - }); + private void callRunners(ConfigurableApplicationContext context, ApplicationArguments args) { + ConfigurableListableBeanFactory beanFactory = context.getBeanFactory(); + String[] beanNames = beanFactory.getBeanNamesForType(Runner.class); + Map instancesToBeanNames = new IdentityHashMap<>(); + for (String beanName : beanNames) { + instancesToBeanNames.put(beanFactory.getBean(beanName, Runner.class), beanName); + } + Comparator comparator = getOrderComparator(beanFactory) + .withSourceProvider(new FactoryAwareOrderSourceProvider(beanFactory, instancesToBeanNames)); + instancesToBeanNames.keySet().stream().sorted(comparator).forEach((runner) -> callRunner(runner, args)); } - private void callRunner(ApplicationRunner runner, ApplicationArguments args) { - try { - (runner).run(args); - } - catch (Exception ex) { - throw new IllegalStateException("Failed to execute ApplicationRunner", ex); - } + private OrderComparator getOrderComparator(ConfigurableListableBeanFactory beanFactory) { + Comparator dependencyComparator = (beanFactory instanceof DefaultListableBeanFactory defaultListableBeanFactory) + ? defaultListableBeanFactory.getDependencyComparator() : null; + return (dependencyComparator instanceof OrderComparator orderComparator) ? orderComparator + : AnnotationAwareOrderComparator.INSTANCE; } - private void callRunner(CommandLineRunner runner, ApplicationArguments args) { - try { - (runner).run(args.getSourceArgs()); + private void callRunner(Runner runner, ApplicationArguments args) { + if (runner instanceof ApplicationRunner) { + callRunner(ApplicationRunner.class, runner, (applicationRunner) -> applicationRunner.run(args)); } - catch (Exception ex) { - throw new IllegalStateException("Failed to execute CommandLineRunner", ex); + if (runner instanceof CommandLineRunner) { + callRunner(CommandLineRunner.class, runner, + (commandLineRunner) -> commandLineRunner.run(args.getSourceArgs())); } } + @SuppressWarnings("unchecked") + private void callRunner(Class type, Runner runner, ThrowingConsumer call) { + call.throwing( + (message, ex) -> new IllegalStateException("Failed to execute " + ClassUtils.getShortName(type), ex)) + .accept((R) runner); + } + private void handleRunFailure(ConfigurableApplicationContext context, Throwable exception, SpringApplicationRunListeners listeners) { try { @@ -1598,4 +1613,41 @@ public SpringApplicationRunListener getRunListener(SpringApplication springAppli } + /** + * {@link OrderSourceProvider} used to obtain factory method and target type order + * sources. Based on internal {@link DefaultListableBeanFactory} code. + */ + private class FactoryAwareOrderSourceProvider implements OrderSourceProvider { + + private final ConfigurableBeanFactory beanFactory; + + private final Map instancesToBeanNames; + + FactoryAwareOrderSourceProvider(ConfigurableBeanFactory beanFactory, Map instancesToBeanNames) { + this.beanFactory = beanFactory; + this.instancesToBeanNames = instancesToBeanNames; + } + + @Override + public Object getOrderSource(Object obj) { + String beanName = this.instancesToBeanNames.get(obj); + return (beanName != null) ? getOrderSource(beanName, obj.getClass()) : null; + } + + private Object getOrderSource(String beanName, Class instanceType) { + try { + RootBeanDefinition beanDefinition = (RootBeanDefinition) this.beanFactory + .getMergedBeanDefinition(beanName); + Method factoryMethod = beanDefinition.getResolvedFactoryMethod(); + Class targetType = beanDefinition.getTargetType(); + targetType = (targetType != instanceType) ? targetType : null; + return Stream.of(factoryMethod, targetType).filter(Objects::nonNull).toArray(); + } + catch (NoSuchBeanDefinitionException ex) { + return null; + } + } + + } + } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java index eb8c7e3cc11b..8f8b714bd235 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/SpringApplicationTests.java @@ -60,6 +60,7 @@ import org.springframework.boot.availability.AvailabilityState; import org.springframework.boot.availability.LivenessState; import org.springframework.boot.availability.ReadinessState; +import org.springframework.boot.builder.ParentContextApplicationContextInitializer; import org.springframework.boot.builder.SpringApplicationBuilder; import org.springframework.boot.context.event.ApplicationContextInitializedEvent; import org.springframework.boot.context.event.ApplicationEnvironmentPreparedEvent; @@ -630,6 +631,19 @@ void runCommandLineRunnersAndApplicationRunners() { assertThat(this.context).has(runTestRunnerBean("runnerC")); } + @Test + void runCommandLineRunnersAndApplicationRunnersWithParentContext() { + SpringApplication application = new SpringApplication(CommandLineRunConfig.class); + application.setWebApplicationType(WebApplicationType.NONE); + application.addInitializers(new ParentContextApplicationContextInitializer( + new AnnotationConfigApplicationContext(CommandLineRunParentConfig.class))); + this.context = application.run("arg"); + assertThat(this.context).has(runTestRunnerBean("runnerA")); + assertThat(this.context).has(runTestRunnerBean("runnerB")); + assertThat(this.context).has(runTestRunnerBean("runnerC")); + assertThat(this.context).doesNotHave(runTestRunnerBean("runnerP")); + } + @Test void runCommandLineRunnersAndApplicationRunnersUsingOrderOnBeanDefinitions() { SpringApplication application = new SpringApplication(BeanDefinitionOrderRunnerConfig.class); @@ -1432,7 +1446,7 @@ public boolean matches(ConfigurableEnvironment value) { }; } - private Condition runTestRunnerBean(final String name) { + private Condition runTestRunnerBean(String name) { return new Condition<>("run testrunner bean") { @Override @@ -1642,17 +1656,27 @@ static class CommandLineRunConfig { @Bean TestCommandLineRunner runnerC() { - return new TestCommandLineRunner(Ordered.LOWEST_PRECEDENCE, "runnerB", "runnerA"); + return new TestCommandLineRunner("runnerC", Ordered.LOWEST_PRECEDENCE, "runnerB", "runnerA"); } @Bean TestApplicationRunner runnerB() { - return new TestApplicationRunner(Ordered.LOWEST_PRECEDENCE - 1, "runnerA"); + return new TestApplicationRunner("runnerB", Ordered.LOWEST_PRECEDENCE - 1, "runnerA"); } @Bean TestCommandLineRunner runnerA() { - return new TestCommandLineRunner(Ordered.HIGHEST_PRECEDENCE); + return new TestCommandLineRunner("runnerA", Ordered.HIGHEST_PRECEDENCE); + } + + } + + @Configuration(proxyBeanMethods = false) + static class CommandLineRunParentConfig { + + @Bean + TestCommandLineRunner runnerP() { + return new TestCommandLineRunner("runnerP", Ordered.LOWEST_PRECEDENCE); } } @@ -1861,12 +1885,16 @@ boolean hasRun() { static class TestCommandLineRunner extends AbstractTestRunner implements CommandLineRunner { - TestCommandLineRunner(int order, String... expectedBefore) { + private final String name; + + TestCommandLineRunner(String name, int order, String... expectedBefore) { super(order, expectedBefore); + this.name = name; } @Override public void run(String... args) { + System.out.println(">>> " + this.name); markAsRan(); } @@ -1874,12 +1902,16 @@ public void run(String... args) { static class TestApplicationRunner extends AbstractTestRunner implements ApplicationRunner { - TestApplicationRunner(int order, String... expectedBefore) { + private final String name; + + TestApplicationRunner(String name, int order, String... expectedBefore) { super(order, expectedBefore); + this.name = name; } @Override public void run(ApplicationArguments args) { + System.out.println(">>> " + this.name); markAsRan(); }