From 44a2d68f197094dc3af8d053ff166e0228256a71 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 28 Sep 2023 17:45:17 -0700 Subject: [PATCH] Replace recursive with loop in PackedValuesBlockHash (#99992) This change replaces the recursion with a loop when packing and hashing multiple keys. While the recursive version is clever, it may not be as straightforward for future readers. Using a loop also helps us avoid StackOverflow when grouping by a large number of keys. --- .../blockhash/PackedValuesBlockHash.java | 192 +++++++++--------- 1 file changed, 95 insertions(+), 97 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index 31f65e9b70053..7ecaddf2092fa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -22,8 +22,6 @@ import org.elasticsearch.compute.operator.BatchEncoder; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.MultivalueDedupe; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; import java.util.Arrays; import java.util.List; @@ -51,19 +49,20 @@ * } */ final class PackedValuesBlockHash extends BlockHash { - private static final Logger logger = LogManager.getLogger(PackedValuesBlockHash.class); static final int DEFAULT_BATCH_SIZE = Math.toIntExact(ByteSizeValue.ofKb(10).getBytes()); - private final List groups; private final int emitBatchSize; private final BytesRefHash bytesRefHash; private final int nullTrackingBytes; + private final BytesRef scratch = new BytesRef(); + private final BytesRefBuilder bytes = new BytesRefBuilder(); + private final Group[] groups; - PackedValuesBlockHash(List groups, BigArrays bigArrays, int emitBatchSize) { - this.groups = groups; + PackedValuesBlockHash(List specs, BigArrays bigArrays, int emitBatchSize) { + this.groups = specs.stream().map(Group::new).toArray(Group[]::new); this.emitBatchSize = emitBatchSize; this.bytesRefHash = new BytesRefHash(1, bigArrays); - this.nullTrackingBytes = groups.size() / 8 + 1; + this.nullTrackingBytes = (groups.length + 7) / 8; } @Override @@ -75,23 +74,28 @@ void add(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) new AddWork(page, addInput, batchSize).add(); } + private static class Group { + final HashAggregationOperator.GroupSpec spec; + BatchEncoder encoder; + int positionOffset; + int valueOffset; + int loopedIndex; + int valueCount; + int bytesStart; + + Group(HashAggregationOperator.GroupSpec spec) { + this.spec = spec; + } + } + class AddWork extends LongLongBlockHash.AbstractAddBlock { - final BatchEncoder[] encoders = new BatchEncoder[groups.size()]; - final int[] positionOffsets = new int[groups.size()]; - final int[] valueOffsets = new int[groups.size()]; - final BytesRef[] scratches = new BytesRef[groups.size()]; - final BytesRefBuilder bytes = new BytesRefBuilder(); final int positionCount; - int position; - int count; - int bufferedGroup; AddWork(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) { super(emitBatchSize, addInput); - for (int g = 0; g < groups.size(); g++) { - encoders[g] = MultivalueDedupe.batchEncoder(page.getBlock(groups.get(g).channel()), batchSize); - scratches[g] = new BytesRef(); + for (Group group : groups) { + group.encoder = MultivalueDedupe.batchEncoder(page.getBlock(group.spec.channel()), batchSize); } bytes.grow(nullTrackingBytes); this.positionCount = page.getPositionCount(); @@ -104,91 +108,86 @@ class AddWork extends LongLongBlockHash.AbstractAddBlock { */ void add() { for (position = 0; position < positionCount; position++) { - if (logger.isTraceEnabled()) { - logger.trace("position {}", position); - } // Make sure all encoders have encoded the current position and the offsets are queued to it's start - for (int g = 0; g < encoders.length; g++) { - positionOffsets[g]++; - while (positionOffsets[g] >= encoders[g].positionCount()) { - encoders[g].encodeNextBatch(); - positionOffsets[g] = 0; - valueOffsets[g] = 0; + boolean singleEntry = true; + for (Group g : groups) { + var encoder = g.encoder; + g.positionOffset++; + while (g.positionOffset >= encoder.positionCount()) { + encoder.encodeNextBatch(); + g.positionOffset = 0; + g.valueOffset = 0; } + g.valueCount = encoder.valueCount(g.positionOffset); + singleEntry &= (g.valueCount == 1); } - - count = 0; Arrays.fill(bytes.bytes(), 0, nullTrackingBytes, (byte) 0); bytes.setLength(nullTrackingBytes); - addPosition(0); - switch (count) { - case 0 -> throw new IllegalStateException("didn't find any values"); - case 1 -> { - ords.appendInt(bufferedGroup); - addedValue(position); - } - default -> ords.endPositionEntry(); - } - for (int g = 0; g < encoders.length; g++) { - valueOffsets[g] += encoders[g].valueCount(positionOffsets[g]); + if (singleEntry) { + addSingleEntry(); + } else { + addMultipleEntries(); } } emitOrds(); } - private void addPosition(int g) { - if (g == groups.size()) { - addBytes(); - return; - } - int start = bytes.length(); - int count = encoders[g].valueCount(positionOffsets[g]); - assert count > 0; - int valueOffset = valueOffsets[g]; - BytesRef v = encoders[g].read(valueOffset++, scratches[g]); - if (logger.isTraceEnabled()) { - logger.trace("\t".repeat(g + 1) + v); - } - if (v.length == 0) { - assert count == 1 : "null value in non-singleton list"; - int nullByte = g / 8; - int nullShift = g % 8; - bytes.bytes()[nullByte] |= (byte) (1 << nullShift); - } - bytes.setLength(start); - bytes.append(v); - addPosition(g + 1); // TODO stack overflow protection - for (int i = 1; i < count; i++) { - v = encoders[g].read(valueOffset++, scratches[g]); - if (logger.isTraceEnabled()) { - logger.trace("\t".repeat(g + 1) + v); + private void addSingleEntry() { + for (int g = 0; g < groups.length; g++) { + Group group = groups[g]; + BytesRef v = group.encoder.read(group.valueOffset++, scratch); + if (v.length == 0) { + int nullByte = g / 8; + int nullShift = g % 8; + bytes.bytes()[nullByte] |= (byte) (1 << nullShift); + } else { + bytes.append(v); } - assert v.length > 0 : "null value after the first position"; - bytes.setLength(start); - bytes.append(v); - addPosition(g + 1); } + int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); + ords.appendInt(ord); + addedValue(position); } - private void addBytes() { - int group = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); - switch (count) { - case 0 -> bufferedGroup = group; - case 1 -> { - ords.beginPositionEntry(); - ords.appendInt(bufferedGroup); - addedValueInMultivaluePosition(position); - ords.appendInt(group); - addedValueInMultivaluePosition(position); + private void addMultipleEntries() { + ords.beginPositionEntry(); + int g = 0; + outer: for (;;) { + for (; g < groups.length; g++) { + Group group = groups[g]; + group.bytesStart = bytes.length(); + BytesRef v = group.encoder.read(group.valueOffset + group.loopedIndex, scratch); + ++group.loopedIndex; + if (v.length == 0) { + assert group.valueCount == 1 : "null value in non-singleton list"; + int nullByte = g / 8; + int nullShift = g % 8; + bytes.bytes()[nullByte] |= (byte) (1 << nullShift); + } else { + bytes.append(v); + } } - default -> { - ords.appendInt(group); - addedValueInMultivaluePosition(position); + // emit ords + int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); + ords.appendInt(ord); + addedValueInMultivaluePosition(position); + + // rewind + Group group = groups[--g]; + bytes.setLength(group.bytesStart); + while (group.loopedIndex == group.valueCount) { + group.loopedIndex = 0; + if (g == 0) { + break outer; + } else { + group = groups[--g]; + bytes.setLength(group.bytesStart); + } } } - count++; - if (logger.isTraceEnabled()) { - logger.trace("{} = {}", bytes.get(), group); + ords.endPositionEntry(); + for (Group group : groups) { + group.valueOffset += group.valueCount; } } } @@ -196,16 +195,16 @@ private void addBytes() { @Override public Block[] getKeys() { int size = Math.toIntExact(bytesRefHash.size()); - BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[groups.size()]; - Block.Builder[] builders = new Block.Builder[groups.size()]; + BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[groups.length]; + Block.Builder[] builders = new Block.Builder[groups.length]; for (int g = 0; g < builders.length; g++) { - ElementType elementType = groups.get(g).elementType(); + ElementType elementType = groups[g].spec.elementType(); decoders[g] = BatchEncoder.decoder(elementType); builders[g] = elementType.newBlockBuilder(size); } - BytesRef values[] = new BytesRef[(int) Math.min(100, bytesRefHash.size())]; - BytesRef nulls[] = new BytesRef[values.length]; + BytesRef[] values = new BytesRef[(int) Math.min(100, bytesRefHash.size())]; + BytesRef[] nulls = new BytesRef[values.length]; for (int offset = 0; offset < values.length; offset++) { values[offset] = new BytesRef(); nulls[offset] = new BytesRef(); @@ -231,7 +230,7 @@ public Block[] getKeys() { readKeys(decoders, builders, nulls, values, offset); } - Block[] keyBlocks = new Block[groups.size()]; + Block[] keyBlocks = new Block[groups.length]; for (int g = 0; g < keyBlocks.length; g++) { keyBlocks[g] = builders[g].build(); } @@ -271,13 +270,12 @@ public String toString() { StringBuilder b = new StringBuilder(); b.append("PackedValuesBlockHash{groups=["); boolean first = true; - for (HashAggregationOperator.GroupSpec spec : groups) { - if (first) { - first = false; - } else { + for (int i = 0; i < groups.length; i++) { + if (i > 0) { b.append(", "); } - b.append(spec.channel()).append(':').append(spec.elementType()); + Group group = groups[i]; + b.append(group.spec.channel()).append(':').append(group.spec.elementType()); } b.append("], entries=").append(bytesRefHash.size()); b.append(", size=").append(ByteSizeValue.ofBytes(bytesRefHash.ramBytesUsed()));