Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Improve performance of encoding composite keys in multi-term aggregations #9434

Merged
merged 1 commit into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add base class for parameterizing the search based tests #9083 ([#9083](https://github.com/opensearch-project/OpenSearch/pull/9083))
- Add support for wrapping CollectorManager with profiling during concurrent execution ([#9129](https://github.com/opensearch-project/OpenSearch/pull/9129))
- Rethrow OpenSearch exception for non-concurrent path while using concurrent search ([#9177](https://github.com/opensearch-project/OpenSearch/pull/9177))
- Improve performance of encoding composite keys in multi-term aggregations ([#9412](https://github.com/opensearch-project/OpenSearch/pull/9412))

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,23 @@ private static Class<?> getGenericType(Object value) {
}
}

/**
* Returns the registered writer for the given class type.
*/
@SuppressWarnings("unchecked")
public static <W extends Writer<?>> W getWriter(Class<?> type) {
Writer<Object> writer = WriteableRegistry.getWriter(type);
if (writer == null) {
// fallback to this local hashmap
// todo: move all writers to the registry
writer = WRITERS.get(type);
}
if (writer == null) {
throw new IllegalArgumentException("can not write type [" + type + "]");
}
return (W) writer;
}

/**
* Notice: when serialization a map, the stream out map with the stream in map maybe have the
* different key-value orders, they will maybe have different stream order.
Expand All @@ -816,17 +833,8 @@ public void writeGenericValue(@Nullable Object value) throws IOException {
return;
}
final Class<?> type = getGenericType(value);
Writer<Object> writer = WriteableRegistry.getWriter(type);
if (writer == null) {
// fallback to this local hashmap
// todo: move all writers to the registry
writer = WRITERS.get(type);
}
if (writer != null) {
writer.write(this, value);
} else {
throw new IllegalArgumentException("can not write type [" + type + "]");
}
final Writer<Object> writer = getWriter(type);
writer.write(this, value);
}

public static void checkWriteable(@Nullable Object value) throws IllegalArgumentException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.NumericUtils;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.ExceptionsHelper;
import org.opensearch.common.CheckedSupplier;
import org.opensearch.common.Numbers;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
import org.opensearch.index.fielddata.SortedNumericDoubleValues;
import org.opensearch.search.DocValueFormat;
Expand Down Expand Up @@ -218,8 +219,8 @@ protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucket
return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
for (List<Object> value : collector.apply(doc)) {
long bucketOrd = bucketOrds.add(owningBucketOrd, encode(value));
for (BytesRef compositeKey : collector.apply(doc)) {
long bucketOrd = bucketOrds.add(owningBucketOrd, compositeKey);
if (bucketOrd < 0) {
bucketOrd = -1 - bucketOrd;
collectExistingBucket(sub, doc, bucketOrd);
Expand All @@ -233,16 +234,7 @@ public void collect(int doc, long owningBucketOrd) throws IOException {

@Override
protected void doClose() {
Releasables.close(bucketOrds);
}

private static BytesRef encode(List<Object> values) {
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.writeCollection(values, StreamOutput::writeGenericValue);
return output.bytes().toBytesRef();
} catch (IOException e) {
throw ExceptionsHelper.convertToRuntime(e);
}
Releasables.close(bucketOrds, multiTermsValue);
}

private static List<Object> decode(BytesRef bytesRef) {
Expand Down Expand Up @@ -279,8 +271,8 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept
MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx);
// brute force
for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) {
for (List<Object> value : collector.apply(docId)) {
bucketOrds.add(owningBucketOrd, encode(value));
for (BytesRef compositeKey : collector.apply(docId)) {
bucketOrds.add(owningBucketOrd, compositeKey);
}
}
}
Expand All @@ -295,7 +287,7 @@ interface MultiTermsValuesSourceCollector {
* Collect a list values of multi_terms on each doc.
* Each terms could have multi_values, so the result is the cartesian product of each term's values.
*/
List<List<Object>> apply(int doc) throws IOException;
List<BytesRef> apply(int doc) throws IOException;
}

@FunctionalInterface
Expand All @@ -314,16 +306,56 @@ interface InternalValuesSourceCollector {
/**
* Collect a list values of a term on specific doc.
*/
List<Object> apply(int doc) throws IOException;
List<TermValue<?>> apply(int doc) throws IOException;
}

/**
* Represents an individual term value.
*/
static class TermValue<T> implements Writeable {
private static final Writer<BytesRef> BYTES_REF_WRITER = StreamOutput.getWriter(BytesRef.class);
private static final Writer<Long> LONG_WRITER = StreamOutput.getWriter(Long.class);
private static final Writer<BigInteger> BIG_INTEGER_WRITER = StreamOutput.getWriter(BigInteger.class);
private static final Writer<Double> DOUBLE_WRITER = StreamOutput.getWriter(Double.class);

private final T value;
private final Writer<T> writer;

private TermValue(T value, Writer<T> writer) {
this.value = value;
this.writer = writer;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
writer.write(out, value);
}

public static TermValue<BytesRef> of(BytesRef value) {
return new TermValue<>(value, BYTES_REF_WRITER);
}

public static TermValue<Long> of(Long value) {
return new TermValue<>(value, LONG_WRITER);
}

public static TermValue<BigInteger> of(BigInteger value) {
return new TermValue<>(value, BIG_INTEGER_WRITER);
}

public static TermValue<Double> of(Double value) {
return new TermValue<>(value, DOUBLE_WRITER);
}
}

/**
* Multi_Term ValuesSource, it is a collection of {@link InternalValuesSource}
*
* @opensearch.internal
*/
static class MultiTermsValuesSource {
static class MultiTermsValuesSource implements Releasable {
private final List<InternalValuesSource> valuesSources;
private final BytesStreamOutput scratch = new BytesStreamOutput();

public MultiTermsValuesSource(List<InternalValuesSource> valuesSources) {
this.valuesSources = valuesSources;
Expand All @@ -336,37 +368,50 @@ public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws I
}
return new MultiTermsValuesSourceCollector() {
@Override
public List<List<Object>> apply(int doc) throws IOException {
List<CheckedSupplier<List<Object>, IOException>> collectedValues = new ArrayList<>();
public List<BytesRef> apply(int doc) throws IOException {
List<List<TermValue<?>>> collectedValues = new ArrayList<>();
for (InternalValuesSourceCollector collector : collectors) {
collectedValues.add(() -> collector.apply(doc));
collectedValues.add(collector.apply(doc));
}
List<List<Object>> result = new ArrayList<>();
apply(0, collectedValues, new ArrayList<>(), result);
List<BytesRef> result = new ArrayList<>();
scratch.seek(0);
scratch.writeVInt(collectors.size()); // number of fields per composite key
cartesianProduct(result, scratch, collectedValues, 0);
return result;
}

/**
* DFS traverse each term's values and add cartesian product to results lists.
* Cartesian product using depth first search.
*
* <p>
* Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue},
* but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work.
*/
private void apply(
int index,
List<CheckedSupplier<List<Object>, IOException>> collectedValues,
List<Object> current,
List<List<Object>> results
private void cartesianProduct(
List<BytesRef> compositeKeys,
BytesStreamOutput scratch,
List<List<TermValue<?>>> collectedValues,
int index
) throws IOException {
if (index == collectedValues.size()) {
results.add(List.copyOf(current));
} else if (null != collectedValues.get(index)) {
for (Object value : collectedValues.get(index).get()) {
current.add(value);
apply(index + 1, collectedValues, current, results);
current.remove(current.size() - 1);
}
if (collectedValues.size() == index) {
compositeKeys.add(BytesRef.deepCopyOf(scratch.bytes().toBytesRef()));
return;
}

long position = scratch.position();
for (TermValue<?> value : collectedValues.get(index)) {
value.writeTo(scratch); // encode the value
cartesianProduct(compositeKeys, scratch, collectedValues, index + 1); // dfs
scratch.seek(position); // backtrack
}
}
};
}

@Override
public void close() {
scratch.close();
}
}

/**
Expand All @@ -379,27 +424,26 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include
return ctx -> {
SortedBinaryDocValues values = valuesSource.bytesValues(ctx);
return doc -> {
BytesRefBuilder previous = new BytesRefBuilder();

if (false == values.advanceExact(doc)) {
return Collections.emptyList();
}
int valuesCount = values.docValueCount();
List<Object> termValues = new ArrayList<>(valuesCount);
List<TermValue<?>> termValues = new ArrayList<>(valuesCount);

// SortedBinaryDocValues don't guarantee uniqueness so we
// need to take care of dups
previous.clear();
BytesRef previous = null;
for (int i = 0; i < valuesCount; ++i) {
BytesRef bytes = values.nextValue();
if (includeExclude != null && false == includeExclude.accept(bytes)) {
continue;
}
if (i > 0 && previous.get().equals(bytes)) {
if (i > 0 && bytes.equals(previous)) {
continue;
}
previous.copyBytes(bytes);
termValues.add(BytesRef.deepCopyOf(bytes));
BytesRef copy = BytesRef.deepCopyOf(bytes);
termValues.add(TermValue.of(copy));
previous = copy;
}
return termValues;
};
Expand All @@ -414,12 +458,12 @@ static InternalValuesSource unsignedLongValuesSource(ValuesSource.Numeric values
int valuesCount = values.docValueCount();

BigInteger previous = Numbers.MAX_UNSIGNED_LONG_VALUE;
List<Object> termValues = new ArrayList<>(valuesCount);
List<TermValue<?>> termValues = new ArrayList<>(valuesCount);
for (int i = 0; i < valuesCount; ++i) {
BigInteger val = Numbers.toUnsignedBigInteger(values.nextValue());
if (previous.compareTo(val) != 0 || i == 0) {
if (longFilter == null || longFilter.accept(NumericUtils.doubleToSortableLong(val.doubleValue()))) {
termValues.add(val);
termValues.add(TermValue.of(val));
}
previous = val;
}
Expand All @@ -439,12 +483,12 @@ static InternalValuesSource longValuesSource(ValuesSource.Numeric valuesSource,
int valuesCount = values.docValueCount();

long previous = Long.MAX_VALUE;
List<Object> termValues = new ArrayList<>(valuesCount);
List<TermValue<?>> termValues = new ArrayList<>(valuesCount);
for (int i = 0; i < valuesCount; ++i) {
long val = values.nextValue();
if (previous != val || i == 0) {
if (longFilter == null || longFilter.accept(val)) {
termValues.add(val);
termValues.add(TermValue.of(val));
}
previous = val;
}
Expand All @@ -464,12 +508,12 @@ static InternalValuesSource doubleValueSource(ValuesSource.Numeric valuesSource,
int valuesCount = values.docValueCount();

double previous = Double.MAX_VALUE;
List<Object> termValues = new ArrayList<>(valuesCount);
List<TermValue<?>> termValues = new ArrayList<>(valuesCount);
for (int i = 0; i < valuesCount; ++i) {
double val = values.nextValue();
if (previous != val || i == 0) {
if (longFilter == null || longFilter.accept(NumericUtils.doubleToSortableLong(val))) {
termValues.add(val);
termValues.add(TermValue.of(val));
}
previous = val;
}
Expand Down