diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java index 59f48bd7fbaba..d97d59db2ee42 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java @@ -33,6 +33,7 @@ import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalOrder; import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.bucket.BucketsAggregator; import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator; import org.opensearch.search.aggregations.bucket.LocalBucketCountThresholds; import org.opensearch.search.aggregations.support.AggregationPath; @@ -215,19 +216,11 @@ public InternalAggregation buildEmptyAggregation() { @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx); + MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, this, sub); return new LeafBucketCollector() { @Override public void collect(int doc, long owningBucketOrd) throws IOException { - for (BytesRef compositeKey : collector.apply(doc)) { - long bucketOrd = bucketOrds.add(owningBucketOrd, compositeKey); - if (bucketOrd < 0) { - bucketOrd = -1 - bucketOrd; - collectExistingBucket(sub, doc, bucketOrd); - } else { - collectBucket(sub, doc, bucketOrd); - } - } + collector.apply(doc, owningBucketOrd); } }; } @@ -268,12 +261,10 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept } // we need to fill-in the blanks for (LeafReaderContext ctx : context.searcher().getTopReaderContext().leaves()) { - MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx); + MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, null, null); // brute force for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) { - for (BytesRef compositeKey : collector.apply(docId)) { - bucketOrds.add(owningBucketOrd, compositeKey); - } + collector.apply(docId, owningBucketOrd); } } } @@ -287,7 +278,8 @@ interface MultiTermsValuesSourceCollector { * Collect a list values of multi_terms on each doc. * Each terms could have multi_values, so the result is the cartesian product of each term's values. */ - List apply(int doc) throws IOException; + void apply(int doc, long owningBucketOrd) throws IOException; + } @FunctionalInterface @@ -361,51 +353,17 @@ public MultiTermsValuesSource(List valuesSources) { this.valuesSources = valuesSources; } - public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws IOException { + public MultiTermsValuesSourceCollector getValues( + LeafReaderContext ctx, + BytesKeyedBucketOrds bucketOrds, + BucketsAggregator aggregator, + LeafBucketCollector sub + ) throws IOException { List collectors = new ArrayList<>(); for (InternalValuesSource valuesSource : valuesSources) { collectors.add(valuesSource.apply(ctx)); } - return new MultiTermsValuesSourceCollector() { - @Override - public List apply(int doc) throws IOException { - List>> collectedValues = new ArrayList<>(); - for (InternalValuesSourceCollector collector : collectors) { - collectedValues.add(collector.apply(doc)); - } - List result = new ArrayList<>(); - scratch.seek(0); - scratch.writeVInt(collectors.size()); // number of fields per composite key - cartesianProduct(result, scratch, collectedValues, 0); - return result; - } - - /** - * Cartesian product using depth first search. - * - *

- * Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue}, - * but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work. - */ - private void cartesianProduct( - List compositeKeys, - BytesStreamOutput scratch, - List>> collectedValues, - int index - ) throws IOException { - if (collectedValues.size() == index) { - compositeKeys.add(BytesRef.deepCopyOf(scratch.bytes().toBytesRef())); - return; - } - - long position = scratch.position(); - for (TermValue value : collectedValues.get(index)) { - value.writeTo(scratch); // encode the value - cartesianProduct(compositeKeys, scratch, collectedValues, index + 1); // dfs - scratch.seek(position); // backtrack - } - } - }; + return new MultiValuesSourceCollectorImpl(collectors, scratch, bucketOrds, aggregator, sub); } @Override @@ -414,6 +372,74 @@ public void close() { } } + static class MultiValuesSourceCollectorImpl implements MultiTermsValuesSourceCollector { + + private final List collectors; + private final BytesStreamOutput scratch; + private final BytesKeyedBucketOrds bucketOrds; + private final BucketsAggregator aggregator; + private final LeafBucketCollector sub; + + private final boolean collectViaAggregator; + + public MultiValuesSourceCollectorImpl( + List collectors, + BytesStreamOutput scratch, + BytesKeyedBucketOrds bucketOrds, + BucketsAggregator aggregator, + LeafBucketCollector sub + ) { + this.collectors = collectors; + this.scratch = scratch; + this.bucketOrds = bucketOrds; + this.aggregator = aggregator; + this.sub = sub; + this.collectViaAggregator = aggregator != null && sub != null; + } + + @Override + public void apply(int doc, long owningBucketOrd) throws IOException { + List>> collectedValues = new ArrayList<>(); + for (InternalValuesSourceCollector collector : collectors) { + collectedValues.add(collector.apply(doc)); + } + scratch.seek(0); + scratch.writeVInt(collectors.size()); // number of fields per composite key + cartesianProductRecursive(collectedValues, 0, owningBucketOrd, doc); + } + + /** + * Cartesian product using depth first search. + */ + private void cartesianProductRecursive(List>> collectedValues, int index, long owningBucketOrd, int doc) + throws IOException { + if (collectedValues.size() == index) { + // Avoid performing a deep copy of the composite key + long bucketOrd = bucketOrds.add(owningBucketOrd, scratch.bytes().toBytesRef()); + if (collectViaAggregator) { + if (bucketOrd < 0) { + bucketOrd = -1 - bucketOrd; + aggregator.collectExistingBucket(sub, doc, bucketOrd); + } else { + aggregator.collectBucket(sub, doc, bucketOrd); + } + } + return; + } + + long position = scratch.position(); + List> values = collectedValues.get(index); + int numIterations = values.size(); + for (int i = 0; i < numIterations; i++) { + TermValue value = values.get(i); + value.writeTo(scratch); // encode the value + cartesianProductRecursive(collectedValues, index + 1, owningBucketOrd, doc); // dfs + scratch.seek(position); // backtrack + } + } + + } + /** * Factory for construct {@link InternalValuesSource}. * @@ -441,9 +467,13 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include if (i > 0 && bytes.equals(previous)) { continue; } - BytesRef copy = BytesRef.deepCopyOf(bytes); - termValues.add(TermValue.of(copy)); - previous = copy; + if (valuesCount > 1) { + BytesRef copy = BytesRef.deepCopyOf(bytes); + termValues.add(TermValue.of(copy)); + previous = copy; + } else { + termValues.add(TermValue.of(bytes)); + } } return termValues; };