Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

With only GlobalAggregation in request causes unnecessary wrapping with MultiCollector #8125

Merged
merged 1 commit into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Replaces ZipInputStream with ZipFile to fix Zip Slip vulnerability ([#7230](https://github.com/opensearch-project/OpenSearch/pull/7230))
- Add missing validation/parsing of SearchBackpressureMode of SearchBackpressureSettings ([#7541](https://github.com/opensearch-project/OpenSearch/pull/7541))
- Fix mapping char_filter when mapping a hashtag ([#7591](https://github.com/opensearch-project/OpenSearch/pull/7591))
- With only GlobalAggregation in request causes unnecessary wrapping with MultiCollector ([#8125](https://github.com/opensearch-project/OpenSearch/pull/8125))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,19 @@

package org.opensearch.search.profile.aggregation;

import org.hamcrest.core.IsNull;
import org.opensearch.action.index.IndexRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.search.aggregations.Aggregator.SubAggCollectionMode;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.bucket.global.Global;
import org.opensearch.search.aggregations.bucket.sampler.DiversifiedOrdinalsSamplerAggregator;
import org.opensearch.search.aggregations.bucket.terms.GlobalOrdinalsStringTermsAggregator;
import org.opensearch.search.aggregations.metrics.Stats;
import org.opensearch.search.profile.ProfileResult;
import org.opensearch.search.profile.ProfileShardResult;
import org.opensearch.search.profile.query.QueryProfileShardResult;
import org.opensearch.test.OpenSearchIntegTestCase;

import java.util.ArrayList;
Expand All @@ -48,11 +53,15 @@
import java.util.Set;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder;
import static org.opensearch.search.aggregations.AggregationBuilders.avg;
import static org.opensearch.search.aggregations.AggregationBuilders.diversifiedSampler;
import static org.opensearch.search.aggregations.AggregationBuilders.global;
import static org.opensearch.search.aggregations.AggregationBuilders.histogram;
import static org.opensearch.search.aggregations.AggregationBuilders.max;
import static org.opensearch.search.aggregations.AggregationBuilders.stats;
import static org.opensearch.search.aggregations.AggregationBuilders.terms;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse;
Expand Down Expand Up @@ -95,6 +104,7 @@ public class AggregationProfilerIT extends OpenSearchIntegTestCase {
private static final String NUMBER_FIELD = "number";
private static final String TAG_FIELD = "tag";
private static final String STRING_FIELD = "string_field";
private final int numDocs = 5;

@Override
protected int numberOfShards() {
Expand All @@ -118,7 +128,7 @@ protected void setupSuiteScopeCluster() throws Exception {
randomStrings[i] = randomAlphaOfLength(10);
}

for (int i = 0; i < 5; i++) {
for (int i = 0; i < numDocs; i++) {
builders.add(
client().prepareIndex("idx")
.setSource(
Expand Down Expand Up @@ -633,4 +643,68 @@ public void testNoProfile() {
assertThat(profileResults, notNullValue());
assertThat(profileResults.size(), equalTo(0));
}

public void testGlobalAggWithStatsSubAggregatorProfile() {
boolean profileEnabled = true;
SearchResponse response = client().prepareSearch("idx")
.addAggregation(global("global").subAggregation(stats("value_stats").field(NUMBER_FIELD)))
.setProfile(profileEnabled)
.get();

assertSearchResponse(response);

Global global = response.getAggregations().get("global");
assertThat(global, IsNull.notNullValue());
assertThat(global.getName(), equalTo("global"));
assertThat(global.getDocCount(), equalTo((long) numDocs));
assertThat((long) ((InternalAggregation) global).getProperty("_count"), equalTo((long) numDocs));
assertThat(global.getAggregations().asList().isEmpty(), is(false));

Stats stats = global.getAggregations().get("value_stats");
assertThat((Stats) ((InternalAggregation) global).getProperty("value_stats"), sameInstance(stats));
assertThat(stats, IsNull.notNullValue());
assertThat(stats.getName(), equalTo("value_stats"));

Map<String, ProfileShardResult> profileResults = response.getProfileResults();
assertThat(profileResults, notNullValue());
assertThat(profileResults.size(), equalTo(getNumShards("idx").numPrimaries));
for (ProfileShardResult profileShardResult : profileResults.values()) {
assertThat(profileShardResult, notNullValue());
List<QueryProfileShardResult> queryProfileShardResults = profileShardResult.getQueryProfileResults();
assertEquals(queryProfileShardResults.size(), 2);
// ensure there is no multi collector getting added with only global agg
for (QueryProfileShardResult queryProfileShardResult : queryProfileShardResults) {
assertEquals(queryProfileShardResult.getQueryResults().size(), 1);
if (queryProfileShardResult.getQueryResults().get(0).getQueryName().equals("MatchAllDocsQuery")) {
assertEquals(0, queryProfileShardResult.getQueryResults().get(0).getProfiledChildren().size());
assertEquals("search_top_hits", queryProfileShardResult.getCollectorResult().getReason());
assertEquals(0, queryProfileShardResult.getCollectorResult().getProfiledChildren().size());
} else if (queryProfileShardResult.getQueryResults().get(0).getQueryName().equals("ConstantScoreQuery")) {
assertEquals(1, queryProfileShardResult.getQueryResults().get(0).getProfiledChildren().size());
assertEquals("aggregation_global", queryProfileShardResult.getCollectorResult().getReason());
assertEquals(0, queryProfileShardResult.getCollectorResult().getProfiledChildren().size());
} else {
fail("unexpected profile shard result in the response");
}
}
AggregationProfileShardResult aggProfileResults = profileShardResult.getAggregationProfileResults();
assertThat(aggProfileResults, notNullValue());
List<ProfileResult> aggProfileResultsList = aggProfileResults.getProfileResults();
assertThat(aggProfileResultsList, notNullValue());
assertEquals(1, aggProfileResultsList.size());
ProfileResult globalAggResult = aggProfileResultsList.get(0);
assertThat(globalAggResult, notNullValue());
assertEquals("GlobalAggregator", globalAggResult.getQueryName());
assertEquals("global", globalAggResult.getLuceneDescription());
assertEquals(1, globalAggResult.getProfiledChildren().size());
assertThat(globalAggResult.getTime(), greaterThan(0L));
Map<String, Long> breakdown = globalAggResult.getTimeBreakdown();
assertThat(breakdown, notNullValue());
assertEquals(BREAKDOWN_KEYS, breakdown.keySet());
assertThat(breakdown.get(INITIALIZE), greaterThan(0L));
assertThat(breakdown.get(COLLECT), greaterThan(0L));
assertThat(breakdown.get(BUILD_AGGREGATION).longValue(), greaterThan(0L));
assertEquals(0, breakdown.get(REDUCE).intValue());
}
}
}
26 changes: 14 additions & 12 deletions server/src/main/java/org/opensearch/search/query/QueryPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -71,6 +72,7 @@

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -234,19 +236,19 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q
// this collector can filter documents during the collection
hasFilterCollector = true;
}
if (searchContext.queryCollectorManagers().isEmpty() == false) {
// plug in additional collectors, like aggregations except global aggregations
collectors.add(
createMultiCollectorContext(
searchContext.queryCollectorManagers()
.entrySet()
.stream()
.filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class)))
.map(Map.Entry::getValue)
.collect(Collectors.toList())
)
);

// plug in additional collectors, like aggregations except global aggregations
final List<CollectorManager<? extends Collector, ReduceableSearchResult>> managersExceptGlobalAgg = searchContext
.queryCollectorManagers()
.entrySet()
.stream()
.filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class)))
.map(Map.Entry::getValue)
.collect(Collectors.toList());
if (managersExceptGlobalAgg.isEmpty() == false) {
collectors.add(createMultiCollectorContext(managersExceptGlobalAgg));
}

if (searchContext.minimumScore() != null) {
// apply the minimum score after multi collector so we filter aggs as well
collectors.add(createMinScoreCollectorContext(searchContext.minimumScore()));
Expand Down