diff --git a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java index 318376127c38..92b3a3abf001 100644 --- a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java @@ -371,7 +371,8 @@ private int registerGroupId(Block dictionary, int positionInDictionary) return groupId; } - private class AddPageWork + @VisibleForTesting + class AddPageWork implements Work { private final Block block; @@ -412,7 +413,8 @@ public Void getResult() } } - private class AddDictionaryPageWork + @VisibleForTesting + class AddDictionaryPageWork implements Work { private final Block dictionary; @@ -456,7 +458,8 @@ public Void getResult() } } - private class AddRunLengthEncodedPageWork + @VisibleForTesting + class AddRunLengthEncodedPageWork implements Work { private final RunLengthEncodedBlock block; @@ -497,7 +500,8 @@ public Void getResult() } } - private class GetGroupIdsWork + @VisibleForTesting + class GetGroupIdsWork implements Work { private final BlockBuilder blockBuilder; @@ -546,7 +550,8 @@ public GroupByIdBlock getResult() } } - private class GetDictionaryGroupIdsWork + @VisibleForTesting + class GetDictionaryGroupIdsWork implements Work { private final BlockBuilder blockBuilder; @@ -600,7 +605,8 @@ public GroupByIdBlock getResult() } } - private class GetRunLengthEncodedGroupIdsWork + @VisibleForTesting + class GetRunLengthEncodedGroupIdsWork implements Work { private final RunLengthEncodedBlock block; diff --git a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java index 16f9c8c6ae28..bfc2fccad8d8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java @@ -23,6 +23,7 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.Type; import io.trino.sql.gen.JoinCompiler; @@ -48,6 +49,7 @@ import static io.trino.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static it.unimi.dsi.fastutil.HashCommon.murmurHash3; +import static java.lang.Math.multiplyExact; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -57,6 +59,9 @@ public class MultiChannelGroupByHash { private static final int INSTANCE_SIZE = ClassLayout.parseClass(MultiChannelGroupByHash.class).instanceSize(); private static final float FILL_RATIO = 0.75f; + // Max (page value count / cumulative dictionary size) to trigger the low cardinality case + private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = .25; + private final List types; private final List hashTypes; private final int[] channels; @@ -222,6 +227,9 @@ public Work addPage(Page page) if (canProcessDictionary(page)) { return new AddDictionaryPageWork(page); } + if (canProcessLowCardinalityDictionary(page)) { + return new AddLowCardinalityDictionaryPageWork(page); + } return new AddNonDictionaryPageWork(page); } @@ -236,6 +244,9 @@ public Work getGroupIds(Page page) if (canProcessDictionary(page)) { return new GetDictionaryGroupIdsWork(page); } + if (canProcessLowCardinalityDictionary(page)) { + return new GetLowCardinalityDictionaryGroupIdsWork(page); + } return new GetNonDictionaryGroupIdsWork(page); } @@ -507,6 +518,25 @@ private boolean canProcessDictionary(Page page) return true; } + private boolean canProcessLowCardinalityDictionary(Page page) + { + // We don't have to rely on 'optimizer.dictionary-aggregations' here since there is little to none chance of regression + int positionCount = page.getPositionCount(); + long cardinality = 1; + for (int channel : channels) { + if (!(page.getBlock(channel) instanceof DictionaryBlock)) { + return false; + } + cardinality = multiplyExact(cardinality, ((DictionaryBlock) page.getBlock(channel)).getDictionary().getPositionCount()); + if (cardinality > positionCount * SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO + || cardinality > Short.MAX_VALUE) { // Need to fit into short array + return false; + } + } + + return true; + } + private boolean isRunLengthEncoded(Page page) { for (int channel : channels) { @@ -561,7 +591,8 @@ public void setProcessed(int position, int groupId) } } - private class AddNonDictionaryPageWork + @VisibleForTesting + class AddNonDictionaryPageWork implements Work { private final Page page; @@ -602,7 +633,8 @@ public Void getResult() } } - private class AddDictionaryPageWork + @VisibleForTesting + class AddDictionaryPageWork implements Work { private final Page page; @@ -649,7 +681,49 @@ public Void getResult() } } - private class AddRunLengthEncodedPageWork + class AddLowCardinalityDictionaryPageWork + extends LowCardinalityDictionaryWork + { + public AddLowCardinalityDictionaryPageWork(Page page) + { + super(page); + } + + @Override + public boolean process() + { + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + int[] combinationIdToPosition = new int[maxCardinality]; + Arrays.fill(combinationIdToPosition, -1); + calculateCombinationIdsToPositionMapping(combinationIdToPosition); + + // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. + // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. + for (int i = 0; i < maxCardinality; i++) { + if (needRehash()) { + return false; + } + if (combinationIdToPosition[i] != -1) { + putIfAbsent(combinationIdToPosition[i], page); + } + } + return true; + } + + @Override + public Void getResult() + { + throw new UnsupportedOperationException(); + } + } + + @VisibleForTesting + class AddRunLengthEncodedPageWork implements Work { private final Page page; @@ -690,7 +764,8 @@ public Void getResult() } } - private class GetNonDictionaryGroupIdsWork + @VisibleForTesting + class GetNonDictionaryGroupIdsWork implements Work { private final BlockBuilder blockBuilder; @@ -739,7 +814,64 @@ public GroupByIdBlock getResult() } } - private class GetDictionaryGroupIdsWork + @VisibleForTesting + class GetLowCardinalityDictionaryGroupIdsWork + extends LowCardinalityDictionaryWork + { + private final long[] groupIds; + private boolean finished; + + public GetLowCardinalityDictionaryGroupIdsWork(Page page) + { + super(page); + groupIds = new long[page.getPositionCount()]; + } + + @Override + public boolean process() + { + // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. + // We can only proceed if tryRehash() successfully did a rehash. + if (needRehash() && !tryRehash()) { + return false; + } + + int positionCount = page.getPositionCount(); + int[] combinationIdToPosition = new int[maxCardinality]; + Arrays.fill(combinationIdToPosition, -1); + short[] positionToCombinationId = calculateCombinationIdsToPositionMapping(combinationIdToPosition); + int[] combinationIdToGroupId = new int[maxCardinality]; + + // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. + // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. + for (int i = 0; i < maxCardinality; i++) { + if (needRehash()) { + return false; + } + if (combinationIdToPosition[i] != -1) { + combinationIdToGroupId[i] = putIfAbsent(combinationIdToPosition[i], page); + } + else { + combinationIdToGroupId[i] = -1; + } + } + for (int i = 0; i < positionCount; i++) { + groupIds[i] = combinationIdToGroupId[positionToCombinationId[i]]; + } + return true; + } + + @Override + public GroupByIdBlock getResult() + { + checkState(!finished, "result has produced"); + finished = true; + return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds)); + } + } + + @VisibleForTesting + class GetDictionaryGroupIdsWork implements Work { private final BlockBuilder blockBuilder; @@ -797,7 +929,8 @@ public GroupByIdBlock getResult() } } - private class GetRunLengthEncodedGroupIdsWork + @VisibleForTesting + class GetRunLengthEncodedGroupIdsWork implements Work { private final Page page; @@ -846,4 +979,56 @@ public GroupByIdBlock getResult() page.getPositionCount())); } } + + private abstract class LowCardinalityDictionaryWork + implements Work + { + protected final Page page; + protected final int maxCardinality; + protected final int[] dictionarySizes; + protected final DictionaryBlock[] blocks; + + public LowCardinalityDictionaryWork(Page page) + { + this.page = requireNonNull(page, "page is null"); + dictionarySizes = new int[channels.length]; + blocks = new DictionaryBlock[channels.length]; + int maxCardinality = 1; + for (int i = 0; i < channels.length; i++) { + Block block = page.getBlock(channels[i]); + verify(block instanceof DictionaryBlock, "Only dictionary blocks are supported"); + blocks[i] = (DictionaryBlock) block; + int blockPositionCount = blocks[i].getDictionary().getPositionCount(); + dictionarySizes[i] = blockPositionCount; + maxCardinality *= blockPositionCount; + } + this.maxCardinality = maxCardinality; + } + + /** + * Returns combinations of all dictionaries ids for every position and populates + * samplePositions array with a single occurrence of every used combination + */ + protected short[] calculateCombinationIdsToPositionMapping(int[] combinationIdToPosition) + { + int positionCount = page.getPositionCount(); + // short arrays improve performance compared to int + short[] combinationIds = new short[positionCount]; + + for (int i = 0; i < positionCount; i++) { + combinationIds[i] = (short) blocks[0].getId(i); + } + for (int j = 1; j < channels.length; j++) { + for (int i = 0; i < positionCount; i++) { + combinationIds[i] *= dictionarySizes[j]; + combinationIds[i] += blocks[j].getId(i); + } + } + + for (int i = 0; i < positionCount; i++) { + combinationIdToPosition[combinationIds[i]] = i; + } + return combinationIds; + } + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java index ceeb849bf0aa..3915455490d7 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java @@ -14,14 +14,17 @@ package io.trino.operator; import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slices; import io.trino.Session; import io.trino.block.BlockAssertions; +import io.trino.operator.MultiChannelGroupByHash.GetLowCardinalityDictionaryGroupIdsWork; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; import io.trino.spi.block.DictionaryBlock; import io.trino.spi.block.DictionaryId; import io.trino.spi.block.RunLengthEncodedBlock; +import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.JoinCompiler; @@ -49,6 +52,7 @@ import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.type.TypeTestUtils.getHashBlock; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; @@ -502,4 +506,130 @@ public void testMemoryReservationYieldWithDictionary() assertEquals(currentQuota.get(), 10); assertEquals(currentQuota.get() / 3, yields); } + + @Test + public void testLowCardinalityDictionariesAddPage() + { + GroupByHash groupByHash = createGroupByHash( + TEST_SESSION, + ImmutableList.of(BIGINT, BIGINT), + new int[] {0, 1}, + Optional.empty(), + 100, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY, + NOOP); + Block firstBlock = BlockAssertions.createLongDictionaryBlock(0, 1000, 10); + Block secondBlock = BlockAssertions.createLongDictionaryBlock(0, 1000, 10); + Page page = new Page(firstBlock, secondBlock); + + Work work = groupByHash.addPage(page); + assertThat(work).isInstanceOf(MultiChannelGroupByHash.AddLowCardinalityDictionaryPageWork.class); + work.process(); + assertThat(groupByHash.getGroupCount()).isEqualTo(10); // Blocks are identical so only 10 distinct groups + + firstBlock = BlockAssertions.createLongDictionaryBlock(10, 1000, 5); + secondBlock = BlockAssertions.createLongDictionaryBlock(10, 1000, 7); + page = new Page(firstBlock, secondBlock); + + groupByHash.addPage(page).process(); + assertThat(groupByHash.getGroupCount()).isEqualTo(45); // Old 10 groups and 35 new + } + + @Test + public void testLowCardinalityDictionariesGetGroupIds() + { + // Compare group id results from page with dictionaries only (processed via low cardinality work) and the same page processed normally + GroupByHash groupByHash = createGroupByHash( + TEST_SESSION, + ImmutableList.of(BIGINT, BIGINT, BIGINT, BIGINT, BIGINT), + new int[] {0, 1, 2, 3, 4}, + Optional.empty(), + 100, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY, + NOOP); + + GroupByHash lowCardinalityGroupByHash = createGroupByHash( + TEST_SESSION, + ImmutableList.of(BIGINT, BIGINT, BIGINT, BIGINT), + new int[] {0, 1, 2, 3}, + Optional.empty(), + 100, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY, + NOOP); + Block sameValueBlock = BlockAssertions.createLongRepeatBlock(0, 100); + Block block1 = BlockAssertions.createLongDictionaryBlock(0, 100, 1); + Block block2 = BlockAssertions.createLongDictionaryBlock(0, 100, 2); + Block block3 = BlockAssertions.createLongDictionaryBlock(0, 100, 3); + Block block4 = BlockAssertions.createLongDictionaryBlock(0, 100, 4); + // Combining block 2 and 4 will result in only 4 distinct values since 2 and 4 are not coprime + + Page lowCardinalityPage = new Page(block1, block2, block3, block4); + Page page = new Page(block1, block2, block3, block4, sameValueBlock); // sameValueBlock will prevent low cardinality optimization to fire + + Work lowCardinalityWork = lowCardinalityGroupByHash.getGroupIds(lowCardinalityPage); + assertThat(lowCardinalityWork).isInstanceOf(GetLowCardinalityDictionaryGroupIdsWork.class); + Work work = groupByHash.getGroupIds(page); + + lowCardinalityWork.process(); + work.process(); + GroupByIdBlock lowCardinalityResults = lowCardinalityWork.getResult(); + GroupByIdBlock results = work.getResult(); + + assertThat(lowCardinalityResults.getGroupCount()).isEqualTo(results.getGroupCount()); + } + + @Test + public void testProperWorkTypesSelected() + { + Block bigintBlock = BlockAssertions.createLongsBlock(1, 2, 3, 4, 5, 6, 7, 8); + Block bigintDictionaryBlock = BlockAssertions.createLongDictionaryBlock(0, 8); + Block bigintRleBlock = BlockAssertions.createRLEBlock(42, 8); + Block varcharBlock = BlockAssertions.createStringsBlock("1", "2", "3", "4", "5", "6", "7", "8"); + Block varcharDictionaryBlock = BlockAssertions.createStringDictionaryBlock(1, 8); + Block varcharRleBlock = new RunLengthEncodedBlock(new VariableWidthBlock(1, Slices.EMPTY_SLICE, new int[] {0, 1}, Optional.empty()), 8); + Block bigintBigDictionaryBlock = BlockAssertions.createLongDictionaryBlock(1, 8, 1000); + Block bigintSingletonDictionaryBlock = BlockAssertions.createLongDictionaryBlock(1, 500000, 1); + Block bigintHugeDictionaryBlock = BlockAssertions.createLongDictionaryBlock(1, 500000, 66000); // Above Short.MAX_VALUE + + Page singleBigintPage = new Page(bigintBlock); + assertGroupByHashWork(singleBigintPage, ImmutableList.of(BIGINT), BigintGroupByHash.GetGroupIdsWork.class); + Page singleBigintDictionaryPage = new Page(bigintDictionaryBlock); + assertGroupByHashWork(singleBigintDictionaryPage, ImmutableList.of(BIGINT), BigintGroupByHash.GetDictionaryGroupIdsWork.class); + Page singleBigintRlePage = new Page(bigintRleBlock); + assertGroupByHashWork(singleBigintRlePage, ImmutableList.of(BIGINT), BigintGroupByHash.GetRunLengthEncodedGroupIdsWork.class); + Page singleVarcharPage = new Page(varcharBlock); + assertGroupByHashWork(singleVarcharPage, ImmutableList.of(VARCHAR), MultiChannelGroupByHash.GetNonDictionaryGroupIdsWork.class); + Page singleVarcharDictionaryPage = new Page(varcharDictionaryBlock); + assertGroupByHashWork(singleVarcharDictionaryPage, ImmutableList.of(VARCHAR), MultiChannelGroupByHash.GetDictionaryGroupIdsWork.class); + Page singleVarcharRlePage = new Page(varcharRleBlock); + assertGroupByHashWork(singleVarcharRlePage, ImmutableList.of(VARCHAR), MultiChannelGroupByHash.GetRunLengthEncodedGroupIdsWork.class); + + Page lowCardinalityDictionaryPage = new Page(bigintDictionaryBlock, varcharDictionaryBlock); + assertGroupByHashWork(lowCardinalityDictionaryPage, ImmutableList.of(BIGINT, VARCHAR), MultiChannelGroupByHash.GetLowCardinalityDictionaryGroupIdsWork.class); + Page highCardinalityDictionaryPage = new Page(bigintDictionaryBlock, bigintBigDictionaryBlock); + assertGroupByHashWork(highCardinalityDictionaryPage, ImmutableList.of(BIGINT, VARCHAR), MultiChannelGroupByHash.GetNonDictionaryGroupIdsWork.class); + + // Cardinality above Short.MAX_VALUE + Page lowCardinalityHugeDictionaryPage = new Page(bigintSingletonDictionaryBlock, bigintHugeDictionaryBlock); + assertGroupByHashWork(lowCardinalityHugeDictionaryPage, ImmutableList.of(BIGINT, BIGINT), MultiChannelGroupByHash.GetNonDictionaryGroupIdsWork.class); + } + + private void assertGroupByHashWork(Page page, List types, Class clazz) + { + GroupByHash groupByHash = createGroupByHash( + types, + IntStream.range(0, types.size()).toArray(), + Optional.empty(), + 100, + true, + JOIN_COMPILER, + TYPE_OPERATOR_FACTORY, + NOOP); + Work work = groupByHash.getGroupIds(page); + // Compare by name since classes are private + assertThat(work).isInstanceOf(clazz); + } }