Skip to content

Commit

Permalink
feat: Support multiple SolverManager instances in Spring Boot (Timefo…
Browse files Browse the repository at this point in the history
…ldAI#590)

This pull request updates the Spring boot extension to support multiple
instances of Solver in an application. The new logic utilizes a map to
keep track of the configurations, and nothing changes for the current
behavior for single solvers, which are mapped by default to the key name
`default`.
  • Loading branch information
zepfred authored Jan 26, 2024
1 parent 1277507 commit 9009fa4
Show file tree
Hide file tree
Showing 45 changed files with 2,502 additions and 291 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ BenchmarkConfigBuildItem registerAdditionalBeans(BuildProducer<AdditionalBeanBui
if (solverConfigBuildItem.getSolvetConfigMap().size() > 1) {
throw new ConfigurationException("""
When defining multiple solvers, the benchmark feature is not enabled.
Consider using separate <solverBenchmark> instances for evaluating different solver configurations.
""");
Consider using separate <solverBenchmark> instances for evaluating different solver configurations.""");
}
if (solverConfigBuildItem.getGeneratedGizmoClasses() == null) {
log.warn("Skipping Timefold Benchmark extension because the Timefold extension was skipped.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class TimefoldBenchmarkProcessorMultipleSolversConfigTest {
TestdataQuarkusSolution.class, TestdataQuarkusConstraintProvider.class))
.assertException(t -> assertThat(t)
.isInstanceOf(ConfigurationException.class)
.hasMessageContaining("""
When defining multiple solvers, the benchmark feature is not enabled.
Consider using separate <solverBenchmark> instances for evaluating different solver configurations.
"""));
.hasMessageContaining(
"""
When defining multiple solvers, the benchmark feature is not enabled.
Consider using separate <solverBenchmark> instances for evaluating different solver configurations."""));

@Test
void benchmark() throws ExecutionException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,12 @@ Some solver configs (%s) don't specify a %s score class, yet there are multiple
}

private void assertEmptyInstances(IndexView indexView, DotName dotName) {
// Validate the solution class
Collection<AnnotationInstance> annotationInstanceCollection = indexView.getAnnotations(dotName);
// No solution class
if (annotationInstanceCollection.isEmpty()) {
try {
throw new IllegalStateException(
"No classes found with a @%s annotation.".formatted(Class.forName(dotName.local()).getSimpleName()));
"No classes were found with a @%s annotation."
.formatted(Class.forName(dotName.local()).getSimpleName()));
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -517,7 +516,7 @@ private SolverConfig createSolverConfig(ClassLoader classLoader, String solverNa
.formatted(solverUrl);
if (!solverName.equals(TimefoldBuildTimeConfig.DEFAULT_SOLVER_NAME)) {
message =
"Invalid quarkus.timefold.\"%s\".solverConfigXML property (%s): that classpath resource does not exist."
"Invalid quarkus.timefold.solver.\"%s\".solverConfigXML property (%s): that classpath resource does not exist."
.formatted(solverName, solverUrl);
}
throw new ConfigurationException(message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TimefoldProcessorMultipleSolversInvalidEntityClassTest {
.assertException(t -> assertThat(t)
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining(
"No classes found with a @PlanningEntity annotation."));
"No classes were found with a @PlanningEntity annotation."));

@Test
void test() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TimefoldProcessorMultipleSolversInvalidSolutionClassTest {
.assertException(t -> assertThat(t)
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining(
"No classes found with a @PlanningSolution annotation."));
"No classes were found with a @PlanningSolution annotation."));

// Multiple classes
@RegisterExtension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class TimefoldProcessorSolverInvalidEntityClassTest {
.assertException(t -> assertThat(t)
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining(
"No classes found with a @PlanningEntity annotation."));
"No classes were found with a @PlanningEntity annotation."));

@Test
void test() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TimefoldProcessorSolverInvalidSolutionClassTest {
.assertException(t -> assertThat(t)
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining(
"No classes found with a @PlanningSolution annotation."));
"No classes were found with a @PlanningSolution annotation."));

// Multiple classes
@RegisterExtension
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package ai.timefold.solver.spring.boot.autoconfigure;

import static ai.timefold.solver.spring.boot.autoconfigure.util.LambdaUtils.rethrowFunction;
import static java.util.Collections.emptyList;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import ai.timefold.solver.core.api.domain.entity.PlanningEntity;
import ai.timefold.solver.core.api.domain.solution.PlanningSolution;

import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.boot.autoconfigure.AutoConfigurationPackages;
import org.springframework.boot.autoconfigure.domain.EntityScanPackages;
import org.springframework.boot.autoconfigure.domain.EntityScanner;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.core.type.filter.AssignableTypeFilter;
import org.springframework.util.ClassUtils;

public class IncludeAbstractClassesEntityScanner extends EntityScanner {

private final ApplicationContext context;

public IncludeAbstractClassesEntityScanner(ApplicationContext context) {
super(context);
this.context = context;
}

public <T> Class<? extends T> findFirstImplementingClass(Class<T> targetClass) {
List<Class<? extends T>> classes = findImplementingClassList(targetClass);
if (!classes.isEmpty()) {
return classes.get(0);
}
return null;
}

private Set<String> findPackages() {
Set<String> packages = new HashSet<>();
packages.addAll(AutoConfigurationPackages.get(context));
EntityScanPackages entityScanPackages = EntityScanPackages.get(context);
packages.addAll(entityScanPackages.getPackageNames());
return packages;
}

public <T> List<Class<? extends T>> findImplementingClassList(Class<T> targetClass) {
if (!AutoConfigurationPackages.has(context)) {
return emptyList();
}
ClassPathScanningCandidateComponentProvider scanner = new ClassPathScanningCandidateComponentProvider(false);
scanner.setEnvironment(context.getEnvironment());
scanner.setResourceLoader(context);
scanner.addIncludeFilter(new AssignableTypeFilter(targetClass));
Set<String> packages = findPackages();
return packages.stream()
.flatMap(basePackage -> scanner.findCandidateComponents(basePackage).stream())
// findCandidateComponents can return the same package for different base packages
.distinct()
.sorted(Comparator.comparing(BeanDefinition::getBeanClassName))
.map(candidate -> {
try {
return (Class<? extends T>) ClassUtils.forName(candidate.getBeanClassName(), context.getClassLoader())
.asSubclass(targetClass);
} catch (ClassNotFoundException e) {
throw new IllegalStateException("The %s class (%s) cannot be found."
.formatted(targetClass.getSimpleName(), candidate.getBeanClassName()), e);
}
})
.collect(Collectors.toList());
}

@SafeVarargs
public final List<Class<?>> findClassesWithAnnotation(Class<? extends Annotation>... annotations) {
if (!AutoConfigurationPackages.has(context)) {
return emptyList();
}
Set<String> packages = findPackages();
return packages.stream().flatMap(rethrowFunction(
basePackage -> findAllClassesUsingClassLoader(this.context.getClassLoader(), basePackage).stream()))
.filter(clazz -> hasAnyFieldOrMethodWithAnnotation(clazz, annotations))
.toList();
}

private boolean hasAnyFieldOrMethodWithAnnotation(Class<?> clazz, Class<? extends Annotation>[] annotations) {
List<Field> fieldList = List.of(clazz.getDeclaredFields());
List<Method> methodList = List.of(clazz.getDeclaredMethods());
return List.of(annotations).stream().anyMatch(a -> fieldList.stream().anyMatch(f -> f.getAnnotation(a) != null)
|| methodList.stream().anyMatch(m -> m.getDeclaredAnnotation(a) != null));
}

public boolean hasSolutionOrEntityClasses() {
try {
return !scan(PlanningSolution.class).isEmpty() || !scan(PlanningEntity.class).isEmpty();
} catch (ClassNotFoundException e) {
throw new IllegalStateException("Scanning for @%s and @%s annotations failed."
.formatted(PlanningSolution.class.getSimpleName(), PlanningEntity.class.getSimpleName()), e);
}
}

public Class<?> findFirstSolutionClass() {
Set<Class<?>> solutionClassSet;
try {
solutionClassSet = scan(PlanningSolution.class);
} catch (ClassNotFoundException e) {
throw new IllegalStateException(
"Scanning for @%s annotations failed.".formatted(PlanningSolution.class.getSimpleName()), e);
}
return solutionClassSet.iterator().next();
}

public List<Class<?>> findEntityClassList() {
Set<Class<?>> entityClassSet;
try {
entityClassSet = scan(PlanningEntity.class);
} catch (ClassNotFoundException e) {
throw new IllegalStateException("Scanning for @%s failed.".formatted(PlanningEntity.class.getSimpleName()), e);
}
return new ArrayList<>(entityClassSet);
}

private Set<Class<?>> findAllClassesUsingClassLoader(ClassLoader classLoader, String packageName) throws IOException {
try (InputStream stream = classLoader.getResourceAsStream(packageName.replaceAll("[.]", "/"));
BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) {
return reader.lines()
.filter(line -> line.endsWith(".class"))
.map(className -> packageName + "." + className.substring(0, className.lastIndexOf('.')))
.map(className -> getClass(classLoader, className))
.collect(Collectors.toSet());
}
}

private Class<?> getClass(ClassLoader classLoader, String className) {
try {
return Class.forName(className, false, classLoader);
} catch (ClassNotFoundException e) {
// ignore the exception
}
return null;
}

@Override
protected ClassPathScanningCandidateComponentProvider
createClassPathScanningCandidateComponentProvider(ApplicationContext context) {
return new ClassPathScanningCandidateComponentProvider(false) {
@Override
protected boolean isCandidateComponent(AnnotatedBeanDefinition beanDefinition) {
AnnotationMetadata metadata = beanDefinition.getMetadata();
// Do not exclude abstract classes nor interfaces
return metadata.isIndependent();
}
};
}

}
Loading

0 comments on commit 9009fa4

Please sign in to comment.