Skip to content

Commit

Permalink
Replace recursive with loop in PackedValuesBlockHash (elastic#99992)
Browse files Browse the repository at this point in the history
This change replaces the recursion with a loop when packing and hashing 
multiple keys. While the recursive version is clever, it may not be as
straightforward for future readers. Using a loop also helps us avoid
StackOverflow when grouping by a large number of keys.
  • Loading branch information
dnhatn authored Sep 29, 2023
1 parent f8d09e9 commit 44a2d68
Showing 1 changed file with 95 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import org.elasticsearch.compute.operator.BatchEncoder;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.MultivalueDedupe;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;

import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -51,19 +49,20 @@
* }</pre>
*/
final class PackedValuesBlockHash extends BlockHash {
private static final Logger logger = LogManager.getLogger(PackedValuesBlockHash.class);
static final int DEFAULT_BATCH_SIZE = Math.toIntExact(ByteSizeValue.ofKb(10).getBytes());

private final List<HashAggregationOperator.GroupSpec> groups;
private final int emitBatchSize;
private final BytesRefHash bytesRefHash;
private final int nullTrackingBytes;
private final BytesRef scratch = new BytesRef();
private final BytesRefBuilder bytes = new BytesRefBuilder();
private final Group[] groups;

PackedValuesBlockHash(List<HashAggregationOperator.GroupSpec> groups, BigArrays bigArrays, int emitBatchSize) {
this.groups = groups;
PackedValuesBlockHash(List<HashAggregationOperator.GroupSpec> specs, BigArrays bigArrays, int emitBatchSize) {
this.groups = specs.stream().map(Group::new).toArray(Group[]::new);
this.emitBatchSize = emitBatchSize;
this.bytesRefHash = new BytesRefHash(1, bigArrays);
this.nullTrackingBytes = groups.size() / 8 + 1;
this.nullTrackingBytes = (groups.length + 7) / 8;
}

@Override
Expand All @@ -75,23 +74,28 @@ void add(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize)
new AddWork(page, addInput, batchSize).add();
}

private static class Group {
final HashAggregationOperator.GroupSpec spec;
BatchEncoder encoder;
int positionOffset;
int valueOffset;
int loopedIndex;
int valueCount;
int bytesStart;

Group(HashAggregationOperator.GroupSpec spec) {
this.spec = spec;
}
}

class AddWork extends LongLongBlockHash.AbstractAddBlock {
final BatchEncoder[] encoders = new BatchEncoder[groups.size()];
final int[] positionOffsets = new int[groups.size()];
final int[] valueOffsets = new int[groups.size()];
final BytesRef[] scratches = new BytesRef[groups.size()];
final BytesRefBuilder bytes = new BytesRefBuilder();
final int positionCount;

int position;
int count;
int bufferedGroup;

AddWork(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) {
super(emitBatchSize, addInput);
for (int g = 0; g < groups.size(); g++) {
encoders[g] = MultivalueDedupe.batchEncoder(page.getBlock(groups.get(g).channel()), batchSize);
scratches[g] = new BytesRef();
for (Group group : groups) {
group.encoder = MultivalueDedupe.batchEncoder(page.getBlock(group.spec.channel()), batchSize);
}
bytes.grow(nullTrackingBytes);
this.positionCount = page.getPositionCount();
Expand All @@ -104,108 +108,103 @@ class AddWork extends LongLongBlockHash.AbstractAddBlock {
*/
void add() {
for (position = 0; position < positionCount; position++) {
if (logger.isTraceEnabled()) {
logger.trace("position {}", position);
}
// Make sure all encoders have encoded the current position and the offsets are queued to it's start
for (int g = 0; g < encoders.length; g++) {
positionOffsets[g]++;
while (positionOffsets[g] >= encoders[g].positionCount()) {
encoders[g].encodeNextBatch();
positionOffsets[g] = 0;
valueOffsets[g] = 0;
boolean singleEntry = true;
for (Group g : groups) {
var encoder = g.encoder;
g.positionOffset++;
while (g.positionOffset >= encoder.positionCount()) {
encoder.encodeNextBatch();
g.positionOffset = 0;
g.valueOffset = 0;
}
g.valueCount = encoder.valueCount(g.positionOffset);
singleEntry &= (g.valueCount == 1);
}

count = 0;
Arrays.fill(bytes.bytes(), 0, nullTrackingBytes, (byte) 0);
bytes.setLength(nullTrackingBytes);
addPosition(0);
switch (count) {
case 0 -> throw new IllegalStateException("didn't find any values");
case 1 -> {
ords.appendInt(bufferedGroup);
addedValue(position);
}
default -> ords.endPositionEntry();
}
for (int g = 0; g < encoders.length; g++) {
valueOffsets[g] += encoders[g].valueCount(positionOffsets[g]);
if (singleEntry) {
addSingleEntry();
} else {
addMultipleEntries();
}
}
emitOrds();
}

private void addPosition(int g) {
if (g == groups.size()) {
addBytes();
return;
}
int start = bytes.length();
int count = encoders[g].valueCount(positionOffsets[g]);
assert count > 0;
int valueOffset = valueOffsets[g];
BytesRef v = encoders[g].read(valueOffset++, scratches[g]);
if (logger.isTraceEnabled()) {
logger.trace("\t".repeat(g + 1) + v);
}
if (v.length == 0) {
assert count == 1 : "null value in non-singleton list";
int nullByte = g / 8;
int nullShift = g % 8;
bytes.bytes()[nullByte] |= (byte) (1 << nullShift);
}
bytes.setLength(start);
bytes.append(v);
addPosition(g + 1); // TODO stack overflow protection
for (int i = 1; i < count; i++) {
v = encoders[g].read(valueOffset++, scratches[g]);
if (logger.isTraceEnabled()) {
logger.trace("\t".repeat(g + 1) + v);
private void addSingleEntry() {
for (int g = 0; g < groups.length; g++) {
Group group = groups[g];
BytesRef v = group.encoder.read(group.valueOffset++, scratch);
if (v.length == 0) {
int nullByte = g / 8;
int nullShift = g % 8;
bytes.bytes()[nullByte] |= (byte) (1 << nullShift);
} else {
bytes.append(v);
}
assert v.length > 0 : "null value after the first position";
bytes.setLength(start);
bytes.append(v);
addPosition(g + 1);
}
int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get())));
ords.appendInt(ord);
addedValue(position);
}

private void addBytes() {
int group = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get())));
switch (count) {
case 0 -> bufferedGroup = group;
case 1 -> {
ords.beginPositionEntry();
ords.appendInt(bufferedGroup);
addedValueInMultivaluePosition(position);
ords.appendInt(group);
addedValueInMultivaluePosition(position);
private void addMultipleEntries() {
ords.beginPositionEntry();
int g = 0;
outer: for (;;) {
for (; g < groups.length; g++) {
Group group = groups[g];
group.bytesStart = bytes.length();
BytesRef v = group.encoder.read(group.valueOffset + group.loopedIndex, scratch);
++group.loopedIndex;
if (v.length == 0) {
assert group.valueCount == 1 : "null value in non-singleton list";
int nullByte = g / 8;
int nullShift = g % 8;
bytes.bytes()[nullByte] |= (byte) (1 << nullShift);
} else {
bytes.append(v);
}
}
default -> {
ords.appendInt(group);
addedValueInMultivaluePosition(position);
// emit ords
int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get())));
ords.appendInt(ord);
addedValueInMultivaluePosition(position);

// rewind
Group group = groups[--g];
bytes.setLength(group.bytesStart);
while (group.loopedIndex == group.valueCount) {
group.loopedIndex = 0;
if (g == 0) {
break outer;
} else {
group = groups[--g];
bytes.setLength(group.bytesStart);
}
}
}
count++;
if (logger.isTraceEnabled()) {
logger.trace("{} = {}", bytes.get(), group);
ords.endPositionEntry();
for (Group group : groups) {
group.valueOffset += group.valueCount;
}
}
}

@Override
public Block[] getKeys() {
int size = Math.toIntExact(bytesRefHash.size());
BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[groups.size()];
Block.Builder[] builders = new Block.Builder[groups.size()];
BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[groups.length];
Block.Builder[] builders = new Block.Builder[groups.length];
for (int g = 0; g < builders.length; g++) {
ElementType elementType = groups.get(g).elementType();
ElementType elementType = groups[g].spec.elementType();
decoders[g] = BatchEncoder.decoder(elementType);
builders[g] = elementType.newBlockBuilder(size);
}

BytesRef values[] = new BytesRef[(int) Math.min(100, bytesRefHash.size())];
BytesRef nulls[] = new BytesRef[values.length];
BytesRef[] values = new BytesRef[(int) Math.min(100, bytesRefHash.size())];
BytesRef[] nulls = new BytesRef[values.length];
for (int offset = 0; offset < values.length; offset++) {
values[offset] = new BytesRef();
nulls[offset] = new BytesRef();
Expand All @@ -231,7 +230,7 @@ public Block[] getKeys() {
readKeys(decoders, builders, nulls, values, offset);
}

Block[] keyBlocks = new Block[groups.size()];
Block[] keyBlocks = new Block[groups.length];
for (int g = 0; g < keyBlocks.length; g++) {
keyBlocks[g] = builders[g].build();
}
Expand Down Expand Up @@ -271,13 +270,12 @@ public String toString() {
StringBuilder b = new StringBuilder();
b.append("PackedValuesBlockHash{groups=[");
boolean first = true;
for (HashAggregationOperator.GroupSpec spec : groups) {
if (first) {
first = false;
} else {
for (int i = 0; i < groups.length; i++) {
if (i > 0) {
b.append(", ");
}
b.append(spec.channel()).append(':').append(spec.elementType());
Group group = groups[i];
b.append(group.spec.channel()).append(':').append(group.spec.elementType());
}
b.append("], entries=").append(bytesRefHash.size());
b.append(", size=").append(ByteSizeValue.ofBytes(bytesRefHash.ramBytesUsed()));
Expand Down

0 comments on commit 44a2d68

Please sign in to comment.