Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Improve SolverManager error message in Python #1114

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package ai.timefold.jpyinterpreter.util;

import java.io.ByteArrayOutputStream;
import java.io.PrintWriter;
import java.io.StringWriter;

public class TracebackUtils {
private TracebackUtils() {

}

public static String getTraceback(Throwable t) {
ByteArrayOutputStream byteOutputStream = new ByteArrayOutputStream();
PrintWriter printWriter = new PrintWriter(byteOutputStream);
var output = new StringWriter();
PrintWriter printWriter = new PrintWriter(output);
t.printStackTrace(printWriter);
return byteOutputStream.toString();
return output.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package ai.timefold.jpyinterpreter.util;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Test;

class TracebackUtilsTest {
@Test
void getTraceback() {
try {
throw new RuntimeException("A runtime error has occurred.");
} catch (RuntimeException e) {
assertThat(TracebackUtils.getTraceback(e))
.contains("A runtime error has occurred.")
.contains(TracebackUtilsTest.class.getCanonicalName());
}
}
}
29 changes: 21 additions & 8 deletions python/python-core/src/main/python/_solver_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from datetime import timedelta
from enum import Enum
from typing import Awaitable, TypeVar, Generic, Callable, TYPE_CHECKING

from ._future import wrap_future
from ._problem_change import ProblemChange, ProblemChangeWrapper
from .config import SolverConfig, SolverConfigOverride, SolverManagerConfig
from ._solver_factory import SolverFactory
from ._future import wrap_future
from ._timefold_java_interop import update_log_level

from typing import Awaitable, TypeVar, Generic, Callable, TYPE_CHECKING
from datetime import timedelta
from enum import Enum
from .config import SolverConfig, SolverConfigOverride, SolverManagerConfig

if TYPE_CHECKING:
# These imports require a JVM to be running, so only import if type checking
Expand All @@ -16,6 +17,7 @@

Solution_ = TypeVar('Solution_')
ProblemId_ = TypeVar('ProblemId_')
logger = logging.getLogger('timefold.solver')


class SolverStatus(Enum):
Expand Down Expand Up @@ -155,6 +157,17 @@ def add_problem_change(self, problem_change: ProblemChange[Solution_]) -> Awaita
return wrap_future(self._delegate.addProblemChange(ProblemChangeWrapper(problem_change)))


def default_exception_handler(problem_id, error):
try:
raise error
except:
# logger does not have a method for printing a message with exception info,
# so we need to raise the unwrapped recorded error to include the traceback
# in the logs
logger.exception(f'Solving failed for problem_id ({problem_id}).')
raise


class SolverJobBuilder(Generic[Solution_, ProblemId_]):
"""
Provides a fluent contract that allows customization and submission of planning problems to solve.
Expand Down Expand Up @@ -348,7 +361,7 @@ def with_exception_handler(self, exception_handler: Callable[[ProblemId_, Except
from _jpyinterpreter import unwrap_python_like_object

java_consumer = BiConsumer @ (lambda problem_id, error: exception_handler(unwrap_python_like_object(problem_id),
error))
unwrap_python_like_object(error)))
return SolverJobBuilder(
self._delegate.withExceptionHandler(java_consumer))

Expand Down Expand Up @@ -508,7 +521,7 @@ def solve_builder(self) -> SolverJobBuilder[Solution_, ProblemId_]:
SolverJobBuilder
A new `SolverJobBuilder`.
"""
return SolverJobBuilder(self._delegate.solveBuilder())
return SolverJobBuilder(self._delegate.solveBuilder()).with_exception_handler(default_exception_handler)

def get_solver_status(self, problem_id: ProblemId_) -> SolverStatus:
"""
Expand Down
68 changes: 64 additions & 4 deletions python/python-core/tests/test_solver_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import pytest
from dataclasses import dataclass, field
from timefold.solver import *
from timefold.solver.domain import *
from timefold.solver.config import *
from timefold.solver.domain import *
from timefold.solver.score import *

import pytest
from dataclasses import dataclass, field
from typing import Annotated, List


Expand Down Expand Up @@ -243,6 +243,66 @@ def my_exception_handler(problem_id, exception):
assert the_problem_id == 1
assert the_exception is not None


@pytest.mark.filterwarnings("ignore:.*Exception in thread.*:pytest.PytestUnhandledThreadExceptionWarning")
def test_default_error(caplog):
@dataclass
class Value:
value: Annotated[int, PlanningId]

@planning_entity
@dataclass
class Entity:
code: Annotated[str, PlanningId]
value: Annotated[Value, PlanningVariable] = field(default=None)

@constraint_provider
def my_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.filter(lambda e: e.missing_attribute == 1)
.reward(SimpleScore.ONE, lambda entity: entity.value.value)
.as_constraint('Maximize Value')
]

@planning_solution
@dataclass
class Solution:
entity_list: Annotated[List[Entity], PlanningEntityCollectionProperty]
value_list: Annotated[List[Value],
DeepPlanningClone,
ProblemFactCollectionProperty,
ValueRangeProvider]
score: Annotated[SimpleScore, PlanningScore] = field(default=None)

solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Entity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=my_constraints
),
termination_config=TerminationConfig(
best_score_limit='6'
)
)
problem: Solution = Solution([Entity('A'), Entity('B'), Entity('C')], [Value(1), Value(2), Value(3)],
SimpleScore.ONE)
with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager:
with caplog.at_level(logging.ERROR, logger="timefold.solver"):
try:
(solver_manager.solve_builder()
.with_problem_id(1)
.with_problem(problem)
.run().get_final_best_solution())
except:
pass

assert len(caplog.records) == 1
error_msg = str(caplog.records[0].exc_info[1])
assert 'AttributeError' in error_msg
assert 'e.missing_attribute == 1' in error_msg


def test_solver_config():
@dataclass
class Value:
Expand Down
Loading