Skip to content

Commit

Permalink
Change rehashing strategy in MultiChannelGroupByHash
Browse files Browse the repository at this point in the history
Previously the hash table capacity was checked every row to see whether a rehash
is needed. Now the input page is split into batches and it is assumed that every
row in batch will create a new group (which is rarely the case) and rehashing
is done in advance before processing.
This may slightly increase memory footprint for small number of groups, however
there is a tiny performance gain as the capacity is not checked every row.
  • Loading branch information
skrzypo987 authored and lukasz-stec committed May 26, 2022
1 parent e8d0978 commit 88cd492
Showing 1 changed file with 45 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import static io.trino.util.HashCollisionsEstimator.estimateNumberOfHashCollisions;
import static it.unimi.dsi.fastutil.HashCommon.arraySize;
import static it.unimi.dsi.fastutil.HashCommon.murmurHash3;
import static java.lang.Math.min;
import static java.lang.Math.multiplyExact;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
Expand All @@ -57,6 +58,7 @@ public class MultiChannelGroupByHash
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(MultiChannelGroupByHash.class).instanceSize();
private static final float FILL_RATIO = 0.75f;
private static final int BATCH_SIZE = 1024;
// Max (page value count / cumulative dictionary size) to trigger the low cardinality case
private static final double SMALL_DICTIONARIES_MAX_CARDINALITY_RATIO = .25;
private static final int VALUES_PAGE_BITS = 14; // 16k positions
Expand Down Expand Up @@ -578,7 +580,6 @@ class AddNonDictionaryPageWork
implements Work<Void>
{
private final Page page;

private int lastPosition;

public AddNonDictionaryPageWork(Page page)
Expand All @@ -591,21 +592,23 @@ public boolean process()
{
int positionCount = page.getPositionCount();
checkState(lastPosition < positionCount, "position count out of bound");
int remainingPositions = positionCount - lastPosition;

// needRehash() == true indicates we have reached capacity boundary and a rehash is needed.
// We can only proceed if tryRehash() successfully did a rehash.
if (needRehash() && !tryRehash()) {
return false;
}
while (remainingPositions != 0) {
int batchSize = min(remainingPositions, BATCH_SIZE);
if (!ensureHashTableSize(batchSize)) {
return false;
}

// putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so.
// Therefore needRehash will not generally return true even if we have just crossed the capacity boundary.
while (lastPosition < positionCount && !needRehash()) {
// get the group for the current row
putIfAbsent(lastPosition, page);
lastPosition++;
for (int i = lastPosition; i < lastPosition + batchSize; i++) {
putIfAbsent(i, page);
}

lastPosition += batchSize;
remainingPositions -= batchSize;
}
return lastPosition == positionCount;
verify(lastPosition == positionCount);
return true;
}

@Override
Expand Down Expand Up @@ -774,23 +777,27 @@ public GetNonDictionaryGroupIdsWork(Page page)
public boolean process()
{
int positionCount = page.getPositionCount();
checkState(lastPosition <= positionCount, "position count out of bound");
checkState(lastPosition < positionCount, "position count out of bound");
checkState(!finished);

// needRehash() == false indicates we have reached capacity boundary and a rehash is needed.
// We can only proceed if tryRehash() successfully did a rehash.
if (needRehash() && !tryRehash()) {
return false;
}
int remainingPositions = positionCount - lastPosition;

// putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so.
// Therefore needRehash will not generally return true even if we have just crossed the capacity boundary.
while (lastPosition < positionCount && !needRehash()) {
// output the group id for this row
groupIds[lastPosition] = putIfAbsent(lastPosition, page);
lastPosition++;
while (remainingPositions != 0) {
int batchSize = min(remainingPositions, BATCH_SIZE);
if (!ensureHashTableSize(batchSize)) {
return false;
}

for (int i = lastPosition; i < lastPosition + batchSize; i++) {
// output the group id for this row
groupIds[i] = putIfAbsent(i, page);
}

lastPosition += batchSize;
remainingPositions -= batchSize;
}
return lastPosition == positionCount;
verify(lastPosition == positionCount);
return true;
}

@Override
Expand Down Expand Up @@ -1021,4 +1028,16 @@ private int calculatePositionToCombinationIdMapping(Page page, short[] positionT
}
return maxCardinality;
}

private boolean ensureHashTableSize(int batchSize)
{
int positionCountUntilRehash = maxFill - nextGroupId;
while (positionCountUntilRehash < batchSize) {
if (!tryRehash()) {
return false;
}
positionCountUntilRehash = maxFill - nextGroupId;
}
return true;
}
}

0 comments on commit 88cd492

Please sign in to comment.