Skip to content

Commit

Permalink
fix: avoid exception when analysing uninitialized solution (#536)
Browse files Browse the repository at this point in the history
  • Loading branch information
triceo authored Jan 3, 2024
1 parent 90bce17 commit 67d4448
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,8 @@ or uncorrupted constraintMatchEnabled (%s) is disabled.
.formatted(constraintMatchEnabledPreference, uncorruptedScoreDirector.isConstraintMatchEnabled());
}

var corruptedAnalysis = buildScoreAnalysis(true);
var uncorruptedAnalysis = uncorruptedScoreDirector.buildScoreAnalysis(true);
var corruptedAnalysis = buildScoreAnalysis(true, ScoreAnalysisMode.SCORE_CORRUPTION);
var uncorruptedAnalysis = uncorruptedScoreDirector.buildScoreAnalysis(true, ScoreAnalysisMode.SCORE_CORRUPTION);

var excessSet = new LinkedHashSet<MatchAnalysis<Score_>>();
var missingSet = new LinkedHashSet<MatchAnalysis<Score_>>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.function.Consumer;

Expand Down Expand Up @@ -456,26 +457,28 @@ void afterListVariableChanged(ListVariableDescriptor<Solution_> variableDescript
void forceTriggerVariableListeners();

default ScoreAnalysis<Score_> buildScoreAnalysis(boolean analyzeConstraintMatches) {
return buildScoreAnalysis(analyzeConstraintMatches, false);
return buildScoreAnalysis(analyzeConstraintMatches, ScoreAnalysisMode.DEFAULT);
}

/**
*
* @param analyzeConstraintMatches True if the result's {@link ConstraintAnalysis} should have its {@link MatchAnalysis}
* populated.
* @param overrideInitScore True if the result's {@link Score} should have its {@link Score#isSolutionInitialized()} set to
* true.
* @param mode Allows to tweak the behavior of this method.
* @return never null
*/
default ScoreAnalysis<Score_> buildScoreAnalysis(boolean analyzeConstraintMatches, boolean overrideInitScore) {
default ScoreAnalysis<Score_> buildScoreAnalysis(boolean analyzeConstraintMatches, ScoreAnalysisMode mode) {
var score = calculateScore();
if (overrideInitScore) {
score = score.withInitScore(0);
} else if (!score.isSolutionInitialized()) {
throw new IllegalArgumentException("""
Cannot analyze solution (%s) as it is not initialized (%s).
Maybe run the solver first?"""
.formatted(getWorkingSolution(), score));
switch (Objects.requireNonNull(mode)) {
case RECOMMENDATION_API -> score = score.withInitScore(0);
case DEFAULT -> {
if (!score.isSolutionInitialized()) {
throw new IllegalArgumentException("""
Cannot analyze solution (%s) as it is not initialized (%s).
Maybe run the solver first?"""
.formatted(getWorkingSolution(), score));
}
}
}
var constraintAnalysisMap = new TreeMap<ConstraintRef, ConstraintAnalysis<Score_>>();
for (var constraintMatchTotal : getConstraintMatchTotalMap().values()) {
Expand All @@ -485,4 +488,25 @@ Cannot analyze solution (%s) as it is not initialized (%s).
return new ScoreAnalysis<>(score, constraintAnalysisMap);
}

enum ScoreAnalysisMode {
/**
* The default mode, which will throw an exception if the solution is not initialized.
*/
DEFAULT,
/**
* If analysis is requested as a result of a score corruption detection,
* there will be no tweaks to the score and no initialization exception will be thrown.
* This is because score corruption may have been detected during construction heuristics,
* where the score is rightfully uninitialized.
*/
SCORE_CORRUPTION,
/**
* Will not throw an exception if the solution is not initialized,
* but will set {@link Score#initScore()} to zero.
* Recommendation API always has an uninitialized solution by design.
*/
RECOMMENDATION_API

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ public List<RecommendedFit<Out_, Score_>> apply(InnerScoreDirector<Solution_, Sc
if (uninitializedCount > 1) {
throw new IllegalStateException("""
Solution (%s) has (%d) uninitialized elements.
Fit Recommendation API requires at most one uninitialized element in the solution.
"""
Fit Recommendation API requires at most one uninitialized element in the solution."""
.formatted(originalSolution, uninitializedCount));
}
var originalScoreAnalysis = scoreDirector.buildScoreAnalysis(fetchPolicy == ScoreAnalysisFetchPolicy.FETCH_ALL, true);
var originalScoreAnalysis = scoreDirector.buildScoreAnalysis(fetchPolicy == ScoreAnalysisFetchPolicy.FETCH_ALL,
InnerScoreDirector.ScoreAnalysisMode.RECOMMENDATION_API);
var clonedElement = scoreDirector.lookUpWorkingObject(originalElement);
var processor =
new FitProcessor<>(solverFactory, propositionFunction, originalScoreAnalysis, clonedElement, fetchPolicy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand All @@ -21,7 +22,11 @@
import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.api.score.buildin.hardsoft.HardSoftScore;
import ai.timefold.solver.core.api.score.buildin.simple.SimpleScore;
import ai.timefold.solver.core.api.score.calculator.ConstraintMatchAwareIncrementalScoreCalculator;
import ai.timefold.solver.core.api.score.calculator.EasyScoreCalculator;
import ai.timefold.solver.core.api.score.constraint.ConstraintMatchTotal;
import ai.timefold.solver.core.api.score.constraint.ConstraintRef;
import ai.timefold.solver.core.api.score.constraint.Indictment;
import ai.timefold.solver.core.api.score.director.ScoreDirector;
import ai.timefold.solver.core.api.solver.SolutionManager;
import ai.timefold.solver.core.api.solver.Solver;
Expand All @@ -38,6 +43,7 @@
import ai.timefold.solver.core.config.localsearch.LocalSearchType;
import ai.timefold.solver.core.config.phase.custom.CustomPhaseConfig;
import ai.timefold.solver.core.config.score.director.ScoreDirectorFactoryConfig;
import ai.timefold.solver.core.config.solver.EnvironmentMode;
import ai.timefold.solver.core.config.solver.SolverConfig;
import ai.timefold.solver.core.config.solver.monitoring.MonitoringConfig;
import ai.timefold.solver.core.config.solver.monitoring.SolverMetric;
Expand All @@ -49,6 +55,8 @@
import ai.timefold.solver.core.impl.phase.event.PhaseLifecycleListenerAdapter;
import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope;
import ai.timefold.solver.core.impl.score.DummySimpleScoreEasyScoreCalculator;
import ai.timefold.solver.core.impl.score.constraint.DefaultConstraintMatchTotal;
import ai.timefold.solver.core.impl.score.constraint.DefaultIndictment;
import ai.timefold.solver.core.impl.testdata.domain.TestdataEntity;
import ai.timefold.solver.core.impl.testdata.domain.TestdataSolution;
import ai.timefold.solver.core.impl.testdata.domain.TestdataValue;
Expand All @@ -71,6 +79,7 @@
import ai.timefold.solver.core.impl.testdata.util.PlannerTestUtils;
import ai.timefold.solver.core.impl.testutil.TestMeterRegistry;

import org.assertj.core.api.Assertions;
import org.assertj.core.api.SoftAssertions;
import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -111,6 +120,96 @@ void solve() {
assertThat(solution.getScore().isSolutionInitialized()).isTrue();
}

@Test
void solveCorruptedEasyUninitialized() {
var solverConfig = PlannerTestUtils.buildSolverConfig(TestdataSolution.class, TestdataEntity.class)
.withEnvironmentMode(EnvironmentMode.FULL_ASSERT)
.withEasyScoreCalculatorClass(CorruptedEasyScoreCalculator.class);
var solverFactory = SolverFactory.<TestdataSolution> create(solverConfig);
var solver = solverFactory.buildSolver();

var solution = new TestdataSolution("s1");
solution.setValueList(Arrays.asList(new TestdataValue("v1"), new TestdataValue("v2")));
solution.setEntityList(Arrays.asList(new TestdataEntity("e1"), new TestdataEntity("e2")));

Assertions.assertThatThrownBy(() -> solver.solve(solution))
.hasMessageContaining("Score corruption")
.hasMessageContaining("workingScore")
.hasMessageContaining("uncorruptedScore")
.hasMessageContaining("Score corruption analysis could not be generated");
}

@Test
void solveCorruptedEasyInitialized() {
var solverConfig = PlannerTestUtils.buildSolverConfig(TestdataSolution.class, TestdataEntity.class)
.withEnvironmentMode(EnvironmentMode.FULL_ASSERT)
.withEasyScoreCalculatorClass(CorruptedEasyScoreCalculator.class);
var solverFactory = SolverFactory.<TestdataSolution> create(solverConfig);
var solver = solverFactory.buildSolver();

var solution = new TestdataSolution("s1");
var value1 = new TestdataValue("v1");
var value2 = new TestdataValue("v2");
solution.setValueList(List.of(value1, value2));
var entity1 = new TestdataEntity("e1");
entity1.setValue(value1);
var entity2 = new TestdataEntity("e2");
entity2.setValue(value2);
solution.setEntityList(List.of(entity1, entity2));

Assertions.assertThatThrownBy(() -> solver.solve(solution))
.hasMessageContaining("Score corruption")
.hasMessageContaining("workingScore")
.hasMessageContaining("uncorruptedScore")
.hasMessageContaining("Score corruption analysis could not be generated");
}

@Test
void solveCorruptedIncrementalUninitialized() {
var solverConfig = PlannerTestUtils.buildSolverConfig(TestdataSolution.class, TestdataEntity.class)
.withEnvironmentMode(EnvironmentMode.FULL_ASSERT)
.withScoreDirectorFactory(new ScoreDirectorFactoryConfig()
.withIncrementalScoreCalculatorClass(CorruptedIncrementalScoreCalculator.class));
var solverFactory = SolverFactory.<TestdataSolution> create(solverConfig);
var solver = solverFactory.buildSolver();

var solution = new TestdataSolution("s1");
solution.setValueList(Arrays.asList(new TestdataValue("v1"), new TestdataValue("v2")));
solution.setEntityList(Arrays.asList(new TestdataEntity("e1"), new TestdataEntity("e2")));

Assertions.assertThatThrownBy(() -> solver.solve(solution))
.hasMessageContaining("Score corruption")
.hasMessageContaining("workingScore")
.hasMessageContaining("uncorruptedScore")
.hasMessageContaining("Score corruption analysis:");
}

@Test
void solveCorruptedIncrementalInitialized() {
var solverConfig = PlannerTestUtils.buildSolverConfig(TestdataSolution.class, TestdataEntity.class)
.withEnvironmentMode(EnvironmentMode.FULL_ASSERT)
.withScoreDirectorFactory(new ScoreDirectorFactoryConfig()
.withIncrementalScoreCalculatorClass(CorruptedIncrementalScoreCalculator.class));
var solverFactory = SolverFactory.<TestdataSolution> create(solverConfig);
var solver = solverFactory.buildSolver();

var solution = new TestdataSolution("s1");
var value1 = new TestdataValue("v1");
var value2 = new TestdataValue("v2");
solution.setValueList(List.of(value1, value2));
var entity1 = new TestdataEntity("e1");
entity1.setValue(value1);
var entity2 = new TestdataEntity("e2");
entity2.setValue(value2);
solution.setEntityList(List.of(entity1, entity2));

Assertions.assertThatThrownBy(() -> solver.solve(solution))
.hasMessageContaining("Score corruption")
.hasMessageContaining("workingScore")
.hasMessageContaining("uncorruptedScore")
.hasMessageContaining("Score corruption analysis:");
}

@Test
void checkDefaultMeters() {
TestMeterRegistry meterRegistry = new TestMeterRegistry();
Expand Down Expand Up @@ -865,4 +964,74 @@ void solveWithMultipleChainedPlanningEntities() {
assertThat(solution.getScore().isSolutionInitialized()).isTrue();
}

public static class CorruptedEasyScoreCalculator implements EasyScoreCalculator<TestdataSolution, SimpleScore> {

@Override
public SimpleScore calculateScore(TestdataSolution testdataSolution) {
int random = (int) (Math.random() * 1000);
return SimpleScore.of(random);
}
}

public static class CorruptedIncrementalScoreCalculator
implements ConstraintMatchAwareIncrementalScoreCalculator<TestdataSolution, SimpleScore> {

@Override
public void resetWorkingSolution(TestdataSolution workingSolution, boolean constraintMatchEnabled) {

}

@Override
public Collection<ConstraintMatchTotal<SimpleScore>> getConstraintMatchTotals() {
return Collections.singletonList(new DefaultConstraintMatchTotal<>(ConstraintRef.of("a", "b"), SimpleScore.of(1)));
}

@Override
public Map<Object, Indictment<SimpleScore>> getIndictmentMap() {
return Collections.singletonMap(new TestdataEntity("e1"),
new DefaultIndictment<>(new TestdataEntity("e1"), SimpleScore.ONE));
}

@Override
public void resetWorkingSolution(TestdataSolution workingSolution) {

}

@Override
public void beforeEntityAdded(Object entity) {

}

@Override
public void afterEntityAdded(Object entity) {

}

@Override
public void beforeVariableChanged(Object entity, String variableName) {

}

@Override
public void afterVariableChanged(Object entity, String variableName) {

}

@Override
public void beforeEntityRemoved(Object entity) {

}

@Override
public void afterEntityRemoved(Object entity) {

}

@Override
public SimpleScore calculateScore() {
int random = (int) (Math.random() * 1000);
return SimpleScore.of(random);
}
}

}

0 comments on commit 67d4448

Please sign in to comment.