Skip to content

Commit

Permalink
Add low cardinality dictionary aggregation
Browse files Browse the repository at this point in the history
If the number of combinations of all dictionaries in a page is below certain number,
we can store the results in a small array and reuse found groups
  • Loading branch information
skrzypo987 authored and sopel39 committed Mar 7, 2022
1 parent d4417ea commit 82b4c47
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@ private int registerGroupId(Block dictionary, int positionInDictionary)
return groupId;
}

private class AddPageWork
@VisibleForTesting
class AddPageWork
implements Work<Void>
{
private final Block block;
Expand Down Expand Up @@ -412,7 +413,8 @@ public Void getResult()
}
}

private class AddDictionaryPageWork
@VisibleForTesting
class AddDictionaryPageWork
implements Work<Void>
{
private final Block dictionary;
Expand Down Expand Up @@ -456,7 +458,8 @@ public Void getResult()
}
}

private class AddRunLengthEncodedPageWork
@VisibleForTesting
class AddRunLengthEncodedPageWork
implements Work<Void>
{
private final RunLengthEncodedBlock block;
Expand Down Expand Up @@ -497,7 +500,8 @@ public Void getResult()
}
}

private class GetGroupIdsWork
@VisibleForTesting
class GetGroupIdsWork
implements Work<GroupByIdBlock>
{
private final BlockBuilder blockBuilder;
Expand Down Expand Up @@ -546,7 +550,8 @@ public GroupByIdBlock getResult()
}
}

private class GetDictionaryGroupIdsWork
@VisibleForTesting
class GetDictionaryGroupIdsWork
implements Work<GroupByIdBlock>
{
private final BlockBuilder blockBuilder;
Expand Down Expand Up @@ -600,7 +605,8 @@ public GroupByIdBlock getResult()
}
}

private class GetRunLengthEncodedGroupIdsWork
@VisibleForTesting
class GetRunLengthEncodedGroupIdsWork
implements Work<GroupByIdBlock>
{
private final RunLengthEncodedBlock block;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.LongArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.Type;
import io.trino.sql.gen.JoinCompiler;
Expand All @@ -48,6 +49,7 @@
import static io.trino.util.HashCollisionsEstimator.estimateNumberOfHashCollisions;
import static it.unimi.dsi.fastutil.HashCommon.arraySize;
import static it.unimi.dsi.fastutil.HashCommon.murmurHash3;
import static java.lang.Math.multiplyExact;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

Expand All @@ -57,6 +59,9 @@ public class MultiChannelGroupByHash
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(MultiChannelGroupByHash.class).instanceSize();
private static final float FILL_RATIO = 0.75f;
// Max (page value count / cumulative dictionary size) to trigger the low cardinality case
private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = .25;

private final List<Type> types;
private final List<Type> hashTypes;
private final int[] channels;
Expand Down Expand Up @@ -222,6 +227,9 @@ public Work<?> addPage(Page page)
if (canProcessDictionary(page)) {
return new AddDictionaryPageWork(page);
}
if (canProcessLowCardinalityDictionary(page)) {
return new AddLowCardinalityDictionaryPageWork(page);
}

return new AddNonDictionaryPageWork(page);
}
Expand All @@ -236,6 +244,9 @@ public Work<GroupByIdBlock> getGroupIds(Page page)
if (canProcessDictionary(page)) {
return new GetDictionaryGroupIdsWork(page);
}
if (canProcessLowCardinalityDictionary(page)) {
return new GetLowCardinalityDictionaryGroupIdsWork(page);
}

return new GetNonDictionaryGroupIdsWork(page);
}
Expand Down Expand Up @@ -507,6 +518,25 @@ private boolean canProcessDictionary(Page page)
return true;
}

private boolean canProcessLowCardinalityDictionary(Page page)
{
// We don't have to rely on 'optimizer.dictionary-aggregations' here since there is little to none chance of regression
int positionCount = page.getPositionCount();
long cardinality = 1;
for (int channel : channels) {
if (!(page.getBlock(channel) instanceof DictionaryBlock)) {
return false;
}
cardinality = multiplyExact(cardinality, ((DictionaryBlock) page.getBlock(channel)).getDictionary().getPositionCount());
if (cardinality > positionCount * SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO
|| cardinality > Short.MAX_VALUE) { // Need to fit into short array
return false;
}
}

return true;
}

private boolean isRunLengthEncoded(Page page)
{
for (int channel : channels) {
Expand Down Expand Up @@ -561,7 +591,8 @@ public void setProcessed(int position, int groupId)
}
}

private class AddNonDictionaryPageWork
@VisibleForTesting
class AddNonDictionaryPageWork
implements Work<Void>
{
private final Page page;
Expand Down Expand Up @@ -602,7 +633,8 @@ public Void getResult()
}
}

private class AddDictionaryPageWork
@VisibleForTesting
class AddDictionaryPageWork
implements Work<Void>
{
private final Page page;
Expand Down Expand Up @@ -649,7 +681,49 @@ public Void getResult()
}
}

private class AddRunLengthEncodedPageWork
class AddLowCardinalityDictionaryPageWork
extends LowCardinalityDictionaryWork<Void>
{
public AddLowCardinalityDictionaryPageWork(Page page)
{
super(page);
}

@Override
public boolean process()
{
// needRehash() == false indicates we have reached capacity boundary and a rehash is needed.
// We can only proceed if tryRehash() successfully did a rehash.
if (needRehash() && !tryRehash()) {
return false;
}

int[] combinationIdToPosition = new int[maxCardinality];
Arrays.fill(combinationIdToPosition, -1);
calculateCombinationIdsToPositionMapping(combinationIdToPosition);

// putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so.
// Therefore needRehash will not generally return true even if we have just crossed the capacity boundary.
for (int i = 0; i < maxCardinality; i++) {
if (needRehash()) {
return false;
}
if (combinationIdToPosition[i] != -1) {
putIfAbsent(combinationIdToPosition[i], page);
}
}
return true;
}

@Override
public Void getResult()
{
throw new UnsupportedOperationException();
}
}

@VisibleForTesting
class AddRunLengthEncodedPageWork
implements Work<Void>
{
private final Page page;
Expand Down Expand Up @@ -690,7 +764,8 @@ public Void getResult()
}
}

private class GetNonDictionaryGroupIdsWork
@VisibleForTesting
class GetNonDictionaryGroupIdsWork
implements Work<GroupByIdBlock>
{
private final BlockBuilder blockBuilder;
Expand Down Expand Up @@ -739,7 +814,64 @@ public GroupByIdBlock getResult()
}
}

private class GetDictionaryGroupIdsWork
@VisibleForTesting
class GetLowCardinalityDictionaryGroupIdsWork
extends LowCardinalityDictionaryWork<GroupByIdBlock>
{
private final long[] groupIds;
private boolean finished;

public GetLowCardinalityDictionaryGroupIdsWork(Page page)
{
super(page);
groupIds = new long[page.getPositionCount()];
}

@Override
public boolean process()
{
// needRehash() == false indicates we have reached capacity boundary and a rehash is needed.
// We can only proceed if tryRehash() successfully did a rehash.
if (needRehash() && !tryRehash()) {
return false;
}

int positionCount = page.getPositionCount();
int[] combinationIdToPosition = new int[maxCardinality];
Arrays.fill(combinationIdToPosition, -1);
short[] positionToCombinationId = calculateCombinationIdsToPositionMapping(combinationIdToPosition);
int[] combinationIdToGroupId = new int[maxCardinality];

// putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so.
// Therefore needRehash will not generally return true even if we have just crossed the capacity boundary.
for (int i = 0; i < maxCardinality; i++) {
if (needRehash()) {
return false;
}
if (combinationIdToPosition[i] != -1) {
combinationIdToGroupId[i] = putIfAbsent(combinationIdToPosition[i], page);
}
else {
combinationIdToGroupId[i] = -1;
}
}
for (int i = 0; i < positionCount; i++) {
groupIds[i] = combinationIdToGroupId[positionToCombinationId[i]];
}
return true;
}

@Override
public GroupByIdBlock getResult()
{
checkState(!finished, "result has produced");
finished = true;
return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds));
}
}

@VisibleForTesting
class GetDictionaryGroupIdsWork
implements Work<GroupByIdBlock>
{
private final BlockBuilder blockBuilder;
Expand Down Expand Up @@ -797,7 +929,8 @@ public GroupByIdBlock getResult()
}
}

private class GetRunLengthEncodedGroupIdsWork
@VisibleForTesting
class GetRunLengthEncodedGroupIdsWork
implements Work<GroupByIdBlock>
{
private final Page page;
Expand Down Expand Up @@ -846,4 +979,56 @@ public GroupByIdBlock getResult()
page.getPositionCount()));
}
}

private abstract class LowCardinalityDictionaryWork<T>
implements Work<T>
{
protected final Page page;
protected final int maxCardinality;
protected final int[] dictionarySizes;
protected final DictionaryBlock[] blocks;

public LowCardinalityDictionaryWork(Page page)
{
this.page = requireNonNull(page, "page is null");
dictionarySizes = new int[channels.length];
blocks = new DictionaryBlock[channels.length];
int maxCardinality = 1;
for (int i = 0; i < channels.length; i++) {
Block block = page.getBlock(channels[i]);
verify(block instanceof DictionaryBlock, "Only dictionary blocks are supported");
blocks[i] = (DictionaryBlock) block;
int blockPositionCount = blocks[i].getDictionary().getPositionCount();
dictionarySizes[i] = blockPositionCount;
maxCardinality *= blockPositionCount;
}
this.maxCardinality = maxCardinality;
}

/**
* Returns combinations of all dictionaries ids for every position and populates
* samplePositions array with a single occurrence of every used combination
*/
protected short[] calculateCombinationIdsToPositionMapping(int[] combinationIdToPosition)
{
int positionCount = page.getPositionCount();
// short arrays improve performance compared to int
short[] combinationIds = new short[positionCount];

for (int i = 0; i < positionCount; i++) {
combinationIds[i] = (short) blocks[0].getId(i);
}
for (int j = 1; j < channels.length; j++) {
for (int i = 0; i < positionCount; i++) {
combinationIds[i] *= dictionarySizes[j];
combinationIds[i] += blocks[j].getId(i);
}
}

for (int i = 0; i < positionCount; i++) {
combinationIdToPosition[combinationIds[i]] = i;
}
return combinationIds;
}
}
}
Loading

0 comments on commit 82b4c47

Please sign in to comment.