Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace recursive with loop in PackedValuesBlockHash #99992

Merged
merged 9 commits into from
Sep 29, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import org.elasticsearch.compute.operator.BatchEncoder;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.MultivalueDedupe;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;

import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -51,19 +49,20 @@
* }</pre>
*/
final class PackedValuesBlockHash extends BlockHash {
private static final Logger logger = LogManager.getLogger(PackedValuesBlockHash.class);
static final int DEFAULT_BATCH_SIZE = Math.toIntExact(ByteSizeValue.ofKb(10).getBytes());

private final List<HashAggregationOperator.GroupSpec> groups;
private final int emitBatchSize;
private final BytesRefHash bytesRefHash;
private final int nullTrackingBytes;
private final BytesRef scratch = new BytesRef();
private final BytesRefBuilder bytes = new BytesRefBuilder();
private final Group[] groups;

PackedValuesBlockHash(List<HashAggregationOperator.GroupSpec> groups, BigArrays bigArrays, int emitBatchSize) {
this.groups = groups;
PackedValuesBlockHash(List<HashAggregationOperator.GroupSpec> specs, BigArrays bigArrays, int emitBatchSize) {
this.groups = specs.stream().map(Group::new).toArray(Group[]::new);
this.emitBatchSize = emitBatchSize;
this.bytesRefHash = new BytesRefHash(1, bigArrays);
this.nullTrackingBytes = groups.size() / 8 + 1;
this.nullTrackingBytes = (groups.length + 7) / 8;
}

@Override
Expand All @@ -75,23 +74,28 @@ void add(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize)
new AddWork(page, addInput, batchSize).add();
}

private static class Group {
final HashAggregationOperator.GroupSpec spec;
BatchEncoder encoder;
int positionOffset;
int valueOffset;
int loopedIndex;
int valueCount;
int bytesStart;

Group(HashAggregationOperator.GroupSpec spec) {
this.spec = spec;
}
}

class AddWork extends LongLongBlockHash.AbstractAddBlock {
final BatchEncoder[] encoders = new BatchEncoder[groups.size()];
final int[] positionOffsets = new int[groups.size()];
final int[] valueOffsets = new int[groups.size()];
final BytesRef[] scratches = new BytesRef[groups.size()];
final BytesRefBuilder bytes = new BytesRefBuilder();
final int positionCount;

int position;
int count;
int bufferedGroup;

AddWork(Page page, GroupingAggregatorFunction.AddInput addInput, int batchSize) {
super(emitBatchSize, addInput);
for (int g = 0; g < groups.size(); g++) {
encoders[g] = MultivalueDedupe.batchEncoder(page.getBlock(groups.get(g).channel()), batchSize);
scratches[g] = new BytesRef();
for (Group group : groups) {
group.encoder = MultivalueDedupe.batchEncoder(page.getBlock(group.spec.channel()), batchSize);
}
bytes.grow(nullTrackingBytes);
this.positionCount = page.getPositionCount();
Expand All @@ -104,108 +108,103 @@ class AddWork extends LongLongBlockHash.AbstractAddBlock {
*/
void add() {
for (position = 0; position < positionCount; position++) {
if (logger.isTraceEnabled()) {
logger.trace("position {}", position);
}
// Make sure all encoders have encoded the current position and the offsets are queued to it's start
for (int g = 0; g < encoders.length; g++) {
positionOffsets[g]++;
while (positionOffsets[g] >= encoders[g].positionCount()) {
encoders[g].encodeNextBatch();
positionOffsets[g] = 0;
valueOffsets[g] = 0;
boolean singleEntry = true;
for (Group g : groups) {
var encoder = g.encoder;
g.positionOffset++;
while (g.positionOffset >= encoder.positionCount()) {
encoder.encodeNextBatch();
g.positionOffset = 0;
g.valueOffset = 0;
}
g.valueCount = encoder.valueCount(g.positionOffset);
singleEntry &= (g.valueCount == 1);
}

count = 0;
Arrays.fill(bytes.bytes(), 0, nullTrackingBytes, (byte) 0);
bytes.setLength(nullTrackingBytes);
addPosition(0);
switch (count) {
case 0 -> throw new IllegalStateException("didn't find any values");
case 1 -> {
ords.appendInt(bufferedGroup);
addedValue(position);
}
default -> ords.endPositionEntry();
}
for (int g = 0; g < encoders.length; g++) {
valueOffsets[g] += encoders[g].valueCount(positionOffsets[g]);
if (singleEntry) {
addSingleEntry();
} else {
addMultipleEntries();
}
}
emitOrds();
}

private void addPosition(int g) {
if (g == groups.size()) {
addBytes();
return;
}
int start = bytes.length();
int count = encoders[g].valueCount(positionOffsets[g]);
assert count > 0;
int valueOffset = valueOffsets[g];
BytesRef v = encoders[g].read(valueOffset++, scratches[g]);
if (logger.isTraceEnabled()) {
logger.trace("\t".repeat(g + 1) + v);
}
if (v.length == 0) {
assert count == 1 : "null value in non-singleton list";
int nullByte = g / 8;
int nullShift = g % 8;
bytes.bytes()[nullByte] |= (byte) (1 << nullShift);
}
bytes.setLength(start);
bytes.append(v);
addPosition(g + 1); // TODO stack overflow protection
for (int i = 1; i < count; i++) {
v = encoders[g].read(valueOffset++, scratches[g]);
if (logger.isTraceEnabled()) {
logger.trace("\t".repeat(g + 1) + v);
private void addSingleEntry() {
for (int g = 0; g < groups.length; g++) {
Group group = groups[g];
BytesRef v = group.encoder.read(group.valueOffset++, scratch);
if (v.length == 0) {
int nullByte = g / 8;
int nullShift = g % 8;
bytes.bytes()[nullByte] |= (byte) (1 << nullShift);
} else {
bytes.append(v);
}
assert v.length > 0 : "null value after the first position";
bytes.setLength(start);
bytes.append(v);
addPosition(g + 1);
}
int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get())));
ords.appendInt(ord);
addedValue(position);
}

private void addBytes() {
int group = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get())));
switch (count) {
case 0 -> bufferedGroup = group;
case 1 -> {
ords.beginPositionEntry();
ords.appendInt(bufferedGroup);
addedValueInMultivaluePosition(position);
ords.appendInt(group);
addedValueInMultivaluePosition(position);
private void addMultipleEntries() {
ords.beginPositionEntry();
int g = 0;
outer: for (;;) {
for (; g < groups.length; g++) {
Group group = groups[g];
group.bytesStart = bytes.length();
BytesRef v = group.encoder.read(group.valueOffset + group.loopedIndex, scratch);
++group.loopedIndex;
if (v.length == 0) {
assert group.valueCount == 1 : "null value in non-singleton list";
int nullByte = g / 8;
int nullShift = g % 8;
bytes.bytes()[nullByte] |= (byte) (1 << nullShift);
} else {
bytes.append(v);
}
}
default -> {
ords.appendInt(group);
addedValueInMultivaluePosition(position);
// emit ords
int ord = Math.toIntExact(hashOrdToGroup(bytesRefHash.add(bytes.get())));
ords.appendInt(ord);
addedValueInMultivaluePosition(position);

// rewind
Group group = groups[--g];
bytes.setLength(group.bytesStart);
while (group.loopedIndex == group.valueCount) {
group.loopedIndex = 0;
if (g == 0) {
break outer;
} else {
group = groups[--g];
bytes.setLength(group.bytesStart);
}
}
}
count++;
if (logger.isTraceEnabled()) {
logger.trace("{} = {}", bytes.get(), group);
ords.endPositionEntry();
for (Group group : groups) {
group.valueOffset += group.valueCount;
}
}
}

@Override
public Block[] getKeys() {
int size = Math.toIntExact(bytesRefHash.size());
BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[groups.size()];
Block.Builder[] builders = new Block.Builder[groups.size()];
BatchEncoder.Decoder[] decoders = new BatchEncoder.Decoder[groups.length];
Block.Builder[] builders = new Block.Builder[groups.length];
for (int g = 0; g < builders.length; g++) {
ElementType elementType = groups.get(g).elementType();
ElementType elementType = groups[g].spec.elementType();
decoders[g] = BatchEncoder.decoder(elementType);
builders[g] = elementType.newBlockBuilder(size);
}

BytesRef values[] = new BytesRef[(int) Math.min(100, bytesRefHash.size())];
BytesRef nulls[] = new BytesRef[values.length];
BytesRef[] values = new BytesRef[(int) Math.min(100, bytesRefHash.size())];
BytesRef[] nulls = new BytesRef[values.length];
for (int offset = 0; offset < values.length; offset++) {
values[offset] = new BytesRef();
nulls[offset] = new BytesRef();
Expand All @@ -231,7 +230,7 @@ public Block[] getKeys() {
readKeys(decoders, builders, nulls, values, offset);
}

Block[] keyBlocks = new Block[groups.size()];
Block[] keyBlocks = new Block[groups.length];
for (int g = 0; g < keyBlocks.length; g++) {
keyBlocks[g] = builders[g].build();
}
Expand Down Expand Up @@ -271,13 +270,12 @@ public String toString() {
StringBuilder b = new StringBuilder();
b.append("PackedValuesBlockHash{groups=[");
boolean first = true;
for (HashAggregationOperator.GroupSpec spec : groups) {
if (first) {
first = false;
} else {
for (int i = 0; i < groups.length; i++) {
if (i > 0) {
b.append(", ");
}
b.append(spec.channel()).append(':').append(spec.elementType());
Group group = groups[i];
b.append(group.spec.channel()).append(':').append(group.spec.elementType());
}
b.append("], entries=").append(bytesRefHash.size());
b.append(", size=").append(ByteSizeValue.ofBytes(bytesRefHash.ramBytesUsed()));
Expand Down