Skip to content

Commit

Permalink
perf: minor improvements to Constraint Streams (#830)
Browse files Browse the repository at this point in the history
- Iterating manually in join allows to create much less instances of
`Iterator`, reducing GC pressure significantly.
- min/max constraint collector was needlessly using a `Map`, which only
ever had one key.

In cases particularly exposed to these issues, the performance
improvements seen were ~ 10 %.
  • Loading branch information
triceo authored May 6, 2024
1 parent 4f19adb commit 5443347
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ protected AbstractJoinNode(int inputStoreIndexLeftOutTupleList, int inputStoreIn
protected abstract boolean testFiltering(LeftTuple_ leftTuple, UniTuple<Right_> rightTuple);

protected final void insertOutTuple(LeftTuple_ leftTuple, UniTuple<Right_> rightTuple) {
OutTuple_ outTuple = createOutTuple(leftTuple, rightTuple);
var outTuple = createOutTuple(leftTuple, rightTuple);
ElementAwareList<OutTuple_> outTupleListLeft = leftTuple.getStore(inputStoreIndexLeftOutTupleList);
ElementAwareListEntry<OutTuple_> outEntryLeft = outTupleListLeft.add(outTuple);
var outEntryLeft = outTupleListLeft.add(outTuple);
outTuple.setStore(outputStoreIndexLeftOutEntry, outEntryLeft);
ElementAwareList<OutTuple_> outTupleListRight = rightTuple.getStore(inputStoreIndexRightOutTupleList);
ElementAwareListEntry<OutTuple_> outEntryRight = outTupleListRight.add(outTuple);
var outEntryRight = outTupleListRight.add(outTuple);
outTuple.setStore(outputStoreIndexRightOutEntry, outEntryRight);
propagationQueue.insert(outTuple);
}
Expand All @@ -72,7 +72,7 @@ protected final void innerUpdateLeft(LeftTuple_ leftTuple, Consumer<Consumer<Uni
ElementAwareList<OutTuple_> outTupleListLeft = leftTuple.getStore(inputStoreIndexLeftOutTupleList);
// Propagate the update for downstream filters, matchWeighers, ...
if (!isFiltering) {
for (OutTuple_ outTuple : outTupleListLeft) {
for (var outTuple : outTupleListLeft) {
updateOutTupleLeft(outTuple, leftTuple);
}
} else {
Expand All @@ -89,7 +89,7 @@ private void updateOutTupleLeft(OutTuple_ outTuple, LeftTuple_ leftTuple) {
}

private void doUpdateOutTuple(OutTuple_ outTuple) {
TupleState state = outTuple.state;
var state = outTuple.state;
if (!state.isActive()) { // Impossible because they shouldn't linger in the indexes.
throw new IllegalStateException("Impossible state: The tuple (" + outTuple.state + ") in node (" +
this + ") is in an unexpected state (" + outTuple.state + ").");
Expand All @@ -104,7 +104,7 @@ protected final void innerUpdateRight(UniTuple<Right_> rightTuple, Consumer<Cons
ElementAwareList<OutTuple_> outTupleListRight = rightTuple.getStore(inputStoreIndexRightOutTupleList);
if (!isFiltering) {
// Propagate the update for downstream filters, matchWeighers, ...
for (OutTuple_ outTuple : outTupleListRight) {
for (var outTuple : outTupleListRight) {
setOutTupleRightFact(outTuple, rightTuple);
doUpdateOutTuple(outTuple);
}
Expand All @@ -118,7 +118,7 @@ protected final void innerUpdateRight(UniTuple<Right_> rightTuple, Consumer<Cons

private void processOutTupleUpdate(LeftTuple_ leftTuple, UniTuple<Right_> rightTuple, ElementAwareList<OutTuple_> outList,
ElementAwareList<OutTuple_> outTupleList, int outputStoreIndexOutEntry) {
OutTuple_ outTuple = findOutTuple(outTupleList, outList, outputStoreIndexOutEntry);
var outTuple = findOutTuple(outTupleList, outList, outputStoreIndexOutEntry);
if (testFiltering(leftTuple, rightTuple)) {
if (outTuple == null) {
insertOutTuple(leftTuple, rightTuple);
Expand All @@ -132,15 +132,19 @@ private void processOutTupleUpdate(LeftTuple_ leftTuple, UniTuple<Right_> rightT
}
}

private OutTuple_ findOutTuple(ElementAwareList<OutTuple_> outTupleList, ElementAwareList<OutTuple_> outList,
int outputStoreIndexOutEntry) {
private static <Tuple_ extends AbstractTuple> Tuple_ findOutTuple(ElementAwareList<Tuple_> outTupleList,
ElementAwareList<Tuple_> outList, int outputStoreIndexOutEntry) {
// Hack: the outTuple has no left/right input tuple reference, use the left/right outList reference instead.
for (OutTuple_ outTuple : outTupleList) {
ElementAwareListEntry<OutTuple_> outEntry = outTuple.getStore(outputStoreIndexOutEntry);
ElementAwareList<OutTuple_> outEntryList = outEntry.getList();
var item = outTupleList.first();
while (item != null) {
// Creating list iterators here caused major GC pressure; therefore, we iterate over the entries directly.
var outTuple = item.getElement();
ElementAwareListEntry<Tuple_> outEntry = outTuple.getStore(outputStoreIndexOutEntry);
var outEntryList = outEntry.getList();
if (outList == outEntryList) {
return outTuple;
}
item = item.next();
}
return null;
}
Expand All @@ -150,7 +154,7 @@ protected final void retractOutTuple(OutTuple_ outTuple) {
outEntryLeft.remove();
ElementAwareListEntry<OutTuple_> outEntryRight = outTuple.removeStore(outputStoreIndexRightOutEntry);
outEntryRight.remove();
TupleState state = outTuple.state;
var state = outTuple.state;
if (!state.isActive()) {
// Impossible because they shouldn't linger in the indexes.
throw new IllegalStateException("Impossible state: The tuple (" + outTuple.state + ") in node (" + this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,20 @@
import ai.timefold.solver.core.impl.util.ElementAwareListEntry;

final class ComparisonIndexer<T, Key_ extends Comparable<Key_>>
implements ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T> {
implements Indexer<T> {

private final int propertyIndex;
private final Supplier<ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> downstreamIndexerSupplier;
private final Supplier<Indexer<T>> downstreamIndexerSupplier;
private final Comparator<Key_> keyComparator;
private final boolean hasOrEquals;
private final NavigableMap<Key_, ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> comparisonMap;
private final NavigableMap<Key_, Indexer<T>> comparisonMap;

public ComparisonIndexer(JoinerType comparisonJoinerType,
Supplier<ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> downstreamIndexerSupplier) {
public ComparisonIndexer(JoinerType comparisonJoinerType, Supplier<Indexer<T>> downstreamIndexerSupplier) {
this(comparisonJoinerType, 0, downstreamIndexerSupplier);
}

public ComparisonIndexer(JoinerType comparisonJoinerType, int propertyIndex,
Supplier<ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> downstreamIndexerSupplier) {
Supplier<Indexer<T>> downstreamIndexerSupplier) {
this.propertyIndex = propertyIndex;
this.downstreamIndexerSupplier = Objects.requireNonNull(downstreamIndexerSupplier);
/*
Expand All @@ -44,11 +43,10 @@ public ComparisonIndexer(JoinerType comparisonJoinerType, int propertyIndex,
}

@Override
public ElementAwareListEntry<T>
put(ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties indexProperties, T tuple) {
public ElementAwareListEntry<T> put(IndexProperties indexProperties, T tuple) {
Key_ indexKey = indexProperties.toKey(propertyIndex);
// Avoids computeIfAbsent in order to not create lambdas on the hot path.
ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T> downstreamIndexer = comparisonMap.get(indexKey);
var downstreamIndexer = comparisonMap.get(indexKey);
if (downstreamIndexer == null) {
downstreamIndexer = downstreamIndexerSupplier.get();
comparisonMap.put(indexKey, downstreamIndexer);
Expand All @@ -57,22 +55,17 @@ public ComparisonIndexer(JoinerType comparisonJoinerType, int propertyIndex,
}

@Override
public void remove(ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties indexProperties,
ElementAwareListEntry<T> entry) {
public void remove(IndexProperties indexProperties, ElementAwareListEntry<T> entry) {
Key_ indexKey = indexProperties.toKey(propertyIndex);
ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T> downstreamIndexer =
getDownstreamIndexer(indexProperties, indexKey, entry);
var downstreamIndexer = getDownstreamIndexer(indexProperties, indexKey, entry);
downstreamIndexer.remove(indexProperties, entry);
if (downstreamIndexer.isEmpty()) {
comparisonMap.remove(indexKey);
}
}

private ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T> getDownstreamIndexer(
ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties indexProperties, Key_ indexerKey,
ElementAwareListEntry<T> entry) {
ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T> downstreamIndexer =
comparisonMap.get(indexerKey);
private Indexer<T> getDownstreamIndexer(IndexProperties indexProperties, Key_ indexerKey, ElementAwareListEntry<T> entry) {
var downstreamIndexer = comparisonMap.get(indexerKey);
if (downstreamIndexer == null) {
throw new IllegalStateException("Impossible state: the tuple (" + entry.getElement()
+ ") with indexProperties (" + indexProperties
Expand All @@ -83,16 +76,15 @@ private ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>

// TODO clean up DRY
@Override
public int size(ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties indexProperties) {
int mapSize = comparisonMap.size();
public int size(IndexProperties indexProperties) {
var mapSize = comparisonMap.size();
if (mapSize == 0) {
return 0;
}
Key_ indexKey = indexProperties.toKey(propertyIndex);
if (mapSize == 1) { // Avoid creation of the entry set and iterator.
Map.Entry<Key_, ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> entry =
comparisonMap.firstEntry();
int comparison = keyComparator.compare(entry.getKey(), indexKey);
var entry = comparisonMap.firstEntry();
var comparison = keyComparator.compare(entry.getKey(), indexKey);
if (comparison >= 0) { // Possibility of reaching the boundary condition.
if (comparison > 0 || !hasOrEquals) {
// Boundary condition reached when we're out of bounds entirely, or when GTE/LTE is not allowed.
Expand All @@ -101,10 +93,9 @@ public int size(ai.timefold.solver.core.impl.score.stream.bavet.common.index.Ind
}
return entry.getValue().size(indexProperties);
} else {
int size = 0;
for (Map.Entry<Key_, ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> entry : comparisonMap
.entrySet()) {
int comparison = keyComparator.compare(entry.getKey(), indexKey);
var size = 0;
for (var entry : comparisonMap.entrySet()) {
var comparison = keyComparator.compare(entry.getKey(), indexKey);
if (comparison >= 0) { // Possibility of reaching the boundary condition.
if (comparison > 0 || !hasOrEquals) {
// Boundary condition reached when we're out of bounds entirely, or when GTE/LTE is not allowed.
Expand All @@ -119,33 +110,29 @@ public int size(ai.timefold.solver.core.impl.score.stream.bavet.common.index.Ind
}

@Override
public void forEach(ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties indexProperties,
Consumer<T> tupleConsumer) {
int size = comparisonMap.size();
public void forEach(IndexProperties indexProperties, Consumer<T> tupleConsumer) {
var size = comparisonMap.size();
if (size == 0) {
return;
}
Key_ indexKey = indexProperties.toKey(propertyIndex);
if (size == 1) { // Avoid creation of the entry set and iterator.
Map.Entry<Key_, ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> entry =
comparisonMap.firstEntry();
var entry = comparisonMap.firstEntry();
visitEntry(indexProperties, tupleConsumer, indexKey, entry);
} else {
for (Map.Entry<Key_, ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> entry : comparisonMap
.entrySet()) {
boolean boundaryReached = visitEntry(indexProperties, tupleConsumer, indexKey, entry);
for (var entry : comparisonMap.entrySet()) {
var boundaryReached = visitEntry(indexProperties, tupleConsumer, indexKey, entry);
if (boundaryReached) {
return;
}
}
}
}

private boolean visitEntry(ai.timefold.solver.core.impl.score.stream.bavet.common.index.IndexProperties indexProperties,
Consumer<T> tupleConsumer,
Key_ indexKey, Map.Entry<Key_, ai.timefold.solver.core.impl.score.stream.bavet.common.index.Indexer<T>> entry) {
private boolean visitEntry(IndexProperties indexProperties, Consumer<T> tupleConsumer, Key_ indexKey,
Map.Entry<Key_, Indexer<T>> entry) {
// Comparator matches the order of iteration of the map, so the boundary is always found from the bottom up.
int comparison = keyComparator.compare(entry.getKey(), indexKey);
var comparison = keyComparator.compare(entry.getKey(), indexKey);
if (comparison >= 0) { // Possibility of reaching the boundary condition.
if (comparison > 0 || !hasOrEquals) {
// Boundary condition reached when we're out of bounds entirely, or when GTE/LTE is not allowed.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
package ai.timefold.solver.core.impl.score.stream.collector;

import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import java.util.function.Function;

import ai.timefold.solver.core.impl.util.ConstantLambdaUtils;
import ai.timefold.solver.core.impl.util.MutableInt;

public final class MinMaxUndoableActionable<Result_, Property_> implements UndoableActionable<Result_, Result_> {
public final class MinMaxUndoableActionable<Result_, Property_>
implements UndoableActionable<Result_, Result_> {

private final boolean isMin;
private final NavigableMap<Property_, Map<Result_, MutableInt>> propertyToItemCountMap;
private final NavigableMap<Property_, ItemCount<Result_>> propertyToItemCountMap;
private final Function<? super Result_, ? extends Property_> propertyFunction;

private MinMaxUndoableActionable(boolean isMin,
NavigableMap<Property_, Map<Result_, MutableInt>> propertyToItemCountMap,
private MinMaxUndoableActionable(boolean isMin, NavigableMap<Property_, ItemCount<Result_>> propertyToItemCountMap,
Function<? super Result_, ? extends Property_> propertyFunction) {
this.isMin = isMin;
this.propertyToItemCountMap = propertyToItemCountMap;
Expand All @@ -40,30 +39,29 @@ public static <Result> MinMaxUndoableActionable<Result, Result> maxCalculator(Co
}

public static <Result, Property extends Comparable<? super Property>> MinMaxUndoableActionable<Result, Property>
minCalculator(
Function<? super Result, ? extends Property> propertyMapper) {
minCalculator(Function<? super Result, ? extends Property> propertyMapper) {
return new MinMaxUndoableActionable<>(true, new TreeMap<>(), propertyMapper);
}

public static <Result, Property extends Comparable<? super Property>> MinMaxUndoableActionable<Result, Property>
maxCalculator(
Function<? super Result, ? extends Property> propertyMapper) {
maxCalculator(Function<? super Result, ? extends Property> propertyMapper) {
return new MinMaxUndoableActionable<>(false, new TreeMap<>(), propertyMapper);
}

@Override
public Runnable insert(Result_ item) {
Property_ key = propertyFunction.apply(item);
Map<Result_, MutableInt> itemCountMap = propertyToItemCountMap.computeIfAbsent(key, ignored -> new LinkedHashMap<>());
MutableInt count = itemCountMap.computeIfAbsent(item, ignored -> new MutableInt());
var value = propertyToItemCountMap.get(key);
if (value == null) {
value = new ItemCount<>(item, new MutableInt());
propertyToItemCountMap.put(key, value);
}
var count = value.count;
count.increment();

return () -> {
if (count.decrement() == 0) {
itemCountMap.remove(item);
if (itemCountMap.isEmpty()) {
propertyToItemCountMap.remove(key);
}
propertyToItemCountMap.remove(key);
}
};
}
Expand All @@ -73,11 +71,11 @@ public Result_ result() {
if (propertyToItemCountMap.isEmpty()) {
return null;
}
return isMin ? getFirstKey(propertyToItemCountMap.firstEntry().getValue())
: getFirstKey(propertyToItemCountMap.lastEntry().getValue());
var itemCount = isMin ? propertyToItemCountMap.firstEntry().getValue() : propertyToItemCountMap.lastEntry().getValue();
return itemCount.item;
}

private static <Key_> Key_ getFirstKey(Map<Key_, ?> map) {
return map.keySet().iterator().next();
private record ItemCount<Item_>(Item_ item, MutableInt count) {
}

}

0 comments on commit 5443347

Please sign in to comment.