From bce9ba02874543b7dc104865cae9ee43c28c6e6b Mon Sep 17 00:00:00 2001 From: Christopher Chianelli Date: Tue, 11 Jun 2024 16:46:42 -0400 Subject: [PATCH] fix: Use custom thread factor class in SolverManager --- .../impl/solver/DefaultSolverManager.java | 8 ++++- .../core/api/solver/SolverManagerTest.java | 33 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolverManager.java b/core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolverManager.java index ae74e639e9..0f226c2e9b 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolverManager.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolverManager.java @@ -23,6 +23,7 @@ import ai.timefold.solver.core.api.solver.SolverStatus; import ai.timefold.solver.core.api.solver.change.ProblemChange; import ai.timefold.solver.core.config.solver.SolverManagerConfig; +import ai.timefold.solver.core.config.util.ConfigUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,7 +48,12 @@ public DefaultSolverManager(SolverFactory solverFactory, this.solverFactory = solverFactory; validateSolverFactory(); int parallelSolverCount = solverManagerConfig.resolveParallelSolverCount(); - solverThreadPool = Executors.newFixedThreadPool(parallelSolverCount); + var threadFactory = Executors.defaultThreadFactory(); + if (solverManagerConfig.getThreadFactoryClass() != null) { + threadFactory = ConfigUtils.newInstance(solverManagerConfig, "threadFactoryClass", + solverManagerConfig.getThreadFactoryClass()); + } + solverThreadPool = Executors.newFixedThreadPool(parallelSolverCount, threadFactory); problemIdToSolverJobMap = new ConcurrentHashMap<>(parallelSolverCount * 10); } diff --git a/core/src/test/java/ai/timefold/solver/core/api/solver/SolverManagerTest.java b/core/src/test/java/ai/timefold/solver/core/api/solver/SolverManagerTest.java index ff163d596e..11d4de09f5 100644 --- a/core/src/test/java/ai/timefold/solver/core/api/solver/SolverManagerTest.java +++ b/core/src/test/java/ai/timefold/solver/core/api/solver/SolverManagerTest.java @@ -27,6 +27,7 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; @@ -1010,4 +1011,36 @@ void terminateScheduledSolverJobEarly_returnsInputProblem() throws ExecutionExce assertThat(result).isSameAs(inputProblem); assertThat(solverJob.isTerminatedEarly()).isTrue(); } + + public static class CustomThreadFactory implements ThreadFactory { + private static final String CUSTOM_THREAD_NAME = "CustomThread"; + + @Override + public Thread newThread(Runnable runnable) { + return new Thread(runnable, CUSTOM_THREAD_NAME); + } + } + + @Test + @Timeout(60) + void threadFactoryIsUsed() throws ExecutionException, InterruptedException { + var threadCheckingPhaseConfig = new CustomPhaseConfig().withCustomPhaseCommands( + scoreDirector -> { + if (!Thread.currentThread().getName().equals(CustomThreadFactory.CUSTOM_THREAD_NAME)) { + fail("Custom thread factory not used"); + } + }); + + var solverConfig = PlannerTestUtils.buildSolverConfig(TestdataSolution.class, TestdataEntity.class) + .withPhases(threadCheckingPhaseConfig, new ConstructionHeuristicPhaseConfig()); + + var solverManagerConfig = new SolverManagerConfig().withThreadFactoryClass(CustomThreadFactory.class); + solverManager = SolverManager.create(solverConfig, solverManagerConfig); + + var inputProblem = PlannerTestUtils.generateTestdataSolution("s1", 4); + var solverJob = solverManager.solve(1L, inputProblem); + + TestdataSolution result = solverJob.getFinalBestSolution(); + assertThat(result).isNotNull(); + } }