Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ES|QL deserves a new hash table #98749

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
package org.elasticsearch.benchmark.compute.operator;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier;
import org.elasticsearch.compute.aggregation.AggregatorMode;
import org.elasticsearch.compute.aggregation.CountAggregatorFunction;
Expand All @@ -36,6 +40,8 @@
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.HashAggregationOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.indices.breaker.HierarchyCircuitBreakerService;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
Expand All @@ -58,14 +64,12 @@
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Fork(1)
@Fork(value = 1, jvmArgsAppend = { "--enable-preview", "--add-modules", "jdk.incubator.vector" })
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this to work with the new hash, but it doesn't produce the lovely performance numbers - yet. Partly that's because we're not integrating with the hash super well - the vector case needs to consume the array somehow. Or something similar. But that feels like something for another time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other reason this doesn't show the performance bump we expect is because we don't enable all of the other aggregations - and because we don't aggregate much larger groups. Either way, this benchmark is much better at showing the performance of aggs, not the groupings. At least not yet.

public class AggregatorBenchmark {
static final int BLOCK_LENGTH = 8 * 1024;
private static final int OP_COUNT = 1024;
private static final int GROUPS = 5;

private static final BigArrays BIG_ARRAYS = BigArrays.NON_RECYCLING_INSTANCE; // TODO real big arrays?

private static final String LONGS = "longs";
private static final String INTS = "ints";
private static final String DOUBLES = "doubles";
Expand Down Expand Up @@ -96,7 +100,7 @@ public class AggregatorBenchmark {
for (String grouping : AggregatorBenchmark.class.getField("grouping").getAnnotationsByType(Param.class)[0].value()) {
for (String op : AggregatorBenchmark.class.getField("op").getAnnotationsByType(Param.class)[0].value()) {
for (String blockType : AggregatorBenchmark.class.getField("blockType").getAnnotationsByType(Param.class)[0].value()) {
run(grouping, op, blockType, 50);
new AggregatorBenchmark().run(grouping, op, blockType, 50);
}
}
}
Expand All @@ -105,6 +109,14 @@ public class AggregatorBenchmark {
}
}

private final PageCacheRecycler recycler = new PageCacheRecycler(Settings.EMPTY);
private final CircuitBreakerService breakerService = new HierarchyCircuitBreakerService(
Settings.EMPTY,
List.of(),
ClusterSettings.createBuiltInClusterSettings()
);
private final BigArrays bigArrays = new BigArrays(recycler, breakerService, CircuitBreaker.REQUEST);

@Param({ NONE, LONGS, INTS, DOUBLES, BOOLEANS, BYTES_REFS, TWO_LONGS, LONGS_AND_BYTES_REFS, TWO_LONGS_AND_BYTES_REFS })
public String grouping;

Expand All @@ -114,7 +126,7 @@ public class AggregatorBenchmark {
@Param({ VECTOR_LONGS, HALF_NULL_LONGS, VECTOR_DOUBLES, HALF_NULL_DOUBLES })
public String blockType;

private static Operator operator(String grouping, String op, String dataType) {
private Operator operator(String grouping, String op, String dataType) {
if (grouping.equals("none")) {
return new AggregationOperator(List.of(supplier(op, dataType, 0).aggregatorFactory(AggregatorMode.SINGLE).get()));
}
Expand All @@ -139,34 +151,35 @@ private static Operator operator(String grouping, String op, String dataType) {
);
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
};
BlockHash.Factory factory = new BlockHash.Factory(bigArrays, recycler, () -> breakerService.getBreaker(CircuitBreaker.REQUEST));
return new HashAggregationOperator(
List.of(supplier(op, dataType, groups.size()).groupingAggregatorFactory(AggregatorMode.SINGLE)),
() -> BlockHash.build(groups, BIG_ARRAYS, 16 * 1024),
() -> factory.build(groups, 16 * 1024),
new DriverContext()
);
}

private static AggregatorFunctionSupplier supplier(String op, String dataType, int dataChannel) {
private AggregatorFunctionSupplier supplier(String op, String dataType, int dataChannel) {
return switch (op) {
case COUNT -> CountAggregatorFunction.supplier(BIG_ARRAYS, List.of(dataChannel));
case COUNT -> CountAggregatorFunction.supplier(bigArrays, List.of(dataChannel));
case COUNT_DISTINCT -> switch (dataType) {
case LONGS -> new CountDistinctLongAggregatorFunctionSupplier(BIG_ARRAYS, List.of(dataChannel), 3000);
case DOUBLES -> new CountDistinctDoubleAggregatorFunctionSupplier(BIG_ARRAYS, List.of(dataChannel), 3000);
case LONGS -> new CountDistinctLongAggregatorFunctionSupplier(bigArrays, List.of(dataChannel), 3000);
case DOUBLES -> new CountDistinctDoubleAggregatorFunctionSupplier(bigArrays, List.of(dataChannel), 3000);
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
};
case MAX -> switch (dataType) {
case LONGS -> new MaxLongAggregatorFunctionSupplier(BIG_ARRAYS, List.of(dataChannel));
case DOUBLES -> new MaxDoubleAggregatorFunctionSupplier(BIG_ARRAYS, List.of(dataChannel));
case LONGS -> new MaxLongAggregatorFunctionSupplier(bigArrays, List.of(dataChannel));
case DOUBLES -> new MaxDoubleAggregatorFunctionSupplier(bigArrays, List.of(dataChannel));
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
};
case MIN -> switch (dataType) {
case LONGS -> new MinLongAggregatorFunctionSupplier(BIG_ARRAYS, List.of(dataChannel));
case DOUBLES -> new MinDoubleAggregatorFunctionSupplier(BIG_ARRAYS, List.of(dataChannel));
case LONGS -> new MinLongAggregatorFunctionSupplier(bigArrays, List.of(dataChannel));
case DOUBLES -> new MinDoubleAggregatorFunctionSupplier(bigArrays, List.of(dataChannel));
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
};
case SUM -> switch (dataType) {
case LONGS -> new SumLongAggregatorFunctionSupplier(BIG_ARRAYS, List.of(dataChannel));
case DOUBLES -> new SumDoubleAggregatorFunctionSupplier(BIG_ARRAYS, List.of(dataChannel));
case LONGS -> new SumLongAggregatorFunctionSupplier(bigArrays, List.of(dataChannel));
case DOUBLES -> new SumDoubleAggregatorFunctionSupplier(bigArrays, List.of(dataChannel));
default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]");
};
default -> throw new IllegalArgumentException("unsupported op [" + op + "]");
Expand Down Expand Up @@ -561,19 +574,20 @@ public void run() {
run(grouping, op, blockType, OP_COUNT);
}

private static void run(String grouping, String op, String blockType, int opCount) {
private void run(String grouping, String op, String blockType, int opCount) {
String dataType = switch (blockType) {
case VECTOR_LONGS, HALF_NULL_LONGS -> LONGS;
case VECTOR_DOUBLES, HALF_NULL_DOUBLES -> DOUBLES;
default -> throw new IllegalArgumentException();
};

Operator operator = operator(grouping, op, dataType);
Page page = page(grouping, blockType);
for (int i = 0; i < opCount; i++) {
operator.addInput(page);
try (Operator operator = operator(grouping, op, dataType)) {
Page page = page(grouping, blockType);
for (int i = 0; i < opCount; i++) {
operator.addInput(page);
}
operator.finish();
checkExpected(grouping, op, blockType, dataType, operator.getOutput(), opCount);
}
operator.finish();
checkExpected(grouping, op, blockType, dataType, operator.getOutput(), opCount);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.benchmark.compute.operator;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.breaker.NoopCircuitBreaker;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.BytesRefHash;
import org.elasticsearch.common.util.LongHash;
import org.elasticsearch.common.util.LongLongHash;
import org.elasticsearch.common.util.LongObjectPagedHashMap;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.compute.aggregation.blockhash.Ordinator64;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

@Warmup(iterations = 5)
@Measurement(iterations = 7)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@State(Scope.Thread)
@Fork(value = 1, jvmArgsAppend = { "--enable-preview", "--add-modules", "jdk.incubator.vector" })
public class HashBenchmark {
static {
// Smoke test all the expected values and force loading subclasses more like prod
try {
for (String unique : HashBenchmark.class.getField("unique").getAnnotationsByType(Param.class)[0].value()) {
HashBenchmark bench = new HashBenchmark();
bench.unique = Integer.parseInt(unique);
bench.initTestData();
bench.longHash();
bench.bytesRefHash();
bench.longLongHash();
bench.longObjectHash();
bench.ordinator();
bench.ordinatorArray();
}
} catch (NoSuchFieldException e) {
throw new AssertionError();
}
}

private static final int ITERATIONS = 10_000_000;

@Param({ "5", "1000", "10000", "100000", "1000000" })
public int unique;

private long[] testLongs;
private BytesRef[] testBytes;
private int[] targetInts;
private long[] targetLongs;
private Object[] targetObject;

@Setup
public void initTestData() {
testLongs = LongStream.range(0, ITERATIONS).map(l -> l % unique).toArray();
BytesRef[] uniqueBytes = IntStream.range(0, unique).mapToObj(i -> new BytesRef(Integer.toString(i))).toArray(BytesRef[]::new);
testBytes = IntStream.range(0, ITERATIONS).mapToObj(i -> uniqueBytes[i % unique]).toArray(BytesRef[]::new);
targetInts = new int[ITERATIONS];
targetLongs = new long[ITERATIONS];
targetObject = new Object[ITERATIONS];
}

@Benchmark
@OperationsPerInvocation(ITERATIONS)
public void longHash() {
LongHash hash = new LongHash(16, BigArrays.NON_RECYCLING_INSTANCE);
for (int i = 0; i < testLongs.length; i++) {
targetLongs[i] = hash.add(testLongs[i]);
}
if (hash.size() != unique) {
throw new AssertionError();
}
}

@Benchmark
@OperationsPerInvocation(ITERATIONS)
public void bytesRefHash() {
BytesRefHash hash = new BytesRefHash(16, BigArrays.NON_RECYCLING_INSTANCE);
for (int i = 0; i < testLongs.length; i++) {
targetLongs[i] = hash.add(testBytes[i]);
}
if (hash.size() != unique) {
throw new AssertionError();
}
}

@Benchmark
@OperationsPerInvocation(ITERATIONS)
public void longLongHash() {
LongLongHash hash = new LongLongHash(16, BigArrays.NON_RECYCLING_INSTANCE);
for (int i = 0; i < testLongs.length; i++) {
targetLongs[i] = hash.add(testLongs[i], testLongs[i]);
}
if (hash.size() != unique) {
throw new AssertionError();
}
}

@Benchmark
@OperationsPerInvocation(ITERATIONS)
public void longObjectHash() {
LongObjectPagedHashMap<Object> hash = new LongObjectPagedHashMap<>(16, BigArrays.NON_RECYCLING_INSTANCE);
Object o = new Object();
for (int i = 0; i < testLongs.length; i++) {
targetObject[i] = hash.put(testLongs[i], o);
}
if (hash.size() != unique) {
throw new AssertionError();
}
}

@Benchmark
@OperationsPerInvocation(ITERATIONS)
public void ordinator() {
Ordinator64 hash = new Ordinator64(
new PageCacheRecycler(Settings.EMPTY),
new NoopCircuitBreaker("bench"),
new Ordinator64.IdSpace()
);
for (int i = 0; i < testLongs.length; i++) {
targetInts[i] = hash.add(testLongs[i]);
}
if (hash.currentSize() != unique) {
throw new AssertionError("expected " + hash.currentSize() + " to be " + unique);
}
}

@Benchmark
@OperationsPerInvocation(ITERATIONS)
public void ordinatorArray() {
Ordinator64 hash = new Ordinator64(
new PageCacheRecycler(Settings.EMPTY),
new NoopCircuitBreaker("bench"),
new Ordinator64.IdSpace()
);
hash.add(testLongs, targetInts, testLongs.length);
if (hash.currentSize() != unique) {
throw new AssertionError();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ if (providers.systemProperty('idea.active').getOrNull() == 'true') {
'--add-opens=java.base/java.nio.file=ALL-UNNAMED',
'--add-opens=java.base/java.time=ALL-UNNAMED',
'--add-opens=java.base/java.lang=ALL-UNNAMED',
'--add-opens=java.management/java.lang.management=ALL-UNNAMED'
'--add-opens=java.management/java.lang.management=ALL-UNNAMED',
'--enable-preview',
'--add-modules=jdk.incubator.vector'
].join(' ')
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public static void configureCompile(Project project) {
// fail on all javac warnings.
// TODO Discuss moving compileOptions.getCompilerArgs() to use provider api with Gradle team.
List<String> compilerArgs = compileOptions.getCompilerArgs();
compilerArgs.add("-Werror");
// compilerArgs.add("-Werror"); NOCOMMIT add me back once we figure out how to not fail compiling with preview features
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This'd be a huge problem to commit, but I can't figure out a good way around it. If I enable the vector API it'll emit the warning. I think Lucene has some kind of hack for accessing the vector API that I think would fix this. And we'd want to steal that.

compilerArgs.add("-Xlint:all,-path,-serial,-options,-deprecation,-try,-removal");
compilerArgs.add("-Xdoclint:all");
compilerArgs.add("-Xdoclint:-missing");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ public void execute(Task t) {
"--add-opens=java.base/java.nio.file=ALL-UNNAMED",
"--add-opens=java.base/java.time=ALL-UNNAMED",
"--add-opens=java.management/java.lang.management=ALL-UNNAMED",
"-XX:+HeapDumpOnOutOfMemoryError"
"-XX:+HeapDumpOnOutOfMemoryError",
"--enable-preview",
"--add-modules=jdk.incubator.vector"
);

test.getJvmArgumentProviders().add(new SimpleCommandLineArgumentProvider("-XX:HeapDumpPath=" + heapdumpDir));
Expand Down
20 changes: 20 additions & 0 deletions x-pack/plugin/esql/compute/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ tasks.named("compileJava").configure {
options.compilerArgs.addAll(["-s", "${projectDir}/src/main/generated"])
}

tasks.named('forbiddenApisMain').configure {
failOnMissingClasses = false // Ignore the vector apis
}

tasks.named('checkstyleMain').configure {
source = "src/main/java"
excludes = [ "**/*.java.st" ]
Expand Down Expand Up @@ -396,4 +400,20 @@ tasks.named('stringTemplates').configure {
it.inputFile = multivalueDedupeInputFile
it.outputFile = "org/elasticsearch/compute/operator/MultivalueDedupeBytesRef.java"
}
File blockHashInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/blockhash/X-BlockHash.java.st")
template {
it.properties = intProperties
it.inputFile = blockHashInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/blockhash/IntBlockHash.java"
}
template {
it.properties = longProperties
it.inputFile = blockHashInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/blockhash/LongBlockHash.java"
}
template {
it.properties = doubleProperties
it.inputFile = blockHashInputFile
it.outputFile = "org/elasticsearch/compute/aggregation/blockhash/DoubleBlockHash.java"
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm generating these so it's easier to keep them updated. I'll generate some more Ordinators at some point - at 32 and 128 bit one at least. But that's another follow up.

}
Loading