diff --git a/presto-main/src/main/java/io/prestosql/operator/BigintGroupByHash.java b/presto-main/src/main/java/io/prestosql/operator/BigintGroupByHash.java index 62b263b5140f..856f9dc4c2ee 100644 --- a/presto-main/src/main/java/io/prestosql/operator/BigintGroupByHash.java +++ b/presto-main/src/main/java/io/prestosql/operator/BigintGroupByHash.java @@ -22,11 +22,14 @@ import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; import io.prestosql.spi.block.BlockBuilder; +import io.prestosql.spi.block.DictionaryBlock; +import io.prestosql.spi.block.RunLengthEncodedBlock; import io.prestosql.spi.type.AbstractLongType; import io.prestosql.spi.type.BigintType; import io.prestosql.spi.type.Type; import org.openjdk.jol.info.ClassLayout; +import java.util.Arrays; import java.util.List; import static com.google.common.base.Preconditions.checkArgument; @@ -67,6 +70,7 @@ public class BigintGroupByHash private final LongBigArray valuesByGroupId; private int nextGroupId; + private DictionaryLookBack dictionaryLookBack; private long hashCollisions; private double expectedHashCollisions; @@ -161,13 +165,29 @@ public void appendValuesTo(int groupId, PageBuilder pageBuilder, int outputChann public Work addPage(Page page) { currentPageSizeInBytes = page.getRetainedSizeInBytes(); - return new AddPageWork(page.getBlock(hashChannel)); + Block block = page.getBlock(hashChannel); + if (block instanceof RunLengthEncodedBlock) { + return new AddRunLengthEncodedPageWork((RunLengthEncodedBlock) block); + } + if (block instanceof DictionaryBlock) { + return new AddDictionaryPageWork((DictionaryBlock) block); + } + + return new AddPageWork(block); } @Override public Work getGroupIds(Page page) { currentPageSizeInBytes = page.getRetainedSizeInBytes(); + Block block = page.getBlock(hashChannel); + if (block instanceof RunLengthEncodedBlock) { + return new GetRunLengthEncodedGroupIdsWork((RunLengthEncodedBlock) block); + } + if (block instanceof DictionaryBlock) { + return new GetDictionaryGroupIdsWork((DictionaryBlock) block); + } + return new GetGroupIdsWork(page.getBlock(hashChannel)); } @@ -269,7 +289,7 @@ private boolean tryRehash() // An estimate of how much extra memory is needed before we can go ahead and expand the hash table. // This includes the new capacity for values, groupIds, and valuesByGroupId as well as the size of the current page - preallocatedMemoryInBytes = (newCapacity - hashCapacity) * (long) (Long.BYTES + Integer.BYTES) + (calculateMaxFill(newCapacity) - maxFill) * Long.BYTES + currentPageSizeInBytes; + preallocatedMemoryInBytes = (newCapacity - hashCapacity) * (long) (Long.BYTES + Integer.BYTES) + (long) (calculateMaxFill(newCapacity) - maxFill) * Long.BYTES + currentPageSizeInBytes; if (!updateMemory.update()) { // reserved memory but has exceeded the limit return false; @@ -333,7 +353,26 @@ private static int calculateMaxFill(int hashSize) return maxFill; } - private class AddPageWork + private void updateDictionaryLookBack(Block dictionary) + { + if (dictionaryLookBack == null || dictionaryLookBack.getDictionary() != dictionary) { + dictionaryLookBack = new DictionaryLookBack(dictionary); + } + } + + private int getGroupId(Block dictionary, int positionInDictionary) + { + if (dictionaryLookBack.isProcessed(positionInDictionary)) { + return dictionaryLookBack.getGroupId(positionInDictionary); + } + + int groupId = putIfAbsent(positionInDictionary, dictionary); + dictionaryLookBack.setProcessed(positionInDictionary, groupId); + return groupId; + } + + @VisibleForTesting + class AddPageWork implements Work { private final Block block; @@ -349,7 +388,7 @@ public AddPageWork(Block block) public boolean process() { int positionCount = block.getPositionCount(); - checkState(lastPosition < positionCount, "position count out of bound"); + checkState(lastPosition <= positionCount, "position count out of bound"); // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. // We can only proceed if tryRehash() successfully did a rehash. @@ -374,7 +413,95 @@ public Void getResult() } } - private class GetGroupIdsWork + @VisibleForTesting + class AddDictionaryPageWork + implements Work + { + private final Block dictionary; + private final DictionaryBlock block; + + private int lastPosition; + + public AddDictionaryPageWork(DictionaryBlock block) + { + this.block = requireNonNull(block, "block is null"); + this.dictionary = block.getDictionary(); + updateDictionaryLookBack(dictionary); + } + + @Override + public boolean process() + { + int positionCount = block.getPositionCount(); + checkState(lastPosition <= positionCount, "position count out of bound"); + + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. + // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. + while (lastPosition < positionCount && !needRehash()) { + int positionInDictionary = block.getId(lastPosition); + getGroupId(dictionary, positionInDictionary); + lastPosition++; + } + return lastPosition == positionCount; + } + + @Override + public Void getResult() + { + throw new UnsupportedOperationException(); + } + } + + @VisibleForTesting + class AddRunLengthEncodedPageWork + implements Work + { + private final RunLengthEncodedBlock block; + + private boolean finished; + + public AddRunLengthEncodedPageWork(RunLengthEncodedBlock block) + { + this.block = requireNonNull(block, "block is null"); + } + + @Override + public boolean process() + { + checkState(!finished); + if (block.getPositionCount() == 0) { + finished = true; + return true; + } + + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + // Only needs to process the first row since it is Run Length Encoded + putIfAbsent(0, block.getValue()); + finished = true; + + return true; + } + + @Override + public Void getResult() + { + throw new UnsupportedOperationException(); + } + } + + @VisibleForTesting + class GetGroupIdsWork implements Work { private final BlockBuilder blockBuilder; @@ -394,7 +521,7 @@ public GetGroupIdsWork(Block block) public boolean process() { int positionCount = block.getPositionCount(); - checkState(lastPosition < positionCount, "position count out of bound"); + checkState(lastPosition <= positionCount, "position count out of bound"); checkState(!finished); // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. @@ -422,4 +549,143 @@ public GroupByIdBlock getResult() return new GroupByIdBlock(nextGroupId, blockBuilder.build()); } } + + @VisibleForTesting + class GetDictionaryGroupIdsWork + implements Work + { + private final BlockBuilder blockBuilder; + private final Block dictionary; + private final DictionaryBlock block; + + private boolean finished; + private int lastPosition; + + public GetDictionaryGroupIdsWork(DictionaryBlock block) + { + this.block = requireNonNull(block, "block is null"); + this.dictionary = block.getDictionary(); + updateDictionaryLookBack(dictionary); + + // we know the exact size required for the block + this.blockBuilder = BIGINT.createFixedSizeBlockBuilder(block.getPositionCount()); + } + + @Override + public boolean process() + { + int positionCount = block.getPositionCount(); + checkState(lastPosition <= positionCount, "position count out of bound"); + checkState(!finished); + + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. + // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. + while (lastPosition < positionCount && !needRehash()) { + int positionInDictionary = block.getId(lastPosition); + int groupId = getGroupId(dictionary, positionInDictionary); + BIGINT.writeLong(blockBuilder, groupId); + lastPosition++; + } + return lastPosition == positionCount; + } + + @Override + public GroupByIdBlock getResult() + { + checkState(lastPosition == block.getPositionCount(), "process has not yet finished"); + checkState(!finished, "result has produced"); + finished = true; + return new GroupByIdBlock(nextGroupId, blockBuilder.build()); + } + } + + @VisibleForTesting + class GetRunLengthEncodedGroupIdsWork + implements Work + { + private final RunLengthEncodedBlock block; + + int groupId = -1; + private boolean processFinished; + private boolean resultProduced; + + public GetRunLengthEncodedGroupIdsWork(RunLengthEncodedBlock block) + { + this.block = requireNonNull(block, "block is null"); + } + + @Override + public boolean process() + { + checkState(!processFinished); + if (block.getPositionCount() == 0) { + processFinished = true; + return true; + } + + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + // Only needs to process the first row since it is Run Length Encoded + groupId = putIfAbsent(0, block.getValue()); + processFinished = true; + return true; + } + + @Override + public GroupByIdBlock getResult() + { + checkState(processFinished); + checkState(!resultProduced); + resultProduced = true; + + return new GroupByIdBlock( + nextGroupId, + new RunLengthEncodedBlock( + BIGINT.createFixedSizeBlockBuilder(1).writeLong(groupId).build(), + block.getPositionCount())); + } + } + + private static final class DictionaryLookBack + { + private final Block dictionary; + private final int[] processed; + + public DictionaryLookBack(Block dictionary) + { + this.dictionary = dictionary; + this.processed = new int[dictionary.getPositionCount()]; + Arrays.fill(processed, -1); + } + + public Block getDictionary() + { + return dictionary; + } + + public int getGroupId(int position) + { + return processed[position]; + } + + public boolean isProcessed(int position) + { + return processed[position] != -1; + } + + public void setProcessed(int position, int groupId) + { + processed[position] = groupId; + } + } } diff --git a/presto-main/src/main/java/io/prestosql/operator/MultiChannelGroupByHash.java b/presto-main/src/main/java/io/prestosql/operator/MultiChannelGroupByHash.java index b92f94b00cb3..28b1ea02a90b 100644 --- a/presto-main/src/main/java/io/prestosql/operator/MultiChannelGroupByHash.java +++ b/presto-main/src/main/java/io/prestosql/operator/MultiChannelGroupByHash.java @@ -16,13 +16,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; -import io.prestosql.array.LongBigArray; import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.block.DictionaryBlock; +import io.prestosql.spi.block.LongArrayBlock; import io.prestosql.spi.block.RunLengthEncodedBlock; import io.prestosql.spi.type.Type; import io.prestosql.sql.gen.JoinCompiler; @@ -30,6 +29,8 @@ import it.unimi.dsi.fastutil.objects.ObjectArrayList; import org.openjdk.jol.info.ClassLayout; +import javax.annotation.Nullable; + import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -39,8 +40,6 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static io.airlift.slice.SizeOf.sizeOf; -import static io.prestosql.operator.SyntheticAddress.decodePosition; -import static io.prestosql.operator.SyntheticAddress.decodeSliceIndex; import static io.prestosql.operator.SyntheticAddress.encodeSyntheticAddress; import static io.prestosql.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; import static io.prestosql.spi.type.BigintType.BIGINT; @@ -48,6 +47,8 @@ import static io.prestosql.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static it.unimi.dsi.fastutil.HashCommon.murmurHash3; +import static java.lang.Math.min; +import static java.lang.Math.multiplyExact; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -57,6 +58,13 @@ public class MultiChannelGroupByHash { private static final int INSTANCE_SIZE = ClassLayout.parseClass(MultiChannelGroupByHash.class).instanceSize(); private static final float FILL_RATIO = 0.75f; + private static final int BATCH_SIZE = 1024; + // Max (page value count / cumulative dictionary size) to trigger the low cardinality case + private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = .25; + private static final int VALUES_PAGE_BITS = 14; // 16k positions + private static final int VALUES_PAGE_MAX_ROW_COUNT = 1 << VALUES_PAGE_BITS; + private static final int VALUES_PAGE_MASK = VALUES_PAGE_MAX_ROW_COUNT - 1; + private final List types; private final List hashTypes; private final int[] channels; @@ -74,12 +82,11 @@ public class MultiChannelGroupByHash private int hashCapacity; private int maxFill; private int mask; - private long[] groupAddressByHash; + // Group ids are assigned incrementally. Therefore, since values page size is constant and power of two, + // the group id is also an address (slice index and position within slice) to group row in channelBuilders. private int[] groupIdsByHash; private byte[] rawHashByHashPosition; - private final LongBigArray groupAddressByGroupId; - private int nextGroupId; private DictionaryLookBack dictionaryLookBack; private long hashCollisions; @@ -141,15 +148,11 @@ public MultiChannelGroupByHash( maxFill = calculateMaxFill(hashCapacity); mask = hashCapacity - 1; - groupAddressByHash = new long[hashCapacity]; - Arrays.fill(groupAddressByHash, -1); rawHashByHashPosition = new byte[hashCapacity]; - groupIdsByHash = new int[hashCapacity]; - groupAddressByGroupId = new LongBigArray(); - groupAddressByGroupId.ensureCapacity(maxFill); + Arrays.fill(groupIdsByHash, -1); // This interface is used for actively reserving memory (push model) for rehash. // The caller can also query memory usage on this object (pull model) @@ -159,9 +162,8 @@ public MultiChannelGroupByHash( @Override public long getRawHash(int groupId) { - long address = groupAddressByGroupId.get(groupId); - int blockIndex = decodeSliceIndex(address); - int position = decodePosition(address); + int blockIndex = groupId >> VALUES_PAGE_BITS; + int position = groupId & VALUES_PAGE_MASK; return hashStrategy.hashPosition(blockIndex, position); } @@ -172,9 +174,7 @@ public long getEstimatedSize() (sizeOf(channelBuilders.get(0).elements()) * channelBuilders.size()) + completedPagesMemorySize + currentPageBuilder.getRetainedSizeInBytes() + - sizeOf(groupAddressByHash) + sizeOf(groupIdsByHash) + - groupAddressByGroupId.sizeOf() + sizeOf(rawHashByHashPosition) + preallocatedMemoryInBytes; } @@ -206,9 +206,8 @@ public int getGroupCount() @Override public void appendValuesTo(int groupId, PageBuilder pageBuilder, int outputChannelOffset) { - long address = groupAddressByGroupId.get(groupId); - int blockIndex = decodeSliceIndex(address); - int position = decodePosition(address); + int blockIndex = groupId >> VALUES_PAGE_BITS; + int position = groupId & VALUES_PAGE_MASK; hashStrategy.appendTo(blockIndex, position, pageBuilder, outputChannelOffset); } @@ -222,6 +221,9 @@ public Work addPage(Page page) if (canProcessDictionary(page)) { return new AddDictionaryPageWork(page); } + if (canProcessLowCardinalityDictionary(page)) { + return new AddLowCardinalityDictionaryPageWork(page); + } return new AddNonDictionaryPageWork(page); } @@ -236,6 +238,9 @@ public Work getGroupIds(Page page) if (canProcessDictionary(page)) { return new GetDictionaryGroupIdsWork(page); } + if (canProcessLowCardinalityDictionary(page)) { + return new GetLowCardinalityDictionaryGroupIdsWork(page); + } return new GetNonDictionaryGroupIdsWork(page); } @@ -250,11 +255,11 @@ public boolean contains(int position, Page page, int[] hashChannels) @Override public boolean contains(int position, Page page, int[] hashChannels, long rawHash) { - int hashPosition = (int) getHashPosition(rawHash, mask); + int hashPosition = getHashPosition(rawHash, mask); // look for a slot containing this key - while (groupAddressByHash[hashPosition] != -1) { - if (positionNotDistinctFromCurrentRow(groupAddressByHash[hashPosition], hashPosition, position, page, (byte) rawHash, hashChannels)) { + while (groupIdsByHash[hashPosition] != -1) { + if (positionNotDistinctFromCurrentRow(groupIdsByHash[hashPosition], hashPosition, position, page, (byte) rawHash, hashChannels)) { // found an existing slot for this key return true; } @@ -280,12 +285,12 @@ private int putIfAbsent(int position, Page page) private int putIfAbsent(int position, Page page, long rawHash) { - int hashPosition = (int) getHashPosition(rawHash, mask); + int hashPosition = getHashPosition(rawHash, mask); // look for an empty slot or a slot containing this key int groupId = -1; - while (groupAddressByHash[hashPosition] != -1) { - if (positionNotDistinctFromCurrentRow(groupAddressByHash[hashPosition], hashPosition, position, page, (byte) rawHash, channels)) { + while (groupIdsByHash[hashPosition] != -1) { + if (positionNotDistinctFromCurrentRow(groupIdsByHash[hashPosition], hashPosition, position, page, (byte) rawHash, channels)) { // found an existing slot for this key groupId = groupIdsByHash[hashPosition]; @@ -322,13 +327,11 @@ private int addNewGroup(int hashPosition, int position, Page page, long rawHash) // record group id in hash int groupId = nextGroupId++; - groupAddressByHash[hashPosition] = address; rawHashByHashPosition[hashPosition] = (byte) rawHash; groupIdsByHash[hashPosition] = groupId; - groupAddressByGroupId.set(groupId, address); // create new page builder if this page is full - if (currentPageBuilder.isFull()) { + if (currentPageBuilder.getPositionCount() == VALUES_PAGE_MAX_ROW_COUNT) { startNewPage(); } @@ -348,6 +351,7 @@ private void startNewPage() { if (currentPageBuilder != null) { completedPagesMemorySize += currentPageBuilder.getRetainedSizeInBytes(); + // TODO: (https://github.com/trinodb/trino/issues/12484) pre-size new PageBuilder to OUTPUT_PAGE_SIZE currentPageBuilder = currentPageBuilder.newPageBuilderLike(); } else { @@ -368,10 +372,9 @@ private boolean tryRehash() int newCapacity = toIntExact(newCapacityLong); // An estimate of how much extra memory is needed before we can go ahead and expand the hash table. - // This includes the new capacity for groupAddressByHash, rawHashByHashPosition, groupIdsByHash, and groupAddressByGroupId as well as the size of the current page - preallocatedMemoryInBytes = (newCapacity - hashCapacity) * (long) (Long.BYTES + Integer.BYTES + Byte.BYTES) + - (calculateMaxFill(newCapacity) - maxFill) * Long.BYTES + - currentPageSizeInBytes; + // This includes the new capacity for rawHashByHashPosition, groupIdsByHash as well as the size of the current page + preallocatedMemoryInBytes = (newCapacity - hashCapacity) * (long) (Integer.BYTES + Byte.BYTES) + + currentPageSizeInBytes; if (!updateMemory.update()) { // reserved memory but has exceeded the limit return false; @@ -381,72 +384,66 @@ private boolean tryRehash() expectedHashCollisions += estimateNumberOfHashCollisions(getGroupCount(), hashCapacity); int newMask = newCapacity - 1; - long[] newKey = new long[newCapacity]; byte[] rawHashes = new byte[newCapacity]; - Arrays.fill(newKey, -1); - int[] newValue = new int[newCapacity]; + int[] newGroupIdByHash = new int[newCapacity]; + Arrays.fill(newGroupIdByHash, -1); - int oldIndex = 0; - for (int groupId = 0; groupId < nextGroupId; groupId++) { + for (int i = 0; i < hashCapacity; i++) { // seek to the next used slot - while (groupAddressByHash[oldIndex] == -1) { - oldIndex++; + int groupId = groupIdsByHash[i]; + if (groupId == -1) { + continue; } - // get the address for this slot - long address = groupAddressByHash[oldIndex]; - - long rawHash = hashPosition(address); + long rawHash = hashPosition(groupId); // find an empty slot for the address - int pos = (int) getHashPosition(rawHash, newMask); - while (newKey[pos] != -1) { + int pos = getHashPosition(rawHash, newMask); + while (newGroupIdByHash[pos] != -1) { pos = (pos + 1) & newMask; hashCollisions++; } // record the mapping - newKey[pos] = address; rawHashes[pos] = (byte) rawHash; - newValue[pos] = groupIdsByHash[oldIndex]; - oldIndex++; + newGroupIdByHash[pos] = groupId; } this.mask = newMask; this.hashCapacity = newCapacity; this.maxFill = calculateMaxFill(newCapacity); - this.groupAddressByHash = newKey; this.rawHashByHashPosition = rawHashes; - this.groupIdsByHash = newValue; - groupAddressByGroupId.ensureCapacity(maxFill); + this.groupIdsByHash = newGroupIdByHash; return true; } - private long hashPosition(long sliceAddress) + private long hashPosition(int groupId) { - int sliceIndex = decodeSliceIndex(sliceAddress); - int position = decodePosition(sliceAddress); + int blockIndex = groupId >> VALUES_PAGE_BITS; + int blockPosition = groupId & VALUES_PAGE_MASK; if (precomputedHashChannel.isPresent()) { - return getRawHash(sliceIndex, position); + return getRawHash(blockIndex, blockPosition, precomputedHashChannel.getAsInt()); } - return hashStrategy.hashPosition(sliceIndex, position); + return hashStrategy.hashPosition(blockIndex, blockPosition); } - private long getRawHash(int sliceIndex, int position) + private long getRawHash(int sliceIndex, int position, int hashChannel) { - return channelBuilders.get(precomputedHashChannel.getAsInt()).get(sliceIndex).getLong(position, 0); + return channelBuilders.get(hashChannel).get(sliceIndex).getLong(position, 0); } - private boolean positionNotDistinctFromCurrentRow(long address, int hashPosition, int position, Page page, byte rawHash, int[] hashChannels) + private boolean positionNotDistinctFromCurrentRow(int groupId, int hashPosition, int position, Page page, byte rawHash, int[] hashChannels) { if (rawHashByHashPosition[hashPosition] != rawHash) { return false; } - return hashStrategy.positionNotDistinctFromRow(decodeSliceIndex(address), decodePosition(address), position, page, hashChannels); + int blockIndex = groupId >> VALUES_PAGE_BITS; + int blockPosition = groupId & VALUES_PAGE_MASK; + return hashStrategy.positionNotDistinctFromRow(blockIndex, blockPosition, position, page, hashChannels); } - private static long getHashPosition(long rawHash, int mask) + private static int getHashPosition(long rawHash, int mask) { - return murmurHash3(rawHash) & mask; + return (int) (murmurHash3(rawHash) & mask); // mask is int so casting is safe } private static int calculateMaxFill(int hashSize) @@ -479,9 +476,7 @@ private Page createPageWithExtractedDictionary(Page page) blocks[channels[0]] = dictionary; // extract hash dictionary - if (inputHashChannel.isPresent()) { - blocks[inputHashChannel.get()] = ((DictionaryBlock) page.getBlock(inputHashChannel.get())).getDictionary(); - } + inputHashChannel.ifPresent(integer -> blocks[integer] = ((DictionaryBlock) page.getBlock(integer)).getDictionary()); return new Page(dictionary.getPositionCount(), blocks); } @@ -500,8 +495,25 @@ private boolean canProcessDictionary(Page page) // data channel is dictionary encoded but hash channel is not return false; } - if (!((DictionaryBlock) inputHashBlock).getDictionarySourceId().equals(inputDataBlock.getDictionarySourceId())) { - // dictionarySourceIds of data block and hash block do not match + // dictionarySourceIds of data block and hash block do not match + return ((DictionaryBlock) inputHashBlock).getDictionarySourceId().equals(inputDataBlock.getDictionarySourceId()); + } + + return true; + } + + private boolean canProcessLowCardinalityDictionary(Page page) + { + // We don't have to rely on 'optimizer.dictionary-aggregations' here since there is little to none chance of regression + int positionCount = page.getPositionCount(); + long cardinality = 1; + for (int channel : channels) { + if (!(page.getBlock(channel) instanceof DictionaryBlock)) { + return false; + } + cardinality = multiplyExact(cardinality, ((DictionaryBlock) page.getBlock(channel)).getDictionary().getPositionCount()); + if (cardinality > positionCount * SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO + || cardinality > Short.MAX_VALUE) { // Need to fit into short array return false; } } @@ -511,8 +523,8 @@ private boolean canProcessDictionary(Page page) private boolean isRunLengthEncoded(Page page) { - for (int i = 0; i < channels.length; i++) { - if (!(page.getBlock(channels[i]) instanceof RunLengthEncodedBlock)) { + for (int channel : channels) { + if (!(page.getBlock(channel) instanceof RunLengthEncodedBlock)) { return false; } } @@ -563,11 +575,11 @@ public void setProcessed(int position, int groupId) } } - private class AddNonDictionaryPageWork + @VisibleForTesting + class AddNonDictionaryPageWork implements Work { private final Page page; - private int lastPosition; public AddNonDictionaryPageWork(Page page) @@ -579,22 +591,24 @@ public AddNonDictionaryPageWork(Page page) public boolean process() { int positionCount = page.getPositionCount(); - checkState(lastPosition < positionCount, "position count out of bound"); + checkState(lastPosition <= positionCount, "position count out of bound"); + int remainingPositions = positionCount - lastPosition; - // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - // get the group for the current row - putIfAbsent(lastPosition, page); - lastPosition++; + for (int i = lastPosition; i < lastPosition + batchSize; i++) { + putIfAbsent(i, page); + } + + lastPosition += batchSize; + remainingPositions -= batchSize; } - return lastPosition == positionCount; + verify(lastPosition == positionCount); + return true; } @Override @@ -604,7 +618,8 @@ public Void getResult() } } - private class AddDictionaryPageWork + @VisibleForTesting + class AddDictionaryPageWork implements Work { private final Page page; @@ -626,7 +641,7 @@ public AddDictionaryPageWork(Page page) public boolean process() { int positionCount = page.getPositionCount(); - checkState(lastPosition < positionCount, "position count out of bound"); + checkState(lastPosition <= positionCount, "position count out of bound"); // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. // We can only proceed if tryRehash() successfully did a rehash. @@ -651,7 +666,56 @@ public Void getResult() } } - private class AddRunLengthEncodedPageWork + class AddLowCardinalityDictionaryPageWork + implements Work + { + private final Page page; + @Nullable + private int[] combinationIdToPosition; + private int nextCombinationId; + + public AddLowCardinalityDictionaryPageWork(Page page) + { + this.page = requireNonNull(page, "page is null"); + } + + @Override + public boolean process() + { + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + if (combinationIdToPosition == null) { + combinationIdToPosition = calculateCombinationIdToPositionMapping(page); + } + + // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. + // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. + for (int combinationId = nextCombinationId; combinationId < combinationIdToPosition.length; combinationId++) { + int position = combinationIdToPosition[combinationId]; + if (position != -1) { + if (needRehash()) { + nextCombinationId = combinationId; + return false; + } + putIfAbsent(position, page); + } + } + return true; + } + + @Override + public Void getResult() + { + throw new UnsupportedOperationException(); + } + } + + @VisibleForTesting + class AddRunLengthEncodedPageWork implements Work { private final Page page; @@ -692,10 +756,11 @@ public Void getResult() } } - private class GetNonDictionaryGroupIdsWork + @VisibleForTesting + class GetNonDictionaryGroupIdsWork implements Work { - private final BlockBuilder blockBuilder; + private final long[] groupIds; private final Page page; private boolean finished; @@ -705,7 +770,7 @@ public GetNonDictionaryGroupIdsWork(Page page) { this.page = requireNonNull(page, "page is null"); // we know the exact size required for the block - this.blockBuilder = BIGINT.createFixedSizeBlockBuilder(page.getPositionCount()); + groupIds = new long[page.getPositionCount()]; } @Override @@ -715,36 +780,103 @@ public boolean process() checkState(lastPosition <= positionCount, "position count out of bound"); checkState(!finished); + int remainingPositions = positionCount - lastPosition; + + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } + + for (int i = lastPosition; i < lastPosition + batchSize; i++) { + // output the group id for this row + groupIds[i] = putIfAbsent(i, page); + } + + lastPosition += batchSize; + remainingPositions -= batchSize; + } + verify(lastPosition == positionCount); + return true; + } + + @Override + public GroupByIdBlock getResult() + { + checkState(lastPosition == page.getPositionCount(), "process has not yet finished"); + checkState(!finished, "result has produced"); + finished = true; + return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds)); + } + } + + @VisibleForTesting + class GetLowCardinalityDictionaryGroupIdsWork + implements Work + { + private final Page page; + private final long[] groupIds; + @Nullable + private short[] positionToCombinationId; + @Nullable + private int[] combinationIdToGroupId; + private int nextPosition; + private boolean finished; + + public GetLowCardinalityDictionaryGroupIdsWork(Page page) + { + this.page = requireNonNull(page, "page is null"); + groupIds = new long[page.getPositionCount()]; + } + + @Override + public boolean process() + { // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. // We can only proceed if tryRehash() successfully did a rehash. if (needRehash() && !tryRehash()) { return false; } - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - // output the group id for this row - BIGINT.writeLong(blockBuilder, putIfAbsent(lastPosition, page)); - lastPosition++; + if (positionToCombinationId == null) { + positionToCombinationId = new short[groupIds.length]; + int maxCardinality = calculatePositionToCombinationIdMapping(page, positionToCombinationId); + combinationIdToGroupId = new int[maxCardinality]; + Arrays.fill(combinationIdToGroupId, -1); } - return lastPosition == positionCount; + + for (int position = nextPosition; position < groupIds.length; position++) { + short combinationId = positionToCombinationId[position]; + int groupId = combinationIdToGroupId[combinationId]; + if (groupId == -1) { + // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. + // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. + if (needRehash()) { + nextPosition = position; + return false; + } + groupId = putIfAbsent(position, page); + combinationIdToGroupId[combinationId] = groupId; + } + groupIds[position] = groupId; + } + return true; } @Override public GroupByIdBlock getResult() { - checkState(lastPosition == page.getPositionCount(), "process has not yet finished"); checkState(!finished, "result has produced"); finished = true; - return new GroupByIdBlock(nextGroupId, blockBuilder.build()); + return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds)); } } - private class GetDictionaryGroupIdsWork + @VisibleForTesting + class GetDictionaryGroupIdsWork implements Work { - private final BlockBuilder blockBuilder; + private final long[] groupIds; private final Page page; private final Page dictionaryPage; private final DictionaryBlock dictionaryBlock; @@ -760,16 +892,14 @@ public GetDictionaryGroupIdsWork(Page page) this.dictionaryBlock = (DictionaryBlock) page.getBlock(channels[0]); updateDictionaryLookBack(dictionaryBlock.getDictionary()); this.dictionaryPage = createPageWithExtractedDictionary(page); - - // we know the exact size required for the block - this.blockBuilder = BIGINT.createFixedSizeBlockBuilder(page.getPositionCount()); + groupIds = new long[page.getPositionCount()]; } @Override public boolean process() { int positionCount = page.getPositionCount(); - checkState(lastPosition < positionCount, "position count out of bound"); + checkState(lastPosition <= positionCount, "position count out of bound"); checkState(!finished); // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. @@ -782,8 +912,7 @@ public boolean process() // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. while (lastPosition < positionCount && !needRehash()) { int positionInDictionary = dictionaryBlock.getId(lastPosition); - int groupId = getGroupId(hashGenerator, dictionaryPage, positionInDictionary); - BIGINT.writeLong(blockBuilder, groupId); + groupIds[lastPosition] = getGroupId(hashGenerator, dictionaryPage, positionInDictionary); lastPosition++; } return lastPosition == positionCount; @@ -795,11 +924,12 @@ public GroupByIdBlock getResult() checkState(lastPosition == page.getPositionCount(), "process has not yet finished"); checkState(!finished, "result has produced"); finished = true; - return new GroupByIdBlock(nextGroupId, blockBuilder.build()); + return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds)); } } - private class GetRunLengthEncodedGroupIdsWork + @VisibleForTesting + class GetRunLengthEncodedGroupIdsWork implements Work { private final Page page; @@ -848,4 +978,65 @@ public GroupByIdBlock getResult() page.getPositionCount())); } } + + /** + * Returns an array containing a position that corresponds to the low cardinality + * dictionary combinationId, or a value of -1 if no position exists within the page + * for that combinationId. + */ + private int[] calculateCombinationIdToPositionMapping(Page page) + { + short[] positionToCombinationId = new short[page.getPositionCount()]; + int maxCardinality = calculatePositionToCombinationIdMapping(page, positionToCombinationId); + int[] combinationIdToPosition = new int[maxCardinality]; + Arrays.fill(combinationIdToPosition, -1); + for (int position = 0; position < positionToCombinationId.length; position++) { + combinationIdToPosition[positionToCombinationId[position]] = position; + } + return combinationIdToPosition; + } + + /** + * Returns the number of combinations of all dictionary ids in input page blocks and populates + * positionToCombinationIds with the combinationId for each position in the input Page + */ + private int calculatePositionToCombinationIdMapping(Page page, short[] positionToCombinationIds) + { + checkArgument(positionToCombinationIds.length == page.getPositionCount()); + + int maxCardinality = 1; + for (int channel = 0; channel < channels.length; channel++) { + Block block = page.getBlock(channels[channel]); + verify(block instanceof DictionaryBlock, "Only dictionary blocks are supported"); + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + int dictionarySize = dictionaryBlock.getDictionary().getPositionCount(); + maxCardinality *= dictionarySize; + if (channel == 0) { + for (int position = 0; position < positionToCombinationIds.length; position++) { + positionToCombinationIds[position] = (short) dictionaryBlock.getId(position); + } + } + else { + for (int position = 0; position < positionToCombinationIds.length; position++) { + short combinationId = positionToCombinationIds[position]; + combinationId *= dictionarySize; + combinationId += dictionaryBlock.getId(position); + positionToCombinationIds[position] = combinationId; + } + } + } + return maxCardinality; + } + + private boolean ensureHashTableSize(int batchSize) + { + int positionCountUntilRehash = maxFill - nextGroupId; + while (positionCountUntilRehash < batchSize) { + if (!tryRehash()) { + return false; + } + positionCountUntilRehash = maxFill - nextGroupId; + } + return true; + } } diff --git a/presto-main/src/test/java/io/prestosql/block/BlockAssertions.java b/presto-main/src/test/java/io/prestosql/block/BlockAssertions.java index 04291d79f8c1..0a4a3180977c 100644 --- a/presto-main/src/test/java/io/prestosql/block/BlockAssertions.java +++ b/presto-main/src/test/java/io/prestosql/block/BlockAssertions.java @@ -383,8 +383,13 @@ public static Block createLongSequenceBlock(int start, int end) public static Block createLongDictionaryBlock(int start, int length) { checkArgument(length > 5, "block must have more than 5 entries"); + return createLongDictionaryBlock(start, length, length / 5); + } + + public static Block createLongDictionaryBlock(int start, int length, int dictionarySize) + { + checkArgument(dictionarySize > 0, "dictionarySize must be greater than 0"); - int dictionarySize = length / 5; BlockBuilder builder = BIGINT.createBlockBuilder(null, dictionarySize); for (int i = start; i < start + dictionarySize; i++) { BIGINT.writeLong(builder, i); diff --git a/presto-main/src/test/java/io/prestosql/operator/BenchmarkGroupByHash.java b/presto-main/src/test/java/io/prestosql/operator/BenchmarkGroupByHash.java index 37646aa1da6a..d86b0b283b1a 100644 --- a/presto-main/src/test/java/io/prestosql/operator/BenchmarkGroupByHash.java +++ b/presto-main/src/test/java/io/prestosql/operator/BenchmarkGroupByHash.java @@ -22,6 +22,8 @@ import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.DictionaryBlock; +import io.prestosql.spi.block.RunLengthEncodedBlock; import io.prestosql.spi.type.AbstractLongType; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeOperators; @@ -53,6 +55,7 @@ import java.util.Optional; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; import static io.prestosql.operator.UpdateMemory.NOOP; import static io.prestosql.spi.type.BigintType.BIGINT; @@ -196,7 +199,7 @@ public long baselineBigArray(BaselinePagesData data) return groupIds; } - private static List createBigintPages(int positionCount, int groupCount, int channelCount, boolean hashEnabled) + private static List createBigintPages(int positionCount, int groupCount, int channelCount, boolean hashEnabled, boolean pollute) { List types = Collections.nCopies(channelCount, BIGINT); ImmutableList.Builder pages = ImmutableList.builder(); @@ -205,6 +208,7 @@ private static List createBigintPages(int positionCount, int groupCount, i } PageBuilder pageBuilder = new PageBuilder(types); + int pageCount = 0; for (int position = 0; position < positionCount; position++) { int rand = ThreadLocalRandom.current().nextInt(groupCount); pageBuilder.declarePosition(); @@ -215,8 +219,34 @@ private static List createBigintPages(int positionCount, int groupCount, i BIGINT.writeLong(pageBuilder.getBlockBuilder(channelCount), AbstractLongType.hash((long) rand)); } if (pageBuilder.isFull()) { - pages.add(pageBuilder.build()); + Page page = pageBuilder.build(); pageBuilder.reset(); + if (pollute) { + if (pageCount % 3 == 0) { + pages.add(page); + } + else if (pageCount % 3 == 1) { + // rle page + Block[] blocks = new Block[page.getChannelCount()]; + for (int channel = 0; channel < blocks.length; ++channel) { + blocks[channel] = new RunLengthEncodedBlock(page.getBlock(channel).getSingleValueBlock(0), page.getPositionCount()); + } + pages.add(new Page(blocks)); + } + else { + // dictionary page + int[] positions = IntStream.range(0, page.getPositionCount()).toArray(); + Block[] blocks = new Block[page.getChannelCount()]; + for (int channel = 0; channel < page.getChannelCount(); ++channel) { + blocks[channel] = new DictionaryBlock(page.getBlock(channel), positions); + } + pages.add(new Page(blocks)); + } + } + else { + pages.add(page); + } + pageCount++; } } pages.add(pageBuilder.build()); @@ -269,7 +299,7 @@ public static class BaselinePagesData @Setup public void setup() { - pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled); + pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled, false); } public List getPages() @@ -295,7 +325,12 @@ public static class SingleChannelBenchmarkData @Setup public void setup() { - pages = createBigintPages(POSITIONS, GROUP_COUNT, channelCount, hashEnabled); + setup(false); + } + + public void setup(boolean pollute) + { + pages = createBigintPages(POSITIONS, GROUP_COUNT, channelCount, hashEnabled, pollute); types = Collections.nCopies(1, BIGINT); channels = new int[1]; for (int i = 0; i < 1; i++) { @@ -351,7 +386,7 @@ public void setup() break; case "BIGINT": types = Collections.nCopies(channelCount, BIGINT); - pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled); + pages = createBigintPages(POSITIONS, groupCount, channelCount, hashEnabled, false); break; default: throw new UnsupportedOperationException("Unsupported dataType"); @@ -389,6 +424,16 @@ private static JoinCompiler getJoinCompiler() return new JoinCompiler(TYPE_OPERATORS); } + static { + // pollute BigintGroupByHash profile by different block types + SingleChannelBenchmarkData singleChannelBenchmarkData = new SingleChannelBenchmarkData(); + singleChannelBenchmarkData.setup(true); + BenchmarkGroupByHash hash = new BenchmarkGroupByHash(); + for (int i = 0; i < 5; ++i) { + hash.bigintGroupByHash(singleChannelBenchmarkData); + } + } + public static void main(String[] args) throws RunnerException { diff --git a/presto-main/src/test/java/io/prestosql/operator/GroupByHashYieldAssertion.java b/presto-main/src/test/java/io/prestosql/operator/GroupByHashYieldAssertion.java index 331a77a581ee..6d5f38df9bb3 100644 --- a/presto-main/src/test/java/io/prestosql/operator/GroupByHashYieldAssertion.java +++ b/presto-main/src/test/java/io/prestosql/operator/GroupByHashYieldAssertion.java @@ -126,7 +126,10 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< // free the pool for the next iteration memoryPool.free(queryId, "test", reservedMemoryInBytes); // this required in case input is blocked - operator.getOutput(); + output = operator.getOutput(); + if (output != null) { + result.add(output); + } continue; } @@ -139,7 +142,7 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone()); // assert the hash capacity is not changed; otherwise, we should have yielded - assertTrue(oldCapacity == getHashCapacity.apply(operator)); + assertEquals((int) getHashCapacity.apply(operator), oldCapacity); // We are not going to rehash; therefore, assert the memory increase only comes from the aggregator assertLessThan(actualIncreasedMemory, additionalMemoryInBytes); @@ -163,8 +166,8 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES) + page.getRetainedSizeInBytes(); } else { - // groupAddressByHash, groupIdsByHash, and rawHashByHashPosition double by hashCapacity; while groupAddressByGroupId double by maxFill = hashCapacity / 0.75 - expectedReservedExtraBytes = oldCapacity * (long) (Long.BYTES * 1.75 + Integer.BYTES + Byte.BYTES) + page.getRetainedSizeInBytes(); + // groupIdsByHash, and rawHashByHashPosition double by hashCapacity + expectedReservedExtraBytes = oldCapacity * (long) (Integer.BYTES + Byte.BYTES); } assertBetweenInclusive(actualIncreasedMemory, expectedReservedExtraBytes, expectedReservedExtraBytes + additionalMemoryInBytes); @@ -186,10 +189,24 @@ public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List< // Assert the estimated reserved memory before rehash is very close to the one after rehash long rehashedMemoryUsage = operator.getOperatorContext().getDriverContext().getMemoryUsage(); - assertBetweenInclusive(rehashedMemoryUsage * 1.0 / newMemoryUsage, 0.99, 1.01); + double memoryUsageErrorUpperBound = 1.01; + double memoryUsageError = rehashedMemoryUsage * 1.0 / newMemoryUsage; + if (memoryUsageError > memoryUsageErrorUpperBound) { + // Usually the error is < 1%, but since MultiChannelGroupByHash.getEstimatedSize + // accounts for changes in completedPagesMemorySize, which is increased if new page is + // added by addNewGroup (an even that cannot be predicted as it depends on the number of unique groups + // in the current page being processed), the difference includes size of the added new page. + // Lower bound is 1% lower than normal because additionalMemoryInBytes includes also aggregator state. + assertBetweenInclusive(rehashedMemoryUsage * 1.0 / (newMemoryUsage + additionalMemoryInBytes), 0.98, memoryUsageErrorUpperBound, + "rehashedMemoryUsage " + rehashedMemoryUsage + ", newMemoryUsage: " + newMemoryUsage); + } + else { + assertBetweenInclusive(memoryUsageError, 0.99, memoryUsageErrorUpperBound); + } // unblocked assertTrue(operator.needsInput()); + assertTrue(operator.getOperatorContext().isWaitingForMemory().isDone()); } } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestDistinctLimitOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestDistinctLimitOperator.java index eb61f057a063..db8dfe161a23 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestDistinctLimitOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestDistinctLimitOperator.java @@ -34,7 +34,7 @@ import java.util.concurrent.ScheduledExecutorService; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.prestosql.RowPagesBuilder.rowPagesBuilder; import static io.prestosql.SessionTestUtils.TEST_SESSION; import static io.prestosql.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; @@ -167,9 +167,9 @@ public void testMemoryReservationYield(Type type) joinCompiler, blockTypeOperators); - GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((DistinctLimitOperator) operator).getCapacity(), 1_400_000); - assertGreaterThan(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((DistinctLimitOperator) operator).getCapacity(), 450_000); + assertGreaterThanOrEqual(result.getYieldCount(), 5); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20); assertEquals(result.getOutput().stream().mapToInt(Page::getPositionCount).sum(), 6_000 * 600); } } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestGroupByHash.java b/presto-main/src/test/java/io/prestosql/operator/TestGroupByHash.java index 7ec7759cc150..84e2b56f7584 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestGroupByHash.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestGroupByHash.java @@ -14,6 +14,7 @@ package io.prestosql.operator; import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slices; import io.prestosql.Session; import io.prestosql.block.BlockAssertions; import io.prestosql.spi.Page; @@ -21,6 +22,9 @@ import io.prestosql.spi.block.Block; import io.prestosql.spi.block.DictionaryBlock; import io.prestosql.spi.block.DictionaryId; +import io.prestosql.spi.block.LongArrayBlock; +import io.prestosql.spi.block.RunLengthEncodedBlock; +import io.prestosql.spi.block.VariableWidthBlock; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeOperators; import io.prestosql.sql.gen.JoinCompiler; @@ -42,11 +46,13 @@ import static io.prestosql.block.BlockAssertions.createLongsBlock; import static io.prestosql.block.BlockAssertions.createStringSequenceBlock; import static io.prestosql.operator.GroupByHash.createGroupByHash; +import static io.prestosql.operator.UpdateMemory.NOOP; import static io.prestosql.spi.block.DictionaryId.randomDictionaryId; import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.spi.type.DoubleType.DOUBLE; import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.type.TypeTestUtils.getHashBlock; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -95,6 +101,61 @@ public void testAddPage() } } + @Test + public void testRunLengthEncodedBigintGroupByHash() + { + GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY); + Block block = BlockAssertions.createLongsBlock(0L); + Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), block); + Page page = new Page( + new RunLengthEncodedBlock(block, 2), + new RunLengthEncodedBlock(hashBlock, 2)); + + groupByHash.addPage(page).process(); + + assertEquals(groupByHash.getGroupCount(), 1); + + Work work = groupByHash.getGroupIds(page); + work.process(); + GroupByIdBlock groupIds = work.getResult(); + + assertEquals(groupIds.getGroupCount(), 1); + assertEquals(groupIds.getPositionCount(), 2); + assertEquals(groupIds.getGroupId(0), 0); + assertEquals(groupIds.getGroupId(1), 0); + + List children = groupIds.getChildren(); + assertEquals(children.size(), 1); + assertTrue(children.get(0) instanceof RunLengthEncodedBlock); + } + + @Test + public void testDictionaryBigintGroupByHash() + { + GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT), new int[] {0}, Optional.of(1), 100, JOIN_COMPILER, TYPE_OPERATOR_FACTORY); + Block block = BlockAssertions.createLongsBlock(0L, 1L); + Block hashBlock = TypeTestUtils.getHashBlock(ImmutableList.of(BIGINT), block); + int[] ids = new int[] {0, 0, 1, 1}; + Page page = new Page( + new DictionaryBlock(block, ids), + new DictionaryBlock(hashBlock, ids)); + + groupByHash.addPage(page).process(); + + assertEquals(groupByHash.getGroupCount(), 2); + + Work work = groupByHash.getGroupIds(page); + work.process(); + GroupByIdBlock groupIds = work.getResult(); + + assertEquals(groupIds.getGroupCount(), 2); + assertEquals(groupIds.getPositionCount(), 4); + assertEquals(groupIds.getGroupId(0), 0); + assertEquals(groupIds.getGroupId(1), 0); + assertEquals(groupIds.getGroupId(2), 1); + assertEquals(groupIds.getGroupId(3), 1); + } + @Test public void testNullGroup() { @@ -445,4 +506,163 @@ public void testMemoryReservationYieldWithDictionary() assertEquals(currentQuota.get(), 10); assertEquals(currentQuota.get() / 3, yields); } + + @Test + public void testLowCardinalityDictionariesAddPage() + { + GroupByHash groupByHash = createGroupByHash( + TEST_SESSION, + ImmutableList.of(BIGINT, BIGINT), + new int[] {0, 1}, + Optional.empty(), + 100, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY); + Block firstBlock = BlockAssertions.createLongDictionaryBlock(0, 1000, 10); + Block secondBlock = BlockAssertions.createLongDictionaryBlock(0, 1000, 10); + Page page = new Page(firstBlock, secondBlock); + + Work work = groupByHash.addPage(page); + assertThat(work).isInstanceOf(MultiChannelGroupByHash.AddLowCardinalityDictionaryPageWork.class); + work.process(); + assertThat(groupByHash.getGroupCount()).isEqualTo(10); // Blocks are identical so only 10 distinct groups + + firstBlock = BlockAssertions.createLongDictionaryBlock(10, 1000, 5); + secondBlock = BlockAssertions.createLongDictionaryBlock(10, 1000, 7); + page = new Page(firstBlock, secondBlock); + + groupByHash.addPage(page).process(); + assertThat(groupByHash.getGroupCount()).isEqualTo(45); // Old 10 groups and 35 new + } + + @Test + public void testLowCardinalityDictionariesGetGroupIds() + { + // Compare group id results from page with dictionaries only (processed via low cardinality work) and the same page processed normally + GroupByHash groupByHash = createGroupByHash( + TEST_SESSION, + ImmutableList.of(BIGINT, BIGINT, BIGINT, BIGINT, BIGINT), + new int[] {0, 1, 2, 3, 4}, + Optional.empty(), + 100, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY); + + GroupByHash lowCardinalityGroupByHash = createGroupByHash( + TEST_SESSION, + ImmutableList.of(BIGINT, BIGINT, BIGINT, BIGINT), + new int[] {0, 1, 2, 3}, + Optional.empty(), + 100, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY); + Block sameValueBlock = BlockAssertions.createLongRepeatBlock(0, 100); + Block block1 = BlockAssertions.createLongDictionaryBlock(0, 100, 1); + Block block2 = BlockAssertions.createLongDictionaryBlock(0, 100, 2); + Block block3 = BlockAssertions.createLongDictionaryBlock(0, 100, 3); + Block block4 = BlockAssertions.createLongDictionaryBlock(0, 100, 4); + // Combining block 2 and 4 will result in only 4 distinct values since 2 and 4 are not coprime + + Page lowCardinalityPage = new Page(block1, block2, block3, block4); + Page page = new Page(block1, block2, block3, block4, sameValueBlock); // sameValueBlock will prevent low cardinality optimization to fire + + Work lowCardinalityWork = lowCardinalityGroupByHash.getGroupIds(lowCardinalityPage); + assertThat(lowCardinalityWork).isInstanceOf(MultiChannelGroupByHash.GetLowCardinalityDictionaryGroupIdsWork.class); + Work work = groupByHash.getGroupIds(page); + + lowCardinalityWork.process(); + work.process(); + GroupByIdBlock lowCardinalityResults = lowCardinalityWork.getResult(); + GroupByIdBlock results = work.getResult(); + + assertThat(lowCardinalityResults.getGroupCount()).isEqualTo(results.getGroupCount()); + } + + @Test + public void testLowCardinalityDictionariesProperGroupIdOrder() + { + GroupByHash groupByHash = createGroupByHash( + TEST_SESSION, + ImmutableList.of(BIGINT, BIGINT), + new int[] {0, 1}, + Optional.empty(), + 100, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY); + + Block dictionary = new LongArrayBlock(2, Optional.empty(), new long[] {0, 1}); + int[] ids = new int[32]; + for (int i = 0; i < 16; i++) { + ids[i] = 1; + } + Block block1 = new DictionaryBlock(dictionary, ids); + Block block2 = new DictionaryBlock(dictionary, ids); + + Page page = new Page(block1, block2); + + Work work = groupByHash.getGroupIds(page); + assertThat(work).isInstanceOf(MultiChannelGroupByHash.GetLowCardinalityDictionaryGroupIdsWork.class); + + work.process(); + GroupByIdBlock results = work.getResult(); + // Records with group id '0' should come before '1' despite being in the end of the block + for (int i = 0; i < 16; i++) { + assertThat(results.getGroupId(i)).isEqualTo(0); + } + for (int i = 16; i < 32; i++) { + assertThat(results.getGroupId(i)).isEqualTo(1); + } + } + + @Test + public void testProperWorkTypesSelected() + { + Block bigintBlock = BlockAssertions.createLongsBlock(1, 2, 3, 4, 5, 6, 7, 8); + Block bigintDictionaryBlock = BlockAssertions.createLongDictionaryBlock(0, 8); + Block bigintRleBlock = BlockAssertions.createRLEBlock(42, 8); + Block varcharBlock = BlockAssertions.createStringsBlock("1", "2", "3", "4", "5", "6", "7", "8"); + Block varcharDictionaryBlock = BlockAssertions.createStringDictionaryBlock(1, 8); + Block varcharRleBlock = new RunLengthEncodedBlock(new VariableWidthBlock(1, Slices.EMPTY_SLICE, new int[] {0, 1}, Optional.empty()), 8); + Block bigintBigDictionaryBlock = BlockAssertions.createLongDictionaryBlock(1, 8, 1000); + Block bigintSingletonDictionaryBlock = BlockAssertions.createLongDictionaryBlock(1, 500000, 1); + Block bigintHugeDictionaryBlock = BlockAssertions.createLongDictionaryBlock(1, 500000, 66000); // Above Short.MAX_VALUE + + Page singleBigintPage = new Page(bigintBlock); + assertGroupByHashWork(singleBigintPage, ImmutableList.of(BIGINT), BigintGroupByHash.GetGroupIdsWork.class); + Page singleBigintDictionaryPage = new Page(bigintDictionaryBlock); + assertGroupByHashWork(singleBigintDictionaryPage, ImmutableList.of(BIGINT), BigintGroupByHash.GetDictionaryGroupIdsWork.class); + Page singleBigintRlePage = new Page(bigintRleBlock); + assertGroupByHashWork(singleBigintRlePage, ImmutableList.of(BIGINT), BigintGroupByHash.GetRunLengthEncodedGroupIdsWork.class); + Page singleVarcharPage = new Page(varcharBlock); + assertGroupByHashWork(singleVarcharPage, ImmutableList.of(VARCHAR), MultiChannelGroupByHash.GetNonDictionaryGroupIdsWork.class); + Page singleVarcharDictionaryPage = new Page(varcharDictionaryBlock); + assertGroupByHashWork(singleVarcharDictionaryPage, ImmutableList.of(VARCHAR), MultiChannelGroupByHash.GetDictionaryGroupIdsWork.class); + Page singleVarcharRlePage = new Page(varcharRleBlock); + assertGroupByHashWork(singleVarcharRlePage, ImmutableList.of(VARCHAR), MultiChannelGroupByHash.GetRunLengthEncodedGroupIdsWork.class); + + Page lowCardinalityDictionaryPage = new Page(bigintDictionaryBlock, varcharDictionaryBlock); + assertGroupByHashWork(lowCardinalityDictionaryPage, ImmutableList.of(BIGINT, VARCHAR), MultiChannelGroupByHash.GetLowCardinalityDictionaryGroupIdsWork.class); + Page highCardinalityDictionaryPage = new Page(bigintDictionaryBlock, bigintBigDictionaryBlock); + assertGroupByHashWork(highCardinalityDictionaryPage, ImmutableList.of(BIGINT, VARCHAR), MultiChannelGroupByHash.GetNonDictionaryGroupIdsWork.class); + + // Cardinality above Short.MAX_VALUE + Page lowCardinalityHugeDictionaryPage = new Page(bigintSingletonDictionaryBlock, bigintHugeDictionaryBlock); + assertGroupByHashWork(lowCardinalityHugeDictionaryPage, ImmutableList.of(BIGINT, BIGINT), MultiChannelGroupByHash.GetNonDictionaryGroupIdsWork.class); + } + + private void assertGroupByHashWork(Page page, List types, Class clazz) + { + GroupByHash groupByHash = createGroupByHash( + types, + IntStream.range(0, types.size()).toArray(), + Optional.empty(), + 100, + true, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY, + NOOP); + Work work = groupByHash.getGroupIds(page); + // Compare by name since classes are private + assertThat(work).isInstanceOf(clazz); + } } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java index df6e01c30922..3b2bf3139e95 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestHashAggregationOperator.java @@ -62,6 +62,7 @@ import static io.airlift.slice.SizeOf.SIZE_OF_LONG; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.airlift.units.DataSize.succinctBytes; @@ -420,9 +421,9 @@ public void testMemoryReservationYield(Type type) // get result with yield; pick a relatively small buffer for aggregator's memory usage GroupByHashYieldResult result; - result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, this::getHashCapacity, 1_400_000); - assertGreaterThan(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, this::getHashCapacity, 450_000); + assertGreaterThanOrEqual(result.getYieldCount(), 5); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20); int count = 0; for (Page page : result.getOutput()) { diff --git a/presto-main/src/test/java/io/prestosql/operator/TestHashSemiJoinOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestHashSemiJoinOperator.java index e795686564c4..f14a5f8a5d61 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestHashSemiJoinOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestHashSemiJoinOperator.java @@ -38,7 +38,6 @@ import static com.google.common.collect.Iterables.concat; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.testing.Assertions.assertGreaterThan; import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.prestosql.RowPagesBuilder.rowPagesBuilder; import static io.prestosql.SessionTestUtils.TEST_SESSION; @@ -244,10 +243,10 @@ public void testSemiJoinMemoryReservationYield(Type type) type, setBuilderOperatorFactory, operator -> ((SetBuilderOperator) operator).getCapacity(), - 1_400_000); + 450_000); - assertGreaterThanOrEqual(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + assertGreaterThanOrEqual(result.getYieldCount(), 4); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 19); assertEquals(result.getOutput().stream().mapToInt(Page::getPositionCount).sum(), 0); } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestMarkDistinctOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestMarkDistinctOperator.java index 6fd3086e7cda..d884f9f052db 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestMarkDistinctOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestMarkDistinctOperator.java @@ -35,7 +35,7 @@ import java.util.concurrent.ScheduledExecutorService; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.prestosql.RowPagesBuilder.rowPagesBuilder; import static io.prestosql.SessionTestUtils.TEST_SESSION; import static io.prestosql.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; @@ -116,9 +116,9 @@ public void testMemoryReservationYield(Type type) OperatorFactory operatorFactory = new MarkDistinctOperatorFactory(0, new PlanNodeId("test"), ImmutableList.of(type), ImmutableList.of(0), Optional.of(1), joinCompiler, blockTypeOperators); // get result with yield; pick a relatively small buffer for partitionRowCount's memory usage - GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((MarkDistinctOperator) operator).getCapacity(), 1_400_000); - assertGreaterThan(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((MarkDistinctOperator) operator).getCapacity(), 450_000); + assertGreaterThanOrEqual(result.getYieldCount(), 5); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20); int count = 0; for (Page page : result.getOutput()) { diff --git a/presto-main/src/test/java/io/prestosql/operator/TestRowNumberOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestRowNumberOperator.java index 88721e0a7352..ced3448b69de 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestRowNumberOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestRowNumberOperator.java @@ -40,7 +40,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.testing.Assertions.assertEqualsIgnoreOrder; -import static io.airlift.testing.Assertions.assertGreaterThan; +import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.prestosql.RowPagesBuilder.rowPagesBuilder; import static io.prestosql.SessionTestUtils.TEST_SESSION; import static io.prestosql.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; @@ -172,9 +172,9 @@ public void testMemoryReservationYield(Type type) blockTypeOperators); // get result with yield; pick a relatively small buffer for partitionRowCount's memory usage - GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((RowNumberOperator) operator).getCapacity(), 1_400_000); - assertGreaterThan(result.getYieldCount(), 5); - assertGreaterThan(result.getMaxReservedBytes(), 20L << 20); + GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash(input, type, operatorFactory, operator -> ((RowNumberOperator) operator).getCapacity(), 280_000); + assertGreaterThanOrEqual(result.getYieldCount(), 5); + assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 20); int count = 0; for (Page page : result.getOutput()) { diff --git a/presto-tests/src/test/java/io/prestosql/memory/TestMemoryManager.java b/presto-tests/src/test/java/io/prestosql/memory/TestMemoryManager.java index 9af851cb7c5f..c133d0f85a07 100644 --- a/presto-tests/src/test/java/io/prestosql/memory/TestMemoryManager.java +++ b/presto-tests/src/test/java/io/prestosql/memory/TestMemoryManager.java @@ -130,7 +130,8 @@ public void testOutOfMemoryKiller() List> queryFutures = new ArrayList<>(); for (int i = 0; i < 2; i++) { - queryFutures.add(executor.submit(() -> queryRunner.execute("SELECT COUNT(*), clerk FROM orders GROUP BY clerk"))); + // for this test to work, the query has to have enough groups for HashAggregationOperator to go over QueryContext.GUARANTEED_MEMORY + queryFutures.add(executor.submit(() -> queryRunner.execute("SELECT COUNT(*), cast(orderkey as varchar), partkey FROM lineitem GROUP BY cast(orderkey as varchar), partkey"))); } // Wait for one of the queries to die