diff --git a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java index 342c39039d3b..3b7112e341cc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java @@ -16,12 +16,10 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; -import io.trino.array.LongBigArray; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; -import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; @@ -42,8 +40,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static io.airlift.slice.SizeOf.sizeOf; -import static io.trino.operator.SyntheticAddress.decodePosition; -import static io.trino.operator.SyntheticAddress.decodeSliceIndex; import static io.trino.operator.SyntheticAddress.encodeSyntheticAddress; import static io.trino.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static io.trino.spi.type.BigintType.BIGINT; @@ -51,6 +47,7 @@ import static io.trino.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static it.unimi.dsi.fastutil.HashCommon.murmurHash3; +import static java.lang.Math.min; import static java.lang.Math.multiplyExact; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -61,8 +58,12 @@ public class MultiChannelGroupByHash { private static final int INSTANCE_SIZE = ClassLayout.parseClass(MultiChannelGroupByHash.class).instanceSize(); private static final float FILL_RATIO = 0.75f; + private static final int BATCH_SIZE = 1024; // Max (page value count / cumulative dictionary size) to trigger the low cardinality case private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = .25; + private static final int VALUES_PAGE_BITS = 14; // 16k positions + private static final int VALUES_PAGE_MAX_ROW_COUNT = 1 << VALUES_PAGE_BITS; + private static final int VALUES_PAGE_MASK = VALUES_PAGE_MAX_ROW_COUNT - 1; private final List types; private final List hashTypes; @@ -81,12 +82,11 @@ public class MultiChannelGroupByHash private int hashCapacity; private int maxFill; private int mask; - private long[] groupAddressByHash; + // Group ids are assigned incrementally. Therefore, since values page size is constant and power of two, + // the group id is also an address (slice index and position within slice) to group row in channelBuilders. private int[] groupIdsByHash; private byte[] rawHashByHashPosition; - private final LongBigArray groupAddressByGroupId; - private int nextGroupId; private DictionaryLookBack dictionaryLookBack; private long hashCollisions; @@ -148,15 +148,9 @@ public MultiChannelGroupByHash( maxFill = calculateMaxFill(hashCapacity); mask = hashCapacity - 1; - groupAddressByHash = new long[hashCapacity]; - Arrays.fill(groupAddressByHash, -1); - rawHashByHashPosition = new byte[hashCapacity]; - groupIdsByHash = new int[hashCapacity]; - - groupAddressByGroupId = new LongBigArray(); - groupAddressByGroupId.ensureCapacity(maxFill); + Arrays.fill(groupIdsByHash, -1); // This interface is used for actively reserving memory (push model) for rehash. // The caller can also query memory usage on this object (pull model) @@ -166,9 +160,8 @@ public MultiChannelGroupByHash( @Override public long getRawHash(int groupId) { - long address = groupAddressByGroupId.get(groupId); - int blockIndex = decodeSliceIndex(address); - int position = decodePosition(address); + int blockIndex = groupId >> VALUES_PAGE_BITS; + int position = groupId & VALUES_PAGE_MASK; return hashStrategy.hashPosition(blockIndex, position); } @@ -179,9 +172,7 @@ public long getEstimatedSize() (sizeOf(channelBuilders.get(0).elements()) * channelBuilders.size()) + completedPagesMemorySize + currentPageBuilder.getRetainedSizeInBytes() + - sizeOf(groupAddressByHash) + sizeOf(groupIdsByHash) + - groupAddressByGroupId.sizeOf() + sizeOf(rawHashByHashPosition) + preallocatedMemoryInBytes; } @@ -213,9 +204,8 @@ public int getGroupCount() @Override public void appendValuesTo(int groupId, PageBuilder pageBuilder) { - long address = groupAddressByGroupId.get(groupId); - int blockIndex = decodeSliceIndex(address); - int position = decodePosition(address); + int blockIndex = groupId >> VALUES_PAGE_BITS; + int position = groupId & VALUES_PAGE_MASK; hashStrategy.appendTo(blockIndex, position, pageBuilder, 0); } @@ -263,11 +253,11 @@ public boolean contains(int position, Page page, int[] hashChannels) @Override public boolean contains(int position, Page page, int[] hashChannels, long rawHash) { - int hashPosition = (int) getHashPosition(rawHash, mask); + int hashPosition = getHashPosition(rawHash, mask); // look for a slot containing this key - while (groupAddressByHash[hashPosition] != -1) { - if (positionNotDistinctFromCurrentRow(groupAddressByHash[hashPosition], hashPosition, position, page, (byte) rawHash, hashChannels)) { + while (groupIdsByHash[hashPosition] != -1) { + if (positionNotDistinctFromCurrentRow(groupIdsByHash[hashPosition], hashPosition, position, page, (byte) rawHash, hashChannels)) { // found an existing slot for this key return true; } @@ -293,12 +283,12 @@ private int putIfAbsent(int position, Page page) private int putIfAbsent(int position, Page page, long rawHash) { - int hashPosition = (int) getHashPosition(rawHash, mask); + int hashPosition = getHashPosition(rawHash, mask); // look for an empty slot or a slot containing this key int groupId = -1; - while (groupAddressByHash[hashPosition] != -1) { - if (positionNotDistinctFromCurrentRow(groupAddressByHash[hashPosition], hashPosition, position, page, (byte) rawHash, channels)) { + while (groupIdsByHash[hashPosition] != -1) { + if (positionNotDistinctFromCurrentRow(groupIdsByHash[hashPosition], hashPosition, position, page, (byte) rawHash, channels)) { // found an existing slot for this key groupId = groupIdsByHash[hashPosition]; @@ -337,13 +327,11 @@ private int addNewGroup(int hashPosition, int position, Page page, long rawHash) // record group id in hash int groupId = nextGroupId++; - groupAddressByHash[hashPosition] = address; rawHashByHashPosition[hashPosition] = (byte) rawHash; groupIdsByHash[hashPosition] = groupId; - groupAddressByGroupId.set(groupId, address); // create new page builder if this page is full - if (currentPageBuilder.isFull()) { + if (currentPageBuilder.getPositionCount() == VALUES_PAGE_MAX_ROW_COUNT) { startNewPage(); } @@ -363,6 +351,7 @@ private void startNewPage() { if (currentPageBuilder != null) { completedPagesMemorySize += currentPageBuilder.getRetainedSizeInBytes(); + // TODO: (https://github.com/trinodb/trino/issues/12484) pre-size new PageBuilder to OUTPUT_PAGE_SIZE currentPageBuilder = currentPageBuilder.newPageBuilderLike(); } else { @@ -383,10 +372,9 @@ private boolean tryRehash() int newCapacity = toIntExact(newCapacityLong); // An estimate of how much extra memory is needed before we can go ahead and expand the hash table. - // This includes the new capacity for groupAddressByHash, rawHashByHashPosition, groupIdsByHash, and groupAddressByGroupId as well as the size of the current page - preallocatedMemoryInBytes = (newCapacity - hashCapacity) * (long) (Long.BYTES + Integer.BYTES + Byte.BYTES) + - (long) (calculateMaxFill(newCapacity) - maxFill) * Long.BYTES + - currentPageSizeInBytes; + // This includes the new capacity for rawHashByHashPosition, groupIdsByHash as well as the size of the current page + preallocatedMemoryInBytes = (newCapacity - hashCapacity) * (long) (Integer.BYTES + Byte.BYTES) + + currentPageSizeInBytes; if (!updateMemory.update()) { // reserved memory but has exceeded the limit return false; @@ -396,54 +384,46 @@ private boolean tryRehash() expectedHashCollisions += estimateNumberOfHashCollisions(getGroupCount(), hashCapacity); int newMask = newCapacity - 1; - long[] newKey = new long[newCapacity]; byte[] rawHashes = new byte[newCapacity]; - Arrays.fill(newKey, -1); - int[] newValue = new int[newCapacity]; + int[] newGroupIdByHash = new int[newCapacity]; + Arrays.fill(newGroupIdByHash, -1); - int oldIndex = 0; - for (int groupId = 0; groupId < nextGroupId; groupId++) { + for (int i = 0; i < hashCapacity; i++) { // seek to the next used slot - while (groupAddressByHash[oldIndex] == -1) { - oldIndex++; + int groupId = groupIdsByHash[i]; + if (groupId == -1) { + continue; } - // get the address for this slot - long address = groupAddressByHash[oldIndex]; - - long rawHash = hashPosition(address); + long rawHash = hashPosition(groupId); // find an empty slot for the address - int pos = (int) getHashPosition(rawHash, newMask); - while (newKey[pos] != -1) { + int pos = getHashPosition(rawHash, newMask); + while (newGroupIdByHash[pos] != -1) { pos = (pos + 1) & newMask; hashCollisions++; } // record the mapping - newKey[pos] = address; rawHashes[pos] = (byte) rawHash; - newValue[pos] = groupIdsByHash[oldIndex]; - oldIndex++; + newGroupIdByHash[pos] = groupId; } this.mask = newMask; this.hashCapacity = newCapacity; this.maxFill = calculateMaxFill(newCapacity); - this.groupAddressByHash = newKey; this.rawHashByHashPosition = rawHashes; - this.groupIdsByHash = newValue; - groupAddressByGroupId.ensureCapacity(maxFill); + this.groupIdsByHash = newGroupIdByHash; return true; } - private long hashPosition(long sliceAddress) + private long hashPosition(int groupId) { - int sliceIndex = decodeSliceIndex(sliceAddress); - int position = decodePosition(sliceAddress); + int blockIndex = groupId >> VALUES_PAGE_BITS; + int blockPosition = groupId & VALUES_PAGE_MASK; if (precomputedHashChannel.isPresent()) { - return getRawHash(sliceIndex, position, precomputedHashChannel.getAsInt()); + return getRawHash(blockIndex, blockPosition, precomputedHashChannel.getAsInt()); } - return hashStrategy.hashPosition(sliceIndex, position); + return hashStrategy.hashPosition(blockIndex, blockPosition); } private long getRawHash(int sliceIndex, int position, int hashChannel) @@ -451,17 +431,19 @@ private long getRawHash(int sliceIndex, int position, int hashChannel) return channelBuilders.get(hashChannel).get(sliceIndex).getLong(position, 0); } - private boolean positionNotDistinctFromCurrentRow(long address, int hashPosition, int position, Page page, byte rawHash, int[] hashChannels) + private boolean positionNotDistinctFromCurrentRow(int groupId, int hashPosition, int position, Page page, byte rawHash, int[] hashChannels) { if (rawHashByHashPosition[hashPosition] != rawHash) { return false; } - return hashStrategy.positionNotDistinctFromRow(decodeSliceIndex(address), decodePosition(address), position, page, hashChannels); + int blockIndex = groupId >> VALUES_PAGE_BITS; + int blockPosition = groupId & VALUES_PAGE_MASK; + return hashStrategy.positionNotDistinctFromRow(blockIndex, blockPosition, position, page, hashChannels); } - private static long getHashPosition(long rawHash, int mask) + private static int getHashPosition(long rawHash, int mask) { - return murmurHash3(rawHash) & mask; + return (int) (murmurHash3(rawHash) & mask); // mask is int so casting is safe } private static int calculateMaxFill(int hashSize) @@ -598,7 +580,6 @@ class AddNonDictionaryPageWork implements Work { private final Page page; - private int lastPosition; public AddNonDictionaryPageWork(Page page) @@ -611,21 +592,23 @@ public boolean process() { int positionCount = page.getPositionCount(); checkState(lastPosition < positionCount, "position count out of bound"); + int remainingPositions = positionCount - lastPosition; - // needRehash() == true indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - // get the group for the current row - putIfAbsent(lastPosition, page); - lastPosition++; + for (int i = lastPosition; i < lastPosition + batchSize; i++) { + putIfAbsent(i, page); + } + + lastPosition += batchSize; + remainingPositions -= batchSize; } - return lastPosition == positionCount; + verify(lastPosition == positionCount); + return true; } @Override @@ -777,7 +760,7 @@ public Void getResult() class GetNonDictionaryGroupIdsWork implements Work { - private final BlockBuilder blockBuilder; + private final long[] groupIds; private final Page page; private boolean finished; @@ -787,30 +770,34 @@ public GetNonDictionaryGroupIdsWork(Page page) { this.page = requireNonNull(page, "page is null"); // we know the exact size required for the block - this.blockBuilder = BIGINT.createFixedSizeBlockBuilder(page.getPositionCount()); + groupIds = new long[page.getPositionCount()]; } @Override public boolean process() { int positionCount = page.getPositionCount(); - checkState(lastPosition <= positionCount, "position count out of bound"); + checkState(lastPosition < positionCount, "position count out of bound"); checkState(!finished); - // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } + int remainingPositions = positionCount - lastPosition; - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - // output the group id for this row - BIGINT.writeLong(blockBuilder, putIfAbsent(lastPosition, page)); - lastPosition++; + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } + + for (int i = lastPosition; i < lastPosition + batchSize; i++) { + // output the group id for this row + groupIds[i] = putIfAbsent(i, page); + } + + lastPosition += batchSize; + remainingPositions -= batchSize; } - return lastPosition == positionCount; + verify(lastPosition == positionCount); + return true; } @Override @@ -819,7 +806,7 @@ public GroupByIdBlock getResult() checkState(lastPosition == page.getPositionCount(), "process has not yet finished"); checkState(!finished, "result has produced"); finished = true; - return new GroupByIdBlock(nextGroupId, blockBuilder.build()); + return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds)); } } @@ -889,7 +876,7 @@ public GroupByIdBlock getResult() class GetDictionaryGroupIdsWork implements Work { - private final BlockBuilder blockBuilder; + private final long[] groupIds; private final Page page; private final Page dictionaryPage; private final DictionaryBlock dictionaryBlock; @@ -905,9 +892,7 @@ public GetDictionaryGroupIdsWork(Page page) this.dictionaryBlock = (DictionaryBlock) page.getBlock(channels[0]); updateDictionaryLookBack(dictionaryBlock.getDictionary()); this.dictionaryPage = createPageWithExtractedDictionary(page); - - // we know the exact size required for the block - this.blockBuilder = BIGINT.createFixedSizeBlockBuilder(page.getPositionCount()); + groupIds = new long[page.getPositionCount()]; } @Override @@ -927,8 +912,7 @@ public boolean process() // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. while (lastPosition < positionCount && !needRehash()) { int positionInDictionary = dictionaryBlock.getId(lastPosition); - int groupId = registerGroupId(hashGenerator, dictionaryPage, positionInDictionary); - BIGINT.writeLong(blockBuilder, groupId); + groupIds[lastPosition] = registerGroupId(hashGenerator, dictionaryPage, positionInDictionary); lastPosition++; } return lastPosition == positionCount; @@ -940,7 +924,7 @@ public GroupByIdBlock getResult() checkState(lastPosition == page.getPositionCount(), "process has not yet finished"); checkState(!finished, "result has produced"); finished = true; - return new GroupByIdBlock(nextGroupId, blockBuilder.build()); + return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds)); } } @@ -1044,4 +1028,16 @@ private int calculatePositionToCombinationIdMapping(Page page, short[] positionT } return maxCardinality; } + + private boolean ensureHashTableSize(int batchSize) + { + int positionCountUntilRehash = maxFill - nextGroupId; + while (positionCountUntilRehash < batchSize) { + if (!tryRehash()) { + return false; + } + positionCountUntilRehash = maxFill - nextGroupId; + } + return true; + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java b/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java index a3e30fba19ef..e1f8f85ce0c4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java +++ b/core/trino-main/src/test/java/io/trino/operator/GroupByHashYieldAssertion.java @@ -127,7 +127,10 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< // free the pool for the next iteration memoryPool.free(anotherTaskId, "test", reservedMemoryInBytes); // this required in case input is blocked - operator.getOutput(); + output = operator.getOutput(); + if (output != null) { + result.add(output); + } continue; } @@ -140,7 +143,7 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone()); // assert the hash capacity is not changed; otherwise, we should have yielded - assertTrue(oldCapacity == getHashCapacity.apply(operator)); + assertEquals((int) getHashCapacity.apply(operator), oldCapacity); // We are not going to rehash; therefore, assert the memory increase only comes from the aggregator assertLessThan(actualIncreasedMemory, additionalMemoryInBytes); @@ -164,8 +167,8 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES) + page.getRetainedSizeInBytes(); } else { - // groupAddressByHash, groupIdsByHash, and rawHashByHashPosition double by hashCapacity; while groupAddressByGroupId double by maxFill = hashCapacity / 0.75 - expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES + Byte.BYTES) + page.getRetainedSizeInBytes(); + // groupIdsByHash, and rawHashByHashPosition double by hashCapacity + expectedReservedExtraBytes = oldCapacity * (long) (Integer.BYTES + Byte.BYTES); } assertBetweenInclusive(actualIncreasedMemory, expectedReservedExtraBytes, expectedReservedExtraBytes + additionalMemoryInBytes); @@ -187,10 +190,24 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< // Assert the estimated reserved memory before rehash is very close to the one after rehash long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage(); - assertBetweenInclusive(rehashedMemoryUsage * 1.0 / newMemoryUsage, 0.99, 1.01); + double memoryUsageErrorUpperBound = 1.01; + double memoryUsageError = rehashedMemoryUsage * 1.0 / newMemoryUsage; + if (memoryUsageError > memoryUsageErrorUpperBound) { + // Usually the error is < 1%, but since MultiChannelGroupByHash.getEstimatedSize + // accounts for changes in completedPagesMemorySize, which is increased if new page is + // added by addNewGroup (an even that cannot be predicted as it depends on the number of unique groups + // in the current page being processed), the difference includes size of the added new page. + // Lower bound is 1% lower than normal because additionalMemoryInBytes includes also aggregator state. + assertBetweenInclusive(rehashedMemoryUsage * 1.0 / (newMemoryUsage + additionalMemoryInBytes), 0.98, memoryUsageErrorUpperBound, + "rehashedMemoryUsage " + rehashedMemoryUsage + ", newMemoryUsage: " + newMemoryUsage); + } + else { + assertBetweenInclusive(memoryUsageError, 0.99, memoryUsageErrorUpperBound); + } // unblocked assertTrue(operator.needsInput()); + assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone()); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDistinctLimitOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestDistinctLimitOperator.java index e5ff25fa33bd..c02353d65aee 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDistinctLimitOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDistinctLimitOperator.java @@ -34,7 +34,7 @@ import java.util.concurrent.ScheduledExecutorService; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.trino.RowPagesBuilder.rowPagesBuilder; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; @@ -167,9 +167,9 @@ public void testMemoryReservationYield(Type type) joinCompiler, blockTypeOperators); - GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((DistinctLimitOperator) operator).getCapacity(), 1_400_000); - assertGreaterThan(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((DistinctLimitOperator) operator).getCapacity(), 450_000); + assertGreaterThanOrEqual(result.getYieldCount(), 5); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20); assertEquals(result.getOutput().stream().mapToInt(Page::getPositionCount).sum(), 6_000 * 600); } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java index 496b6eb8d066..3f05b637aaf2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHashAggregationOperator.java @@ -62,6 +62,7 @@ import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.airlift.units.DataSize.succinctBytes; @@ -410,9 +411,9 @@ public void testMemoryReservationYield(Type type) // get result with yield; pick a relatively small buffer for aggregator's memory usage GroupByHashYieldResult result; - result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, this::getHashCapacity, 1_400_000); - assertGreaterThan(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, this::getHashCapacity, 450_000); + assertGreaterThanOrEqual(result.getYieldCount(), 5); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20); int count = 0; for (Page page : result.getOutput()) { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java index b124080c2280..ce5b31dfa0f8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java @@ -38,7 +38,6 @@ import static com.google.common.collect.Iterables.concat; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.testing.Assertions.assertGreaterThan; import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.trino.RowPagesBuilder.rowPagesBuilder; import static io.trino.SessionTestUtils.TEST_SESSION; @@ -244,10 +243,10 @@ public void testSemiJoinMemoryReservationYield(Type type) type, setBuilderOperatorFactory, operator -> ((SetBuilderOperator) operator).getCapacity(), - 1_400_000); + 450_000); - assertGreaterThanOrEqual(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + assertGreaterThanOrEqual(result.getYieldCount(), 4); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 19); assertEquals(result.getOutput().stream().mapToInt(Page::getPositionCount).sum(), 0); } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestMarkDistinctOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestMarkDistinctOperator.java index 3ee3e6f9d816..a5807719e758 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestMarkDistinctOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestMarkDistinctOperator.java @@ -38,7 +38,7 @@ import static com.google.common.base.Throwables.throwIfUnchecked; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.airlift.testing.Assertions.assertInstanceOf; import static io.trino.RowPagesBuilder.rowPagesBuilder; import static io.trino.SessionTestUtils.TEST_SESSION; @@ -176,9 +176,9 @@ public void testMemoryReservationYield(Type type) OperatorFactory operatorFactory = new MarkDistinctOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(type), ImmutableList.of(0), Optional.of(1), joinCompiler, blockTypeOperators); // get result with yield; pick a relatively small buffer for partitionRowCount's memory usage - GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((MarkDistinctOperator) operator).getCapacity(), 1_400_000); - assertGreaterThan(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((MarkDistinctOperator) operator).getCapacity(), 450_000); + assertGreaterThanOrEqual(result.getYieldCount(), 5); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20); int count = 0; for (Page page : result.getOutput()) { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestRowNumberOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestRowNumberOperator.java index 5e2a62569ad0..2c8bad1809b4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestRowNumberOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestRowNumberOperator.java @@ -40,7 +40,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; -import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.trino.RowPagesBuilder.rowPagesBuilder; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; @@ -171,9 +171,9 @@ public void testMemoryReservationYield(Type type) blockTypeOperators); // get result with yield; pick a relatively small buffer for partitionRowCount's memory usage - GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((RowNumberOperator) operator).getCapacity(), 1_400_000); - assertGreaterThan(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((RowNumberOperator) operator).getCapacity(), 280_000); + assertGreaterThanOrEqual(result.getYieldCount(), 5); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20); int count = 0; for (Page page : result.getOutput()) { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestTopNRankingOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestTopNRankingOperator.java index c71f50c97233..714bdb047c29 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestTopNRankingOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestTopNRankingOperator.java @@ -302,7 +302,7 @@ public void testMemoryReservationYield() type, operatorFactory, operator -> ((TopNRankingOperator) operator).getGroupedTopNBuilder() == null ? 0 : ((GroupedTopNRowNumberBuilder) ((TopNRankingOperator) operator).getGroupedTopNBuilder()).getGroupByHash().getCapacity(), - 1_000_000); + 450_000); assertGreaterThan(result.getYieldCount(), 3); assertGreaterThan(result.getMaxReservedBytes(), 5L << 20); diff --git a/testing/trino-tests/src/test/java/io/trino/memory/TestMemoryManager.java b/testing/trino-tests/src/test/java/io/trino/memory/TestMemoryManager.java index 2bbe55477b09..5c0dedb5220b 100644 --- a/testing/trino-tests/src/test/java/io/trino/memory/TestMemoryManager.java +++ b/testing/trino-tests/src/test/java/io/trino/memory/TestMemoryManager.java @@ -234,7 +234,8 @@ public void testClusterPools() List> queryFutures = new ArrayList<>(); for (int i = 0; i < 2; i++) { - queryFutures.add(executor.submit(() -> queryRunner.execute("SELECT COUNT(*), clerk FROM orders GROUP BY clerk"))); + // for this test to work, the query has to have enough groups for HashAggregationOperator to go over QueryContext.GUARANTEED_MEMORY + queryFutures.add(executor.submit(() -> queryRunner.execute("SELECT COUNT(*), cast(orderkey as varchar), partkey FROM lineitem GROUP BY cast(orderkey as varchar), partkey"))); } ClusterMemoryManager memoryManager = queryRunner.getCoordinator().getClusterMemoryManager();