diff --git a/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/util/TracebackUtils.java b/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/util/TracebackUtils.java index 0a58f6a266..f38d9ed812 100644 --- a/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/util/TracebackUtils.java +++ b/python/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/util/TracebackUtils.java @@ -1,7 +1,7 @@ package ai.timefold.jpyinterpreter.util; -import java.io.ByteArrayOutputStream; import java.io.PrintWriter; +import java.io.StringWriter; public class TracebackUtils { private TracebackUtils() { @@ -9,9 +9,9 @@ private TracebackUtils() { } public static String getTraceback(Throwable t) { - ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream(); - PrintWriter printWriter = new PrintWriter(byteOutputStream); + var output = new StringWriter(); + PrintWriter printWriter = new PrintWriter(output); t.printStackTrace(printWriter); - return byteOutputStream.toString(); + return output.toString(); } } diff --git a/python/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/TracebackUtilsTest.java b/python/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/TracebackUtilsTest.java new file mode 100644 index 0000000000..4c57c461c2 --- /dev/null +++ b/python/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/TracebackUtilsTest.java @@ -0,0 +1,18 @@ +package ai.timefold.jpyinterpreter.util; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class TracebackUtilsTest { + @Test + void getTraceback() { + try { + throw new RuntimeException("A runtime error has occurred."); + } catch (RuntimeException e) { + assertThat(TracebackUtils.getTraceback(e)) + .contains("A runtime error has occurred.") + .contains(TracebackUtilsTest.class.getCanonicalName()); + } + } +} diff --git a/python/python-core/src/main/python/_solver_manager.py b/python/python-core/src/main/python/_solver_manager.py index e7428345b7..7f27784958 100644 --- a/python/python-core/src/main/python/_solver_manager.py +++ b/python/python-core/src/main/python/_solver_manager.py @@ -1,12 +1,13 @@ +import logging +from datetime import timedelta +from enum import Enum +from typing import Awaitable, TypeVar, Generic, Callable, TYPE_CHECKING + +from ._future import wrap_future from ._problem_change import ProblemChange, ProblemChangeWrapper -from .config import SolverConfig, SolverConfigOverride, SolverManagerConfig from ._solver_factory import SolverFactory -from ._future import wrap_future from ._timefold_java_interop import update_log_level - -from typing import Awaitable, TypeVar, Generic, Callable, TYPE_CHECKING -from datetime import timedelta -from enum import Enum +from .config import SolverConfig, SolverConfigOverride, SolverManagerConfig if TYPE_CHECKING: # These imports require a JVM to be running, so only import if type checking @@ -16,6 +17,7 @@ Solution_ = TypeVar('Solution_') ProblemId_ = TypeVar('ProblemId_') +logger = logging.getLogger('timefold.solver') class SolverStatus(Enum): @@ -155,6 +157,17 @@ def add_problem_change(self, problem_change: ProblemChange[Solution_]) -> Awaita return wrap_future(self._delegate.addProblemChange(ProblemChangeWrapper(problem_change))) +def default_exception_handler(problem_id, error): + try: + raise error + except: + # logger does not have a method for printing a message with exception info, + # so we need to raise the unwrapped recorded error to include the traceback + # in the logs + logger.exception(f'Solving failed for problem_id ({problem_id}).') + raise + + class SolverJobBuilder(Generic[Solution_, ProblemId_]): """ Provides a fluent contract that allows customization and submission of planning problems to solve. @@ -348,7 +361,7 @@ def with_exception_handler(self, exception_handler: Callable[[ProblemId_, Except from _jpyinterpreter import unwrap_python_like_object java_consumer = BiConsumer @ (lambda problem_id, error: exception_handler(unwrap_python_like_object(problem_id), - error)) + unwrap_python_like_object(error))) return SolverJobBuilder( self._delegate.withExceptionHandler(java_consumer)) @@ -508,7 +521,7 @@ def solve_builder(self) -> SolverJobBuilder[Solution_, ProblemId_]: SolverJobBuilder A new `SolverJobBuilder`. """ - return SolverJobBuilder(self._delegate.solveBuilder()) + return SolverJobBuilder(self._delegate.solveBuilder()).with_exception_handler(default_exception_handler) def get_solver_status(self, problem_id: ProblemId_) -> SolverStatus: """ diff --git a/python/python-core/tests/test_solver_manager.py b/python/python-core/tests/test_solver_manager.py index 71d7ef5618..ce91d306dd 100644 --- a/python/python-core/tests/test_solver_manager.py +++ b/python/python-core/tests/test_solver_manager.py @@ -1,10 +1,10 @@ +import logging +import pytest +from dataclasses import dataclass, field from timefold.solver import * -from timefold.solver.domain import * from timefold.solver.config import * +from timefold.solver.domain import * from timefold.solver.score import * - -import pytest -from dataclasses import dataclass, field from typing import Annotated, List @@ -243,6 +243,66 @@ def my_exception_handler(problem_id, exception): assert the_problem_id == 1 assert the_exception is not None + +@pytest.mark.filterwarnings("ignore:.*Exception in thread.*:pytest.PytestUnhandledThreadExceptionWarning") +def test_default_error(caplog): + @dataclass + class Value: + value: Annotated[int, PlanningId] + + @planning_entity + @dataclass + class Entity: + code: Annotated[str, PlanningId] + value: Annotated[Value, PlanningVariable] = field(default=None) + + @constraint_provider + def my_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .filter(lambda e: e.missing_attribute == 1) + .reward(SimpleScore.ONE, lambda entity: entity.value.value) + .as_constraint('Maximize Value') + ] + + @planning_solution + @dataclass + class Solution: + entity_list: Annotated[List[Entity], PlanningEntityCollectionProperty] + value_list: Annotated[List[Value], + DeepPlanningClone, + ProblemFactCollectionProperty, + ValueRangeProvider] + score: Annotated[SimpleScore, PlanningScore] = field(default=None) + + solver_config = SolverConfig( + solution_class=Solution, + entity_class_list=[Entity], + score_director_factory_config=ScoreDirectorFactoryConfig( + constraint_provider_function=my_constraints + ), + termination_config=TerminationConfig( + best_score_limit='6' + ) + ) + problem: Solution = Solution([Entity('A'), Entity('B'), Entity('C')], [Value(1), Value(2), Value(3)], + SimpleScore.ONE) + with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager: + with caplog.at_level(logging.ERROR, logger="timefold.solver"): + try: + (solver_manager.solve_builder() + .with_problem_id(1) + .with_problem(problem) + .run().get_final_best_solution()) + except: + pass + + assert len(caplog.records) == 1 + error_msg = str(caplog.records[0].exc_info[1]) + assert 'AttributeError' in error_msg + assert 'e.missing_attribute == 1' in error_msg + + def test_solver_config(): @dataclass class Value: