From aeec7162f9d436ee7115ad7ff9bcb041e5b1b438 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 27 Sep 2023 22:48:26 -0700 Subject: [PATCH 1/6] Replace recursive with loop in PackedValuesBlockHash --- .../blockhash/PackedValuesBlockHash.java | 126 ++++++++---------- 1 file changed, 57 insertions(+), 69 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index 31f65e9b70053..46605b2395e80 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -79,19 +79,19 @@ class AddWork extends LongLongBlockHash.AbstractAddBlock { final BatchEncoder[] encoders = new BatchEncoder[groups.size()]; final int[] positionOffsets = new int[groups.size()]; final int[] valueOffsets = new int[groups.size()]; - final BytesRef[] scratches = new BytesRef[groups.size()]; + final BytesRef scratch = new BytesRef(); + final int[] loopedIndices = new int[groups.size()]; + final int[] valueCounts = new int[groups.size()]; + final int[] bytesStarts = new int[groups.size()]; final BytesRefBuilder bytes = new BytesRefBuilder(); final int positionCount; int position; - int count; - int bufferedGroup; AddWork(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) { super(emitBatchSize, addInput); for (int g = 0; g < groups.size(); g++) { encoders[g] = MultivalueDedupe.batchEncoder(page.getBlock(groups.get(g).channel()), batchSize); - scratches[g] = new BytesRef(); } bytes.grow(nullTrackingBytes); this.positionCount = page.getPositionCount(); @@ -104,91 +104,79 @@ class AddWork extends LongLongBlockHash.AbstractAddBlock { */ void add() { for (position = 0; position < positionCount; position++) { - if (logger.isTraceEnabled()) { - logger.trace("position {}", position); - } // Make sure all encoders have encoded the current position and the offsets are queued to it's start + boolean singleEntry = true; for (int g = 0; g < encoders.length; g++) { positionOffsets[g]++; - while (positionOffsets[g] >= encoders[g].positionCount()) { + if (positionOffsets[g] >= encoders[g].positionCount()) { encoders[g].encodeNextBatch(); positionOffsets[g] = 0; valueOffsets[g] = 0; } + valueCounts[g] = encoders[g].valueCount(positionOffsets[g]); + singleEntry &= (valueCounts[g] == 1); } - - count = 0; Arrays.fill(bytes.bytes(), 0, nullTrackingBytes, (byte) 0); bytes.setLength(nullTrackingBytes); - addPosition(0); - switch (count) { - case 0 -> throw new IllegalStateException("didn't find any values"); - case 1 -> { - ords.appendInt(bufferedGroup); - addedValue(position); - } - default -> ords.endPositionEntry(); - } - for (int g = 0; g < encoders.length; g++) { - valueOffsets[g] += encoders[g].valueCount(positionOffsets[g]); + if (singleEntry) { + addSingleEntry(); + } else { + addMultipleEntries(); } } emitOrds(); } - private void addPosition(int g) { - if (g == groups.size()) { - addBytes(); - return; - } - int start = bytes.length(); - int count = encoders[g].valueCount(positionOffsets[g]); - assert count > 0; - int valueOffset = valueOffsets[g]; - BytesRef v = encoders[g].read(valueOffset++, scratches[g]); - if (logger.isTraceEnabled()) { - logger.trace("\t".repeat(g + 1) + v); - } - if (v.length == 0) { - assert count == 1 : "null value in non-singleton list"; - int nullByte = g / 8; - int nullShift = g % 8; - bytes.bytes()[nullByte] |= (byte) (1 << nullShift); - } - bytes.setLength(start); - bytes.append(v); - addPosition(g + 1); // TODO stack overflow protection - for (int i = 1; i < count; i++) { - v = encoders[g].read(valueOffset++, scratches[g]); - if (logger.isTraceEnabled()) { - logger.trace("\t".repeat(g + 1) + v); + private void addSingleEntry() { + for (int g = 0; g < encoders.length; g++) { + BytesRef v = encoders[g].read(valueOffsets[g]++, scratch); + if (v.length == 0) { + int nullByte = g / 8; + int nullShift = g % 8; + bytes.bytes()[nullByte] |= (byte) (1 << nullShift); + } else { + bytes.append(v); } - assert v.length > 0 : "null value after the first position"; - bytes.setLength(start); - bytes.append(v); - addPosition(g + 1); } + int group = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); + ords.appendInt(group); + addedValue(position); } - private void addBytes() { - int group = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); - switch (count) { - case 0 -> bufferedGroup = group; - case 1 -> { - ords.beginPositionEntry(); - ords.appendInt(bufferedGroup); - addedValueInMultivaluePosition(position); - ords.appendInt(group); - addedValueInMultivaluePosition(position); + private void addMultipleEntries() { + ords.beginPositionEntry(); + int g = 0; + outer: for (;;) { + for (; g < encoders.length; g++) { + BytesRef v = encoders[g].read(valueOffsets[g] + loopedIndices[g], scratch); + ++loopedIndices[g]; + if (v.length == 0) { + int nullByte = g / 8; + int nullShift = g % 8; + bytes.bytes()[nullByte] |= (byte) (1 << nullShift); + } else { + bytes.append(v); + bytesStarts[g] = bytes.length(); + } } - default -> { - ords.appendInt(group); - addedValueInMultivaluePosition(position); + // emit ords + int group = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); + ords.appendInt(group); + addedValueInMultivaluePosition(position); + + // rewind + --g; + while (loopedIndices[g] == valueCounts[g]) { + loopedIndices[g] = 0; + if (g == 0) { + break outer; + } + bytes.setLength(bytesStarts[g--]); } } - count++; - if (logger.isTraceEnabled()) { - logger.trace("{} = {}", bytes.get(), group); + ords.endPositionEntry(); + for (g = 0; g < encoders.length; g++) { + valueOffsets[g] += valueCounts[g]; } } } @@ -204,8 +192,8 @@ public Block[] getKeys() { builders[g] = elementType.newBlockBuilder(size); } - BytesRef values[] = new BytesRef[(int) Math.min(100, bytesRefHash.size())]; - BytesRef nulls[] = new BytesRef[values.length]; + BytesRef[] values = new BytesRef[(int) Math.min(100, bytesRefHash.size())]; + BytesRef[] nulls = new BytesRef[values.length]; for (int offset = 0; offset < values.length; offset++) { values[offset] = new BytesRef(); nulls[offset] = new BytesRef(); From db4a8a2458d4e2a27c8526e3942b2fa3fbde5447 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 28 Sep 2023 08:57:22 -0700 Subject: [PATCH 2/6] Fix BytesStarts --- .../aggregation/blockhash/PackedValuesBlockHash.java | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index 46605b2395e80..c4ea8f5da98a6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -22,8 +22,6 @@ import org.elasticsearch.compute.operator.BatchEncoder; import org.elasticsearch.compute.operator.HashAggregationOperator; import org.elasticsearch.compute.operator.MultivalueDedupe; -import org.elasticsearch.logging.LogManager; -import org.elasticsearch.logging.Logger; import java.util.Arrays; import java.util.List; @@ -51,7 +49,6 @@ * } */ final class PackedValuesBlockHash extends BlockHash { - private static final Logger logger = LogManager.getLogger(PackedValuesBlockHash.class); static final int DEFAULT_BATCH_SIZE = Math.toIntExact(ByteSizeValue.ofKb(10).getBytes()); private final List groups; @@ -108,7 +105,7 @@ void add() { boolean singleEntry = true; for (int g = 0; g < encoders.length; g++) { positionOffsets[g]++; - if (positionOffsets[g] >= encoders[g].positionCount()) { + while (positionOffsets[g] >= encoders[g].positionCount()) { encoders[g].encodeNextBatch(); positionOffsets[g] = 0; valueOffsets[g] = 0; @@ -148,6 +145,7 @@ private void addMultipleEntries() { int g = 0; outer: for (;;) { for (; g < encoders.length; g++) { + bytesStarts[g] = bytes.length(); BytesRef v = encoders[g].read(valueOffsets[g] + loopedIndices[g], scratch); ++loopedIndices[g]; if (v.length == 0) { @@ -156,7 +154,6 @@ private void addMultipleEntries() { bytes.bytes()[nullByte] |= (byte) (1 << nullShift); } else { bytes.append(v); - bytesStarts[g] = bytes.length(); } } // emit ords @@ -165,13 +162,14 @@ private void addMultipleEntries() { addedValueInMultivaluePosition(position); // rewind - --g; + bytes.setLength(bytesStarts[--g]); while (loopedIndices[g] == valueCounts[g]) { loopedIndices[g] = 0; if (g == 0) { break outer; + } else{ + bytes.setLength(bytesStarts[--g]); } - bytes.setLength(bytesStarts[g--]); } } ords.endPositionEntry(); From 3b1b0c4f5aae7cdd7f65f4a0e03f144acd4d18d1 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 28 Sep 2023 09:17:00 -0700 Subject: [PATCH 3/6] stylecheck --- .../compute/aggregation/blockhash/PackedValuesBlockHash.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index c4ea8f5da98a6..cd4e6315c7d17 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -105,7 +105,7 @@ void add() { boolean singleEntry = true; for (int g = 0; g < encoders.length; g++) { positionOffsets[g]++; - while (positionOffsets[g] >= encoders[g].positionCount()) { + if (positionOffsets[g] >= encoders[g].positionCount()) { encoders[g].encodeNextBatch(); positionOffsets[g] = 0; valueOffsets[g] = 0; @@ -167,7 +167,7 @@ private void addMultipleEntries() { loopedIndices[g] = 0; if (g == 0) { break outer; - } else{ + } else { bytes.setLength(bytesStarts[--g]); } } From 5ae70486efd06d45864424a841565f85a1664840 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 28 Sep 2023 10:30:47 -0700 Subject: [PATCH 4/6] encode next batch --- .../compute/aggregation/blockhash/PackedValuesBlockHash.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index cd4e6315c7d17..11966be021f94 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -105,7 +105,7 @@ void add() { boolean singleEntry = true; for (int g = 0; g < encoders.length; g++) { positionOffsets[g]++; - if (positionOffsets[g] >= encoders[g].positionCount()) { + while (positionOffsets[g] >= encoders[g].positionCount()) { encoders[g].encodeNextBatch(); positionOffsets[g] = 0; valueOffsets[g] = 0; From 472ee39698791b9f5371dbcc2a958ccbe911b95f Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 28 Sep 2023 10:35:45 -0700 Subject: [PATCH 5/6] assertion --- .../compute/aggregation/blockhash/PackedValuesBlockHash.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index 11966be021f94..291465e076e6c 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -149,6 +149,7 @@ private void addMultipleEntries() { BytesRef v = encoders[g].read(valueOffsets[g] + loopedIndices[g], scratch); ++loopedIndices[g]; if (v.length == 0) { + assert valueCounts[g] == 1 : "null value in non-singleton list"; int nullByte = g / 8; int nullShift = g % 8; bytes.bytes()[nullByte] |= (byte) (1 << nullShift); From f41c6e83f1c03f92c021c1d1a8d3ded3bbd0d659 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 28 Sep 2023 12:00:41 -0700 Subject: [PATCH 6/6] Group --- .../blockhash/PackedValuesBlockHash.java | 109 ++++++++++-------- 1 file changed, 60 insertions(+), 49 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java index 291465e076e6c..7ecaddf2092fa 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/PackedValuesBlockHash.java @@ -51,16 +51,18 @@ final class PackedValuesBlockHash extends BlockHash { static final int DEFAULT_BATCH_SIZE = Math.toIntExact(ByteSizeValue.ofKb(10).getBytes()); - private final List groups; private final int emitBatchSize; private final BytesRefHash bytesRefHash; private final int nullTrackingBytes; + private final BytesRef scratch = new BytesRef(); + private final BytesRefBuilder bytes = new BytesRefBuilder(); + private final Group[] groups; - PackedValuesBlockHash(List groups, BigArrays bigArrays, int emitBatchSize) { - this.groups = groups; + PackedValuesBlockHash(List specs, BigArrays bigArrays, int emitBatchSize) { + this.groups = specs.stream().map(Group::new).toArray(Group[]::new); this.emitBatchSize = emitBatchSize; this.bytesRefHash = new BytesRefHash(1, bigArrays); - this.nullTrackingBytes = groups.size() / 8 + 1; + this.nullTrackingBytes = (groups.length + 7) / 8; } @Override @@ -72,23 +74,28 @@ void add(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) new AddWork(page, addInput, batchSize).add(); } + private static class Group { + final HashAggregationOperator.GroupSpec spec; + BatchEncoder encoder; + int positionOffset; + int valueOffset; + int loopedIndex; + int valueCount; + int bytesStart; + + Group(HashAggregationOperator.GroupSpec spec) { + this.spec = spec; + } + } + class AddWork extends LongLongBlockHash.AbstractAddBlock { - final BatchEncoder[] encoders = new BatchEncoder[groups.size()]; - final int[] positionOffsets = new int[groups.size()]; - final int[] valueOffsets = new int[groups.size()]; - final BytesRef scratch = new BytesRef(); - final int[] loopedIndices = new int[groups.size()]; - final int[] valueCounts = new int[groups.size()]; - final int[] bytesStarts = new int[groups.size()]; - final BytesRefBuilder bytes = new BytesRefBuilder(); final int positionCount; - int position; AddWork(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) { super(emitBatchSize, addInput); - for (int g = 0; g < groups.size(); g++) { - encoders[g] = MultivalueDedupe.batchEncoder(page.getBlock(groups.get(g).channel()), batchSize); + for (Group group : groups) { + group.encoder = MultivalueDedupe.batchEncoder(page.getBlock(group.spec.channel()), batchSize); } bytes.grow(nullTrackingBytes); this.positionCount = page.getPositionCount(); @@ -103,15 +110,16 @@ void add() { for (position = 0; position < positionCount; position++) { // Make sure all encoders have encoded the current position and the offsets are queued to it's start boolean singleEntry = true; - for (int g = 0; g < encoders.length; g++) { - positionOffsets[g]++; - while (positionOffsets[g] >= encoders[g].positionCount()) { - encoders[g].encodeNextBatch(); - positionOffsets[g] = 0; - valueOffsets[g] = 0; + for (Group g : groups) { + var encoder = g.encoder; + g.positionOffset++; + while (g.positionOffset >= encoder.positionCount()) { + encoder.encodeNextBatch(); + g.positionOffset = 0; + g.valueOffset = 0; } - valueCounts[g] = encoders[g].valueCount(positionOffsets[g]); - singleEntry &= (valueCounts[g] == 1); + g.valueCount = encoder.valueCount(g.positionOffset); + singleEntry &= (g.valueCount == 1); } Arrays.fill(bytes.bytes(), 0, nullTrackingBytes, (byte) 0); bytes.setLength(nullTrackingBytes); @@ -125,8 +133,9 @@ void add() { } private void addSingleEntry() { - for (int g = 0; g < encoders.length; g++) { - BytesRef v = encoders[g].read(valueOffsets[g]++, scratch); + for (int g = 0; g < groups.length; g++) { + Group group = groups[g]; + BytesRef v = group.encoder.read(group.valueOffset++, scratch); if (v.length == 0) { int nullByte = g / 8; int nullShift = g % 8; @@ -135,8 +144,8 @@ private void addSingleEntry() { bytes.append(v); } } - int group = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); - ords.appendInt(group); + int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); + ords.appendInt(ord); addedValue(position); } @@ -144,12 +153,13 @@ private void addMultipleEntries() { ords.beginPositionEntry(); int g = 0; outer: for (;;) { - for (; g < encoders.length; g++) { - bytesStarts[g] = bytes.length(); - BytesRef v = encoders[g].read(valueOffsets[g] + loopedIndices[g], scratch); - ++loopedIndices[g]; + for (; g < groups.length; g++) { + Group group = groups[g]; + group.bytesStart = bytes.length(); + BytesRef v = group.encoder.read(group.valueOffset + group.loopedIndex, scratch); + ++group.loopedIndex; if (v.length == 0) { - assert valueCounts[g] == 1 : "null value in non-singleton list"; + assert group.valueCount == 1 : "null value in non-singleton list"; int nullByte = g / 8; int nullShift = g % 8; bytes.bytes()[nullByte] |= (byte) (1 << nullShift); @@ -158,24 +168,26 @@ private void addMultipleEntries() { } } // emit ords - int group = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); - ords.appendInt(group); + int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get()))); + ords.appendInt(ord); addedValueInMultivaluePosition(position); // rewind - bytes.setLength(bytesStarts[--g]); - while (loopedIndices[g] == valueCounts[g]) { - loopedIndices[g] = 0; + Group group = groups[--g]; + bytes.setLength(group.bytesStart); + while (group.loopedIndex == group.valueCount) { + group.loopedIndex = 0; if (g == 0) { break outer; } else { - bytes.setLength(bytesStarts[--g]); + group = groups[--g]; + bytes.setLength(group.bytesStart); } } } ords.endPositionEntry(); - for (g = 0; g < encoders.length; g++) { - valueOffsets[g] += valueCounts[g]; + for (Group group : groups) { + group.valueOffset += group.valueCount; } } } @@ -183,10 +195,10 @@ private void addMultipleEntries() { @Override public Block[] getKeys() { int size = Math.toIntExact(bytesRefHash.size()); - BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[groups.size()]; - Block.Builder[] builders = new Block.Builder[groups.size()]; + BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[groups.length]; + Block.Builder[] builders = new Block.Builder[groups.length]; for (int g = 0; g < builders.length; g++) { - ElementType elementType = groups.get(g).elementType(); + ElementType elementType = groups[g].spec.elementType(); decoders[g] = BatchEncoder.decoder(elementType); builders[g] = elementType.newBlockBuilder(size); } @@ -218,7 +230,7 @@ public Block[] getKeys() { readKeys(decoders, builders, nulls, values, offset); } - Block[] keyBlocks = new Block[groups.size()]; + Block[] keyBlocks = new Block[groups.length]; for (int g = 0; g < keyBlocks.length; g++) { keyBlocks[g] = builders[g].build(); } @@ -258,13 +270,12 @@ public String toString() { StringBuilder b = new StringBuilder(); b.append("PackedValuesBlockHash{groups=["); boolean first = true; - for (HashAggregationOperator.GroupSpec spec : groups) { - if (first) { - first = false; - } else { + for (int i = 0; i < groups.length; i++) { + if (i > 0) { b.append(", "); } - b.append(spec.channel()).append(':').append(spec.elementType()); + Group group = groups[i]; + b.append(group.spec.channel()).append(':').append(group.spec.elementType()); } b.append("], entries=").append(bytesRefHash.size()); b.append(", size=").append(ByteSizeValue.ofBytes(bytesRefHash.ramBytesUsed()));