diff --git a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java index b6a437d1b675..3b7112e341cc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java @@ -47,6 +47,7 @@ import static io.trino.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static it.unimi.dsi.fastutil.HashCommon.murmurHash3; +import static java.lang.Math.min; import static java.lang.Math.multiplyExact; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -57,6 +58,7 @@ public class MultiChannelGroupByHash { private static final int INSTANCE_SIZE = ClassLayout.parseClass(MultiChannelGroupByHash.class).instanceSize(); private static final float FILL_RATIO = 0.75f; + private static final int BATCH_SIZE = 1024; // Max (page value count / cumulative dictionary size) to trigger the low cardinality case private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = .25; private static final int VALUES_PAGE_BITS = 14; // 16k positions @@ -578,7 +580,6 @@ class AddNonDictionaryPageWork implements Work { private final Page page; - private int lastPosition; public AddNonDictionaryPageWork(Page page) @@ -591,21 +592,23 @@ public boolean process() { int positionCount = page.getPositionCount(); checkState(lastPosition < positionCount, "position count out of bound"); + int remainingPositions = positionCount - lastPosition; - // needRehash() == true indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - // get the group for the current row - putIfAbsent(lastPosition, page); - lastPosition++; + for (int i = lastPosition; i < lastPosition + batchSize; i++) { + putIfAbsent(i, page); + } + + lastPosition += batchSize; + remainingPositions -= batchSize; } - return lastPosition == positionCount; + verify(lastPosition == positionCount); + return true; } @Override @@ -774,23 +777,27 @@ public GetNonDictionaryGroupIdsWork(Page page) public boolean process() { int positionCount = page.getPositionCount(); - checkState(lastPosition <= positionCount, "position count out of bound"); + checkState(lastPosition < positionCount, "position count out of bound"); checkState(!finished); - // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } + int remainingPositions = positionCount - lastPosition; - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - // output the group id for this row - groupIds[lastPosition] = putIfAbsent(lastPosition, page); - lastPosition++; + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } + + for (int i = lastPosition; i < lastPosition + batchSize; i++) { + // output the group id for this row + groupIds[i] = putIfAbsent(i, page); + } + + lastPosition += batchSize; + remainingPositions -= batchSize; } - return lastPosition == positionCount; + verify(lastPosition == positionCount); + return true; } @Override @@ -1021,4 +1028,16 @@ private int calculatePositionToCombinationIdMapping(Page page, short[] positionT } return maxCardinality; } + + private boolean ensureHashTableSize(int batchSize) + { + int positionCountUntilRehash = maxFill - nextGroupId; + while (positionCountUntilRehash < batchSize) { + if (!tryRehash()) { + return false; + } + positionCountUntilRehash = maxFill - nextGroupId; + } + return true; + } }