Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Verify variable change calls on every move in TRACKED_FULL_ASSERT #1260

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
package ai.timefold.solver.core.impl.domain.variable.listener.support.violation;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import ai.timefold.solver.core.api.domain.variable.ListVariableListener;
import ai.timefold.solver.core.api.score.director.ScoreDirector;
Expand All @@ -21,13 +28,25 @@
public class ListVariableTracker<Solution_>
implements SourcedVariableListener<Solution_>, ListVariableListener<Solution_, Object, Object>, Supply {
private final ListVariableDescriptor<Solution_> variableDescriptor;
private final List<Object> beforeVariableChangedEntityList;
private final List<Object> afterVariableChangedEntityList;
private final Map<Object, SortedSet<ChangeRange>> beforeVariableChangeEventMap;
private final Map<Object, SortedSet<ChangeRange>> afterVariableChangeEventMap;
private final Set<Object> afterUnassignedEvents;

private record ChangeRange(int start, int end) implements Comparable<ChangeRange> {
@Override
public int compareTo(@NonNull ChangeRange other) {
return Comparator.comparingInt(ChangeRange::end)
.thenComparing(ChangeRange::start)
.reversed()
.compare(this, other);
}
}

public ListVariableTracker(ListVariableDescriptor<Solution_> variableDescriptor) {
this.variableDescriptor = variableDescriptor;
beforeVariableChangedEntityList = new ArrayList<>();
afterVariableChangedEntityList = new ArrayList<>();
beforeVariableChangeEventMap = new IdentityHashMap<>();
afterVariableChangeEventMap = new IdentityHashMap<>();
afterUnassignedEvents = Collections.newSetFromMap(new IdentityHashMap<>());
}

@Override
Expand All @@ -37,8 +56,9 @@ public VariableDescriptor<Solution_> getSourceVariableDescriptor() {

@Override
public void resetWorkingSolution(@NonNull ScoreDirector<Solution_> scoreDirector) {
beforeVariableChangedEntityList.clear();
afterVariableChangedEntityList.clear();
beforeVariableChangeEventMap.clear();
afterVariableChangeEventMap.clear();
afterUnassignedEvents.clear();
}

@Override
Expand All @@ -63,42 +83,69 @@ public void afterEntityRemoved(@NonNull ScoreDirector<Solution_> scoreDirector,

@Override
public void afterListVariableElementUnassigned(@NonNull ScoreDirector<Solution_> scoreDirector, @NonNull Object element) {

afterUnassignedEvents.add(element);
}

@Override
public void beforeListVariableChanged(@NonNull ScoreDirector<Solution_> scoreDirector, @NonNull Object entity,
int fromIndex, int toIndex) {
beforeVariableChangedEntityList.add(entity);
beforeVariableChangeEventMap.computeIfAbsent(entity, k -> new TreeSet<>())
.add(new ChangeRange(fromIndex, toIndex));
}

@Override
public void afterListVariableChanged(@NonNull ScoreDirector<Solution_> scoreDirector, @NonNull Object entity, int fromIndex,
int toIndex) {
afterVariableChangedEntityList.add(entity);
afterVariableChangeEventMap.computeIfAbsent(entity, k -> new TreeSet<>())
.add(new ChangeRange(fromIndex, toIndex));
}

public List<String> getEntitiesMissingBeforeAfterEvents(
List<VariableId<Solution_>> changedVariables) {
List<VariableId<Solution_>> changedVariables,
VariableSnapshotTotal<Solution_> beforeSolution,
VariableSnapshotTotal<Solution_> afterSolution) {
List<String> out = new ArrayList<>();
Set<Object> allBeforeValues = Collections.newSetFromMap(new IdentityHashMap<>());
Set<Object> allAfterValues = Collections.newSetFromMap(new IdentityHashMap<>());
for (var changedVariable : changedVariables) {
if (!variableDescriptor.equals(changedVariable.variableDescriptor())) {
continue;
}
Object entity = changedVariable.entity();
if (!beforeVariableChangedEntityList.contains(entity)) {

if (!beforeVariableChangeEventMap.containsKey(entity)) {
out.add("Entity (" + entity
+ ") is missing a beforeListVariableChanged call for list variable ("
+ variableDescriptor.getVariableName() + ").");
}
if (!afterVariableChangedEntityList.contains(entity)) {
if (!afterVariableChangeEventMap.containsKey(entity)) {
out.add("Entity (" + entity
+ ") is missing a afterListVariableChanged call for list variable ("
+ variableDescriptor.getVariableName() + ").");
}

List<Object> beforeList =
new ArrayList<>((List<Object>) beforeSolution.getVariableSnapshot(changedVariable).value());

List<Object> afterList = new ArrayList<>((List<Object>) afterSolution.getVariableSnapshot(changedVariable).value());

allBeforeValues.addAll(beforeList);
allAfterValues.addAll(afterList);
}

var unassignedValues = Collections.newSetFromMap(new IdentityHashMap<>());
unassignedValues.addAll(allBeforeValues);
unassignedValues.removeAll(allAfterValues);

for (var unassignedValue : unassignedValues) {
if (!afterUnassignedEvents.contains(unassignedValue)) {
out.add("Missing afterListElementUnassigned: " + unassignedValue);
}
}
beforeVariableChangedEntityList.clear();
afterVariableChangedEntityList.clear();

beforeVariableChangeEventMap.clear();
afterVariableChangeEventMap.clear();
afterUnassignedEvents.clear();
return out;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import ai.timefold.solver.core.impl.domain.variable.descriptor.VariableDescriptor;
import ai.timefold.solver.core.impl.domain.variable.supply.SupplyManager;

import org.jspecify.annotations.Nullable;

public final class SolutionTracker<Solution_> {
private final SolutionDescriptor<Solution_> solutionDescriptor;
private final List<VariableTracker<Solution_>> normalVariableTrackers;
Expand Down Expand Up @@ -68,7 +70,7 @@ public void setAfterMoveSolution(Solution_ workingSolution) {
if (beforeVariables != null) {
missingEventsForward = getEntitiesMissingBeforeAfterEvents(beforeVariables, afterVariables);
} else {
missingEventsBackward = Collections.emptyList();
missingEventsForward = Collections.emptyList();
}
}

Expand Down Expand Up @@ -106,11 +108,23 @@ private List<String> getEntitiesMissingBeforeAfterEvents(VariableSnapshotTotal<S
out.addAll(normalVariableTracker.getEntitiesMissingBeforeAfterEvents(changes));
}
for (ListVariableTracker<Solution_> listVariableTracker : listVariableTrackers) {
out.addAll(listVariableTracker.getEntitiesMissingBeforeAfterEvents(changes));
out.addAll(listVariableTracker.getEntitiesMissingBeforeAfterEvents(changes,
beforeSolution, afterSolution));
}
return out;
}

public @Nullable String buildDirectorCorruptionMessage(Object completedAction) {
if (missingEventsForward != null && !missingEventsForward.isEmpty()) {
return """
Score Director Corruption Detected.
Missing variable listener events for actual move (%s):
%s
""".formatted(completedAction, formatList(missingEventsForward));
}
return null;
}

public String buildScoreCorruptionMessage() {
if (beforeMoveSolution == null) {
return "";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,12 @@ private void assertScoreFromScratch(Score_ score, Object completedAction, boolea
if (assertionScoreDirectorFactory == null) {
assertionScoreDirectorFactory = scoreDirectorFactory;
}
if (trackingWorkingSolution) {
var directorCorruption = solutionTracker.buildDirectorCorruptionMessage(completedAction);
if (directorCorruption != null) {
throw new IllegalStateException(directorCorruption);
}
}
try (var uncorruptedScoreDirector =
assertionScoreDirectorFactory.buildDerivedScoreDirector(false, ConstraintMatchPolicy.ENABLED)) {
uncorruptedScoreDirector.setWorkingSolution(workingSolution);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import ai.timefold.solver.core.api.score.director.ScoreDirector;
import ai.timefold.solver.core.api.solver.Solver;
import ai.timefold.solver.core.api.solver.SolverFactory;
import ai.timefold.solver.core.config.constructionheuristic.ConstructionHeuristicPhaseConfig;
import ai.timefold.solver.core.config.heuristic.selector.move.factory.MoveListFactoryConfig;
import ai.timefold.solver.core.config.localsearch.LocalSearchPhaseConfig;
import ai.timefold.solver.core.config.phase.custom.CustomPhaseConfig;
import ai.timefold.solver.core.config.score.director.ScoreDirectorFactoryConfig;
Expand All @@ -26,6 +28,9 @@
import ai.timefold.solver.core.config.solver.testutil.corruptedundoshadow.CorruptedUndoShadowEntity;
import ai.timefold.solver.core.config.solver.testutil.corruptedundoshadow.CorruptedUndoShadowSolution;
import ai.timefold.solver.core.config.solver.testutil.corruptedundoshadow.CorruptedUndoShadowValue;
import ai.timefold.solver.core.impl.heuristic.move.AbstractMove;
import ai.timefold.solver.core.impl.heuristic.move.Move;
import ai.timefold.solver.core.impl.heuristic.selector.move.factory.MoveListFactory;
import ai.timefold.solver.core.impl.phase.custom.CustomPhaseCommand;
import ai.timefold.solver.core.impl.phase.event.PhaseLifecycleListenerAdapter;
import ai.timefold.solver.core.impl.phase.scope.AbstractStepScope;
Expand All @@ -37,6 +42,7 @@
import ai.timefold.solver.core.impl.testdata.domain.TestdataValue;
import ai.timefold.solver.core.impl.testdata.util.PlannerTestUtils;

import org.jspecify.annotations.NonNull;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
Expand Down Expand Up @@ -192,6 +198,35 @@ void corruptedConstraints(EnvironmentMode environmentMode) {
}
}

@ParameterizedTest(name = "{0}")
@EnumSource(EnvironmentMode.class)
void corruptedScoreDirector(EnvironmentMode environmentMode) {
SolverConfig solverConfig = buildSolverConfig(environmentMode);
// For full assert modes it should throw exception about corrupted score
solverConfig.setPhaseConfigList(List.of(
new ConstructionHeuristicPhaseConfig(),
new LocalSearchPhaseConfig()
.withMoveSelectorConfig(
new MoveListFactoryConfig()
.withMoveListFactoryClass(ChangeMoveWithoutListenersMoveListFactory.class))));
setSolverConfigCalculatorClass(solverConfig, ConstantEasyScoreCalculator.class);

switch (environmentMode) {
case TRACKED_FULL_ASSERT -> {
assertIllegalStateExceptionWhileSolving(
solverConfig,
"Score Director Corruption Detected.");
}
case FULL_ASSERT,
NON_INTRUSIVE_FULL_ASSERT,
FAST_ASSERT,
REPRODUCIBLE,
NON_REPRODUCIBLE -> {
// No exception expected
}
}
}

private void assertReproducibility(Solver<TestdataSolution> solver1, Solver<TestdataSolution> solver2) {
assertGeneratingSameNumbers(((DefaultSolver<TestdataSolution>) solver1).getRandomFactory(),
((DefaultSolver<TestdataSolution>) solver2).getRandomFactory());
Expand Down Expand Up @@ -320,4 +355,44 @@ public List<SimpleScore> getScores() {
return scores;
}
}

public static class ConstantEasyScoreCalculator implements EasyScoreCalculator<TestdataSolution, SimpleScore> {
@Override
public @NonNull SimpleScore calculateScore(@NonNull TestdataSolution o) {
return SimpleScore.ZERO;
}
}

public static class ChangeMoveWithoutListeners extends AbstractMove<TestdataSolution> {
private final TestdataEntity entity;
private final TestdataValue value;

public ChangeMoveWithoutListeners(TestdataEntity entity, TestdataValue value) {
this.entity = entity;
this.value = value;
}

@Override
protected void doMoveOnGenuineVariables(ScoreDirector<TestdataSolution> scoreDirector) {
entity.setValue(value);
}

@Override
public boolean isMoveDoable(ScoreDirector<TestdataSolution> scoreDirector) {
return entity.getValue() != value;
}
}

public static class ChangeMoveWithoutListenersMoveListFactory implements MoveListFactory<TestdataSolution> {
@Override
public List<? extends Move<TestdataSolution>> createMoveList(TestdataSolution testdataSolution) {
var out = new ArrayList<Move<TestdataSolution>>();
for (var entity : testdataSolution.getEntityList()) {
for (var value : testdataSolution.getValueList()) {
out.add(new ChangeMoveWithoutListeners(entity, value));
}
}
return out;
}
}
}
Loading
Loading