diff --git a/CHANGELOG.md b/CHANGELOG.md index 32d8ccf86f055..0648cc9947f69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -147,6 +147,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add base class for parameterizing the search based tests #9083 ([#9083](https://github.com/opensearch-project/OpenSearch/pull/9083)) - Add support for wrapping CollectorManager with profiling during concurrent execution ([#9129](https://github.com/opensearch-project/OpenSearch/pull/9129)) - Rethrow OpenSearch exception for non-concurrent path while using concurrent search ([#9177](https://github.com/opensearch-project/OpenSearch/pull/9177)) +- Improve performance of encoding composite keys in multi-term aggregations ([#9412](https://github.com/opensearch-project/OpenSearch/pull/9412)) ### Deprecated diff --git a/libs/core/src/main/java/org/opensearch/core/common/io/stream/StreamOutput.java b/libs/core/src/main/java/org/opensearch/core/common/io/stream/StreamOutput.java index 53288cc8c8359..a61278c0cc4de 100644 --- a/libs/core/src/main/java/org/opensearch/core/common/io/stream/StreamOutput.java +++ b/libs/core/src/main/java/org/opensearch/core/common/io/stream/StreamOutput.java @@ -804,6 +804,23 @@ private static Class getGenericType(Object value) { } } + /** + * Returns the registered writer for the given class type. + */ + @SuppressWarnings("unchecked") + public static > W getWriter(Class type) { + Writer writer = WriteableRegistry.getWriter(type); + if (writer == null) { + // fallback to this local hashmap + // todo: move all writers to the registry + writer = WRITERS.get(type); + } + if (writer == null) { + throw new IllegalArgumentException("can not write type [" + type + "]"); + } + return (W) writer; + } + /** * Notice: when serialization a map, the stream out map with the stream in map maybe have the * different key-value orders, they will maybe have different stream order. @@ -816,17 +833,8 @@ public void writeGenericValue(@Nullable Object value) throws IOException { return; } final Class type = getGenericType(value); - Writer writer = WriteableRegistry.getWriter(type); - if (writer == null) { - // fallback to this local hashmap - // todo: move all writers to the registry - writer = WRITERS.get(type); - } - if (writer != null) { - writer.write(this, value); - } else { - throw new IllegalArgumentException("can not write type [" + type + "]"); - } + final Writer writer = getWriter(type); + writer.write(this, value); } public static void checkWriteable(@Nullable Object value) throws IllegalArgumentException { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java index c70c00ff69cba..0482ef823818c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java @@ -11,17 +11,18 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.NumericUtils; import org.apache.lucene.util.PriorityQueue; import org.opensearch.ExceptionsHelper; import org.opensearch.common.CheckedSupplier; import org.opensearch.common.Numbers; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.index.fielddata.SortedBinaryDocValues; import org.opensearch.index.fielddata.SortedNumericDoubleValues; import org.opensearch.search.DocValueFormat; @@ -218,8 +219,8 @@ protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucket return new LeafBucketCollector() { @Override public void collect(int doc, long owningBucketOrd) throws IOException { - for (List value : collector.apply(doc)) { - long bucketOrd = bucketOrds.add(owningBucketOrd, encode(value)); + for (BytesRef compositeKey : collector.apply(doc)) { + long bucketOrd = bucketOrds.add(owningBucketOrd, compositeKey); if (bucketOrd < 0) { bucketOrd = -1 - bucketOrd; collectExistingBucket(sub, doc, bucketOrd); @@ -233,16 +234,7 @@ public void collect(int doc, long owningBucketOrd) throws IOException { @Override protected void doClose() { - Releasables.close(bucketOrds); - } - - private static BytesRef encode(List values) { - try (BytesStreamOutput output = new BytesStreamOutput()) { - output.writeCollection(values, StreamOutput::writeGenericValue); - return output.bytes().toBytesRef(); - } catch (IOException e) { - throw ExceptionsHelper.convertToRuntime(e); - } + Releasables.close(bucketOrds, multiTermsValue); } private static List decode(BytesRef bytesRef) { @@ -279,8 +271,8 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx); // brute force for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) { - for (List value : collector.apply(docId)) { - bucketOrds.add(owningBucketOrd, encode(value)); + for (BytesRef compositeKey : collector.apply(docId)) { + bucketOrds.add(owningBucketOrd, compositeKey); } } } @@ -295,7 +287,7 @@ interface MultiTermsValuesSourceCollector { * Collect a list values of multi_terms on each doc. * Each terms could have multi_values, so the result is the cartesian product of each term's values. */ - List> apply(int doc) throws IOException; + List apply(int doc) throws IOException; } @FunctionalInterface @@ -314,7 +306,46 @@ interface InternalValuesSourceCollector { /** * Collect a list values of a term on specific doc. */ - List apply(int doc) throws IOException; + List> apply(int doc) throws IOException; + } + + /** + * Represents an individual term value. + */ + static class TermValue implements Writeable { + private static final Writer BYTES_REF_WRITER = StreamOutput.getWriter(BytesRef.class); + private static final Writer LONG_WRITER = StreamOutput.getWriter(Long.class); + private static final Writer BIG_INTEGER_WRITER = StreamOutput.getWriter(BigInteger.class); + private static final Writer DOUBLE_WRITER = StreamOutput.getWriter(Double.class); + + private final T value; + private final Writer writer; + + private TermValue(T value, Writer writer) { + this.value = value; + this.writer = writer; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + writer.write(out, value); + } + + public static TermValue of(BytesRef value) { + return new TermValue<>(value, BYTES_REF_WRITER); + } + + public static TermValue of(Long value) { + return new TermValue<>(value, LONG_WRITER); + } + + public static TermValue of(BigInteger value) { + return new TermValue<>(value, BIG_INTEGER_WRITER); + } + + public static TermValue of(Double value) { + return new TermValue<>(value, DOUBLE_WRITER); + } } /** @@ -322,8 +353,9 @@ interface InternalValuesSourceCollector { * * @opensearch.internal */ - static class MultiTermsValuesSource { + static class MultiTermsValuesSource implements Releasable { private final List valuesSources; + private final BytesStreamOutput scratch = new BytesStreamOutput(); public MultiTermsValuesSource(List valuesSources) { this.valuesSources = valuesSources; @@ -336,37 +368,50 @@ public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws I } return new MultiTermsValuesSourceCollector() { @Override - public List> apply(int doc) throws IOException { - List, IOException>> collectedValues = new ArrayList<>(); + public List apply(int doc) throws IOException { + List>> collectedValues = new ArrayList<>(); for (InternalValuesSourceCollector collector : collectors) { - collectedValues.add(() -> collector.apply(doc)); + collectedValues.add(collector.apply(doc)); } - List> result = new ArrayList<>(); - apply(0, collectedValues, new ArrayList<>(), result); + List result = new ArrayList<>(); + scratch.seek(0); + scratch.writeVInt(collectors.size()); // number of fields per composite key + cartesianProduct(result, scratch, collectedValues, 0); return result; } /** - * DFS traverse each term's values and add cartesian product to results lists. + * Cartesian product using depth first search. + * + *

+ * Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue}, + * but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work. */ - private void apply( - int index, - List, IOException>> collectedValues, - List current, - List> results + private void cartesianProduct( + List compositeKeys, + BytesStreamOutput scratch, + List>> collectedValues, + int index ) throws IOException { - if (index == collectedValues.size()) { - results.add(List.copyOf(current)); - } else if (null != collectedValues.get(index)) { - for (Object value : collectedValues.get(index).get()) { - current.add(value); - apply(index + 1, collectedValues, current, results); - current.remove(current.size() - 1); - } + if (collectedValues.size() == index) { + compositeKeys.add(BytesRef.deepCopyOf(scratch.bytes().toBytesRef())); + return; + } + + long position = scratch.position(); + for (TermValue value : collectedValues.get(index)) { + value.writeTo(scratch); // encode the value + cartesianProduct(compositeKeys, scratch, collectedValues, index + 1); // dfs + scratch.seek(position); // backtrack } } }; } + + @Override + public void close() { + scratch.close(); + } } /** @@ -379,27 +424,26 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include return ctx -> { SortedBinaryDocValues values = valuesSource.bytesValues(ctx); return doc -> { - BytesRefBuilder previous = new BytesRefBuilder(); - if (false == values.advanceExact(doc)) { return Collections.emptyList(); } int valuesCount = values.docValueCount(); - List termValues = new ArrayList<>(valuesCount); + List> termValues = new ArrayList<>(valuesCount); // SortedBinaryDocValues don't guarantee uniqueness so we // need to take care of dups - previous.clear(); + BytesRef previous = null; for (int i = 0; i < valuesCount; ++i) { BytesRef bytes = values.nextValue(); if (includeExclude != null && false == includeExclude.accept(bytes)) { continue; } - if (i > 0 && previous.get().equals(bytes)) { + if (i > 0 && bytes.equals(previous)) { continue; } - previous.copyBytes(bytes); - termValues.add(BytesRef.deepCopyOf(bytes)); + BytesRef copy = BytesRef.deepCopyOf(bytes); + termValues.add(TermValue.of(copy)); + previous = copy; } return termValues; }; @@ -414,12 +458,12 @@ static InternalValuesSource unsignedLongValuesSource(ValuesSource.Numeric values int valuesCount = values.docValueCount(); BigInteger previous = Numbers.MAX_UNSIGNED_LONG_VALUE; - List termValues = new ArrayList<>(valuesCount); + List> termValues = new ArrayList<>(valuesCount); for (int i = 0; i < valuesCount; ++i) { BigInteger val = Numbers.toUnsignedBigInteger(values.nextValue()); if (previous.compareTo(val) != 0 || i == 0) { if (longFilter == null || longFilter.accept(NumericUtils.doubleToSortableLong(val.doubleValue()))) { - termValues.add(val); + termValues.add(TermValue.of(val)); } previous = val; } @@ -439,12 +483,12 @@ static InternalValuesSource longValuesSource(ValuesSource.Numeric valuesSource, int valuesCount = values.docValueCount(); long previous = Long.MAX_VALUE; - List termValues = new ArrayList<>(valuesCount); + List> termValues = new ArrayList<>(valuesCount); for (int i = 0; i < valuesCount; ++i) { long val = values.nextValue(); if (previous != val || i == 0) { if (longFilter == null || longFilter.accept(val)) { - termValues.add(val); + termValues.add(TermValue.of(val)); } previous = val; } @@ -464,12 +508,12 @@ static InternalValuesSource doubleValueSource(ValuesSource.Numeric valuesSource, int valuesCount = values.docValueCount(); double previous = Double.MAX_VALUE; - List termValues = new ArrayList<>(valuesCount); + List> termValues = new ArrayList<>(valuesCount); for (int i = 0; i < valuesCount; ++i) { double val = values.nextValue(); if (previous != val || i == 0) { if (longFilter == null || longFilter.accept(NumericUtils.doubleToSortableLong(val))) { - termValues.add(val); + termValues.add(TermValue.of(val)); } previous = val; }