Skip to content

Commit

Permalink
feat: expose constraint metamodel in Quarkus and Spring Boot (#1108)
Browse files Browse the repository at this point in the history
  • Loading branch information
triceo authored Sep 25, 2024
1 parent d696678 commit 0801dfd
Show file tree
Hide file tree
Showing 43 changed files with 655 additions and 308 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
*/
public interface Constraint {

String DEFAULT_CONSTRAINT_GROUP = "default";

/**
* The {@link ConstraintFactory} that built this.
*
Expand All @@ -32,6 +34,20 @@ default String getDescription() {
return "";
}

default String getConstraintGroup() {
return DEFAULT_CONSTRAINT_GROUP;
}

/**
* Returns the weight of the constraint as defined in the {@link ConstraintProvider},
* without any overrides.
*
* @return null if the constraint does not have a weight defined
*/
default <Score_ extends Score<Score_>> Score_ getConstraintWeight() {
return null;
}

/**
* @deprecated Prefer {@link #getConstraintRef()}.
* @return never null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public interface ConstraintBuilder {
/**
* Builds a {@link Constraint} from the constraint stream.
* The {@link ConstraintRef#packageName() constraint package} defaults to the package of the {@link PlanningSolution} class.
* The constraint will be placed in the {@link Constraint#DEFAULT_CONSTRAINT_GROUP default constraint group}.
*
* @param constraintName never null, shows up in {@link ConstraintMatchTotal} during score justification
* @return never null
Expand All @@ -20,12 +21,26 @@ default Constraint asConstraint(String constraintName) {
/**
* Builds a {@link Constraint} from the constraint stream.
* The {@link ConstraintRef#packageName() constraint package} defaults to the package of the {@link PlanningSolution} class.
* The constraint will be placed in the {@link Constraint#DEFAULT_CONSTRAINT_GROUP default constraint group}.
*
* @param constraintName never null, shows up in {@link ConstraintMatchTotal} during score justification
* @param constraintDescription never null
* @return never null
*/
Constraint asConstraintDescribed(String constraintName, String constraintDescription);
default Constraint asConstraintDescribed(String constraintName, String constraintDescription) {
return asConstraintDescribed(constraintName, constraintDescription, Constraint.DEFAULT_CONSTRAINT_GROUP);
}

/**
* Builds a {@link Constraint} from the constraint stream.
* The {@link ConstraintRef#packageName() constraint package} defaults to the package of the {@link PlanningSolution} class.
*
* @param constraintName never null, shows up in {@link ConstraintMatchTotal} during score justification
* @param constraintDescription never null
* @param constraintGroup never null, only allows alphanumeric characters, "-" and "_"
* @return never null
*/
Constraint asConstraintDescribed(String constraintName, String constraintDescription, String constraintGroup);

/**
* Builds a {@link Constraint} from the constraint stream.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ai.timefold.solver.core.api.score.stream;

import java.util.Collection;
import java.util.Set;

import ai.timefold.solver.core.api.score.constraint.ConstraintRef;

/**
* Provides information about the known constraints.
* Works in combination with {@link ConstraintProvider}.
*/
public interface ConstraintMetaModel {

/**
* Returns the constraint for the given reference.
*
* @param constraintRef never null
* @return null if such constraint does not exist
*/
Constraint getConstraint(ConstraintRef constraintRef);

/**
* Returns all constraints defined in the {@link ConstraintProvider}.
*
* @return never null, iteration order is undefined
*/
Collection<Constraint> getConstraints();

/**
* Returns all constraints from {@link #getConstraints()} that belong to the given group.
*
* @param constraintGroup never null
* @return never null, iteration order is undefined
*/
Collection<Constraint> getConstraintsPerGroup(String constraintGroup);

/**
* Returns constraint groups with at least one constraint in it.
*
* @return never null, iteration order is undefined
*/
Set<String> getConstraintGroups();

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.Objects;

import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.api.score.stream.ConstraintMetaModel;
import ai.timefold.solver.core.api.score.stream.ConstraintProvider;
import ai.timefold.solver.core.api.score.stream.ConstraintStreamImplType;
import ai.timefold.solver.core.config.score.director.ScoreDirectorFactoryConfig;
Expand All @@ -18,7 +19,6 @@
import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintSession;
import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintSessionFactory;
import ai.timefold.solver.core.impl.score.stream.common.AbstractConstraintStreamScoreDirectorFactory;
import ai.timefold.solver.core.impl.score.stream.common.ConstraintLibrary;
import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner;

public final class BavetConstraintStreamScoreDirectorFactory<Solution_, Score_ extends Score<Score_>>
Expand Down Expand Up @@ -56,14 +56,15 @@ private static Class<? extends ConstraintProvider> getConstraintProviderClass(Sc
}

private final BavetConstraintSessionFactory<Solution_, Score_> constraintSessionFactory;
private final ConstraintLibrary<Score_> constraintLibrary;
private final ConstraintMetaModel constraintMetaModel;

public BavetConstraintStreamScoreDirectorFactory(SolutionDescriptor<Solution_> solutionDescriptor,
ConstraintProvider constraintProvider, EnvironmentMode environmentMode) {
super(solutionDescriptor);
var constraintFactory = new BavetConstraintFactory<>(solutionDescriptor, environmentMode);
constraintLibrary = ConstraintLibrary.of(constraintFactory.buildConstraints(constraintProvider));
constraintSessionFactory = new BavetConstraintSessionFactory<>(solutionDescriptor, constraintLibrary);
constraintMetaModel =
DefaultConstraintMetaModel.of(constraintFactory.buildConstraints(constraintProvider));
constraintSessionFactory = new BavetConstraintSessionFactory<>(solutionDescriptor, constraintMetaModel);
}

@Override
Expand Down Expand Up @@ -94,8 +95,8 @@ public AbstractScoreInliner<Score_> fireAndForget(Object... facts) {
}

@Override
public ConstraintLibrary<Score_> getConstraintLibrary() {
return constraintLibrary;
public ConstraintMetaModel getConstraintMetaModel() {
return constraintMetaModel;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package ai.timefold.solver.core.impl.score.director.stream;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import ai.timefold.solver.core.api.score.constraint.ConstraintRef;
import ai.timefold.solver.core.api.score.stream.Constraint;
import ai.timefold.solver.core.api.score.stream.ConstraintMetaModel;
import ai.timefold.solver.core.impl.util.CollectionUtils;

record DefaultConstraintMetaModel(
Map<ConstraintRef, Constraint> constraintPerRefMap,
Map<String, List<Constraint>> constraintPerGroupMap) implements ConstraintMetaModel {

public static ConstraintMetaModel of(List<? extends Constraint> constraints) {
var constraintCount = constraints.size();
// Preserve iteration order by using LinkedHashMap.
var perRefMap = CollectionUtils.<ConstraintRef, Constraint> newLinkedHashMap(constraintCount);
var perGroupMap = new TreeMap<String, List<Constraint>>();
for (var constraint : constraints) {
perRefMap.put(constraint.getConstraintRef(), constraint);
// The list is used to preserve iteration order of the constraints.
// Constraint groups are an optional feature, therefore most people won't use them,
// therefore sizing the list assuming all constraints end up in the default group.
perGroupMap.computeIfAbsent(constraint.getConstraintGroup(), k -> new ArrayList<>(constraintCount))
.add(constraint);
}
return new DefaultConstraintMetaModel(
Collections.unmodifiableMap(perRefMap),
Collections.unmodifiableMap(perGroupMap));
}

@Override
public Constraint getConstraint(ConstraintRef constraintRef) {
return constraintPerRefMap.get(constraintRef);
}

@Override
public Collection<Constraint> getConstraintsPerGroup(String constraintGroup) {
return constraintPerGroupMap.getOrDefault(constraintGroup, Collections.emptyList());
}

@Override
public Set<String> getConstraintGroups() {
return constraintPerGroupMap.keySet();
}

@Override
public Collection<Constraint> getConstraints() {
return constraintPerRefMap.values();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ public final class BavetConstraint<Solution_> extends
private final BavetScoringConstraintStream<Solution_> scoringConstraintStream;

public BavetConstraint(BavetConstraintFactory<Solution_> constraintFactory, ConstraintRef constraintRef,
String description, Score<?> constraintWeight, ScoreImpactType scoreImpactType, Object justificationMapping,
Object indictedObjectsMapping, BavetScoringConstraintStream<Solution_> scoringConstraintStream) {
super(constraintFactory, constraintRef, description, constraintWeight, scoreImpactType, justificationMapping,
indictedObjectsMapping);
String description, String constraintGroup, Score<?> constraintWeight, ScoreImpactType scoreImpactType,
Object justificationMapping, Object indictedObjectsMapping,
BavetScoringConstraintStream<Solution_> scoringConstraintStream) {
super(constraintFactory, constraintRef, description, constraintGroup, constraintWeight, scoreImpactType,
justificationMapping, indictedObjectsMapping);
this.scoringConstraintStream = scoringConstraintStream;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ai.timefold.solver.core.api.score.Score;
import ai.timefold.solver.core.api.score.stream.Constraint;
import ai.timefold.solver.core.api.score.stream.ConstraintMetaModel;
import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor;
import ai.timefold.solver.core.impl.score.stream.bavet.common.AbstractConcatNode;
import ai.timefold.solver.core.impl.score.stream.bavet.common.AbstractIfExistsNode;
Expand All @@ -26,7 +27,6 @@
import ai.timefold.solver.core.impl.score.stream.bavet.common.PropagationQueue;
import ai.timefold.solver.core.impl.score.stream.bavet.common.Propagator;
import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode;
import ai.timefold.solver.core.impl.score.stream.common.ConstraintLibrary;
import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner;
import ai.timefold.solver.core.impl.util.CollectionUtils;

Expand All @@ -40,13 +40,12 @@ public final class BavetConstraintSessionFactory<Solution_, Score_ extends Score
private static final Level CONSTRAINT_WEIGHT_LOGGING_LEVEL = Level.DEBUG;

private final SolutionDescriptor<Solution_> solutionDescriptor;
private final ConstraintLibrary<Score_> constraintLibrary;
private final ConstraintMetaModel constraintMetaModel;

@SuppressWarnings("unchecked")
public BavetConstraintSessionFactory(SolutionDescriptor<Solution_> solutionDescriptor,
ConstraintLibrary<Score_> constraintLibrary) {
ConstraintMetaModel constraintMetaModel) {
this.solutionDescriptor = Objects.requireNonNull(solutionDescriptor);
this.constraintLibrary = Objects.requireNonNull(constraintLibrary);
this.constraintMetaModel = Objects.requireNonNull(constraintMetaModel);
}

// ************************************************************************
Expand All @@ -57,17 +56,17 @@ public BavetConstraintSessionFactory(SolutionDescriptor<Solution_> solutionDescr
public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, boolean constraintMatchEnabled,
boolean scoreDirectorDerived) {
var constraintWeightSupplier = solutionDescriptor.getConstraintWeightSupplier();
var constraints = constraintMetaModel.getConstraints();
if (constraintWeightSupplier != null) { // Fail fast on unknown constraints.
var knownConstraints = constraintLibrary.getConstraints()
.stream()
var knownConstraints = constraints.stream()
.map(Constraint::getConstraintRef)
.collect(Collectors.toSet());
constraintWeightSupplier.validate(workingSolution, knownConstraints);
}
var scoreDefinition = solutionDescriptor.<Score_> getScoreDefinition();
var zeroScore = scoreDefinition.getZeroScore();
var constraintStreamSet = new LinkedHashSet<BavetAbstractConstraintStream<Solution_>>();
var constraintWeightMap = CollectionUtils.<Constraint, Score_> newHashMap(constraintLibrary.getConstraints().size());
var constraintWeightMap = CollectionUtils.<Constraint, Score_> newHashMap(constraints.size());

// Only log constraint weights if logging is enabled; otherwise we don't need to build the string.
var constraintWeightLoggingEnabled = !scoreDirectorDerived && LOGGER.isEnabledForLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL);
Expand All @@ -76,10 +75,10 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
.formatted(workingSolution))
: null;

for (var constraint : constraintLibrary.getConstraints()) {
for (var constraint : constraints) {
var constraintRef = constraint.getConstraintRef();
var castConstraint = (BavetConstraint<Solution_>) constraint;
var defaultConstraintWeight = castConstraint.getDefaultConstraintWeight();
var defaultConstraintWeight = castConstraint.getConstraintWeight();
var constraintWeight = (Score_) castConstraint.extractConstraintWeight(workingSolution);
if (!constraintWeight.equals(zeroScore)) {
if (constraintWeightLoggingEnabled) {
Expand Down Expand Up @@ -122,6 +121,7 @@ public BavetConstraintSession<Score_> buildSession(Solution_ workingSolution, bo
return new BavetConstraintSession<>(scoreInliner, buildNodeNetwork(constraintStreamSet, scoreInliner));
}

@SuppressWarnings("unchecked")
private static <Solution_, Score_ extends Score<Score_>> NodeNetwork buildNodeNetwork(
Set<BavetAbstractConstraintStream<Solution_>> constraintStreamSet, AbstractScoreInliner<Score_> scoreInliner) {
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,10 @@ public <Score_ extends Score<Score_>> BiConstraintBuilder<A, B, Score_> innerImp
private <Score_ extends Score<Score_>> BiConstraintBuilderImpl<A, B, Score_> newTerminator(
BavetScoringConstraintStream<Solution_> stream, ScoreImpactType impactType, Score_ constraintWeight) {
return new BiConstraintBuilderImpl<>(
(constraintPackage, constraintName, constraintDescription, constraintWeight_, impactType_, justificationMapping,
indictedObjectsMapping) -> buildConstraint(constraintPackage, constraintName, constraintDescription,
constraintWeight_, impactType_, justificationMapping, indictedObjectsMapping, stream),
(constraintPackage, constraintName, constraintDescription, constraintGroup, constraintWeight_, impactType_,
justificationMapping, indictedObjectsMapping) -> buildConstraint(constraintPackage, constraintName,
constraintDescription, constraintGroup, constraintWeight_, impactType_, justificationMapping,
indictedObjectsMapping, stream),
impactType, constraintWeight);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ public boolean guaranteesDistinct() {

@SuppressWarnings("unchecked")
protected <Score_ extends Score<Score_>> Constraint buildConstraint(String constraintPackage, String constraintName,
String description, Score_ constraintWeight, ScoreImpactType impactType, Object justificationFunction,
Object indictedObjectsMapping,
BavetScoringConstraintStream<Solution_> stream) {
String description, String constraintGroup, Score_ constraintWeight, ScoreImpactType impactType,
Object justificationFunction, Object indictedObjectsMapping, BavetScoringConstraintStream<Solution_> stream) {
var resolvedConstraintPackage =
Objects.requireNonNullElseGet(constraintPackage, this.constraintFactory::getDefaultConstraintPackage);
var resolvedJustificationMapping =
Expand All @@ -66,7 +65,7 @@ protected <Score_ extends Score<Score_>> Constraint buildConstraint(String const
Objects.requireNonNullElseGet(indictedObjectsMapping, this::getDefaultIndictedObjectsMapping);
var isConstraintWeightConfigurable = constraintWeight == null;
var constraintRef = ConstraintRef.of(resolvedConstraintPackage, constraintName);
var constraint = new BavetConstraint<>(constraintFactory, constraintRef, description,
var constraint = new BavetConstraint<>(constraintFactory, constraintRef, description, constraintGroup,
isConstraintWeightConfigurable ? null : constraintWeight, impactType, resolvedJustificationMapping,
resolvedIndictedObjectsMapping, stream);
stream.setConstraint(constraint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,10 @@ public <Score_ extends Score<Score_>> QuadConstraintBuilder<A, B, C, D, Score_>
private <Score_ extends Score<Score_>> QuadConstraintBuilderImpl<A, B, C, D, Score_> newTerminator(
BavetScoringConstraintStream<Solution_> stream, Score_ constraintWeight, ScoreImpactType impactType) {
return new QuadConstraintBuilderImpl<>(
(constraintPackage, constraintName, constraintDescription, constraintWeight_, impactType_, justificationMapping,
indictedObjectsMapping) -> buildConstraint(constraintPackage, constraintName, constraintDescription,
constraintWeight_, impactType_, justificationMapping, indictedObjectsMapping, stream),
(constraintPackage, constraintName, constraintDescription, constraintGroup, constraintWeight_, impactType_,
justificationMapping, indictedObjectsMapping) -> buildConstraint(constraintPackage, constraintName,
constraintDescription, constraintGroup, constraintWeight_, impactType_, justificationMapping,
indictedObjectsMapping, stream),
impactType, constraintWeight);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,10 @@ public <Score_ extends Score<Score_>> TriConstraintBuilder<A, B, C, Score_> inne
private <Score_ extends Score<Score_>> TriConstraintBuilderImpl<A, B, C, Score_>
newTerminator(BavetScoringConstraintStream<Solution_> stream, Score_ constraintWeight, ScoreImpactType impactType) {
return new TriConstraintBuilderImpl<>(
(constraintPackage, constraintName, constraintDescription, constraintWeight_, impactType_, justificationMapping,
indictedObjectsMapping) -> buildConstraint(constraintPackage, constraintName, constraintDescription,
constraintWeight_, impactType_, justificationMapping, indictedObjectsMapping, stream),
(constraintPackage, constraintName, constraintDescription, constraintGroup, constraintWeight_, impactType_,
justificationMapping, indictedObjectsMapping) -> buildConstraint(constraintPackage, constraintName,
constraintDescription, constraintGroup, constraintWeight_, impactType_, justificationMapping,
indictedObjectsMapping, stream),
impactType, constraintWeight);
}

Expand Down
Loading

0 comments on commit 0801dfd

Please sign in to comment.