Skip to content

Commit

Permalink
Restore shortcut total hit count
Browse files Browse the repository at this point in the history
We have removed shortcut total hit count with elastic#89047 and later noticed a
couple of benchmark regressions. While we have moved to skip counting
when possible when not collecting hits (e.g. size=0), which is the case
where Elasticsearch uses TotalHitCountCollector and the shortcutting is
supported natively in Lucene.

For the case where hits are collected, the total hit count is counted as
part of the collection in TopScoreDocCollector and TopFieldCollector,
where Lucene does not support skipping the counting as it is hard to
determine whether more competitive hits need to be collected or not.

The previous change caused a regression specifically when collecting
hits because we ended up removing our manual shortcut in favour of
counting which causes overhead.

With this change we reintroduce the shortcut total hit count method,
and only use it when strictly necessary. When size is 0, we rely
entirely on Lucene to shortcut the total hit counting, while when hits
are collected we do it our way, for now.

While at it, a few more tests are added to cover for situations that
were not covered before.
  • Loading branch information
javanna committed Mar 29, 2023
1 parent 01daaf5 commit 69a800f
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@
import org.apache.lucene.search.ScoreCachingWrappingScorer;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.Weight;

import java.io.IOException;

/**
* Collector that wraps another collector and collects only documents that have a score that's greater or equal than the
* provided minimum score. Given that this collector filters documents out, it does and should not override {@link #setWeight(Weight)},
* as that may lead to exposing total hit count that does not reflect the filtering.
*/
public class MinimumScoreCollector extends SimpleCollector {

private final Collector collector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

import java.io.IOException;

/**
* Collector that wraps another collector and collects only documents that match the provided filter.
* Given that this collector filters documents out, it does and should not override {@link #setWeight(Weight)},
* as that may lead to exposing total hit count that does not reflect the filtering.
*/
public class FilteredCollector implements Collector {

private final Collector collector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,21 @@ static boolean executeInternal(SearchContext searchContext) throws QueryPhaseExe
}

final LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
// whether the chain contains a collector that filters documents
boolean hasFilterCollector = false;
if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER) {
// add terminate_after before the filter collectors
// it will only be applied on documents accepted by these filter collectors
collectors.add(createEarlyTerminationCollectorContext(searchContext.terminateAfter()));
// this collector can filter documents during the collection
hasFilterCollector = true;
}
if (searchContext.parsedPostFilter() != null) {
// add post filters before aggregations
// it will only be applied to top hits
collectors.add(createFilteredCollectorContext(searcher, searchContext.parsedPostFilter().query()));
// this collector can filter documents during the collection
hasFilterCollector = true;
}
if (searchContext.getAggsCollector() != null) {
// plug in additional collectors, like aggregations
Expand All @@ -146,6 +152,8 @@ static boolean executeInternal(SearchContext searchContext) throws QueryPhaseExe
if (searchContext.minimumScore() != null) {
// apply the minimum score after multi collector so we filter aggs as well
collectors.add(createMinScoreCollectorContext(searchContext.minimumScore()));
// this collector can filter documents during the collection
hasFilterCollector = true;
}

boolean timeoutSet = scrollContext == null
Expand All @@ -168,7 +176,7 @@ static boolean executeInternal(SearchContext searchContext) throws QueryPhaseExe
}

try {
boolean shouldRescore = searchWithCollector(searchContext, searcher, query, collectors, timeoutSet);
boolean shouldRescore = searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, timeoutSet);
ExecutorService executor = searchContext.indexShard().getThreadPool().executor(ThreadPool.Names.SEARCH);
assert executor instanceof EWMATrackingEsThreadPoolExecutor
|| (executor instanceof EsThreadPoolExecutor == false /* in case thread pool is mocked out in tests */)
Expand All @@ -195,10 +203,11 @@ private static boolean searchWithCollector(
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean timeoutSet
) throws IOException {
// create the top docs collector last when the other collectors are known
final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext);
final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector);
// add the top docs collector, the first collector context in the chain
collectors.addFirst(topDocsFactory);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,30 @@

package org.elasticsearch.search.query;

import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.queries.spans.SpanQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
Expand Down Expand Up @@ -73,6 +86,11 @@ static class EmptyTopDocsCollectorContext extends TopDocsCollectorContext {
private final Collector collector;
private final Supplier<TotalHits> hitCountSupplier;

/**
* Ctr
* @param sortAndFormats The sort clause if provided
* @param trackTotalHitsUpTo The threshold up to which total hit count needs to be tracked
*/
private EmptyTopDocsCollectorContext(@Nullable SortAndFormats sortAndFormats, int trackTotalHitsUpTo) {
super(REASON_SEARCH_COUNT, 0);
this.sort = sortAndFormats == null ? null : sortAndFormats.sort;
Expand All @@ -82,7 +100,6 @@ private EmptyTopDocsCollectorContext(@Nullable SortAndFormats sortAndFormats, in
this.hitCountSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
} else {
TotalHitCountCollector hitCountCollector = new TotalHitCountCollector();
// implicit total hit counts are valid only when there is no filter collector in the chain
if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_ACCURATE) {
this.collector = hitCountCollector;
this.hitCountSupplier = () -> new TotalHits(hitCountCollector.getTotalHits(), TotalHits.Relation.EQUAL_TO);
Expand Down Expand Up @@ -187,21 +204,25 @@ private static TopDocsCollector<?> createCollector(

/**
* Ctr
* @param reader The index reader
* @param query The Lucene query
* @param sortAndFormats The query sort
* @param sortAndFormats The sort clause if provided
* @param numHits The number of top hits to retrieve
* @param searchAfter The doc this request should "search after"
* @param trackMaxScore True if max score should be tracked
* @param trackTotalHitsUpTo True if the total number of hits should be tracked
* @param trackTotalHitsUpTo Threshold up to which total hit count should be tracked
* @param hasFilterCollector True if the collector chain contains at least one collector that can filter documents out
*/
private SimpleTopDocsCollectorContext(
IndexReader reader,
Query query,
@Nullable SortAndFormats sortAndFormats,
@Nullable ScoreDoc searchAfter,
int numHits,
boolean trackMaxScore,
int trackTotalHitsUpTo
) {
int trackTotalHitsUpTo,
boolean hasFilterCollector
) throws IOException {
super(REASON_SEARCH_TOP_HITS, numHits);
this.sortAndFormats = sortAndFormats;

Expand All @@ -219,9 +240,18 @@ private SimpleTopDocsCollectorContext(
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
} else {
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, trackTotalHitsUpTo);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> topDocsSupplier.get().totalHits;
// implicit total hit counts are valid only when there is no filter collector in the chain
final int hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query);
if (hitCount == -1) {
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, trackTotalHitsUpTo);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> topDocsSupplier.get().totalHits;
} else {
// don't compute hit counts via the collector
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO);
}
}
MaxScoreCollector maxScoreCollector = null;
if (sortAndFormats == null) {
Expand Down Expand Up @@ -263,7 +293,7 @@ TopDocsAndMaxScore newTopDocs() {
}

@Override
void postProcess(QuerySearchResult result) {
void postProcess(QuerySearchResult result) throws IOException {
final TopDocsAndMaxScore topDocs = newTopDocs();
result.topDocs(topDocs, sortAndFormats == null ? null : sortAndFormats.formats);
}
Expand All @@ -274,21 +304,32 @@ static class ScrollingTopDocsCollectorContext extends SimpleTopDocsCollectorCont
private final int numberOfShards;

private ScrollingTopDocsCollectorContext(
IndexReader reader,
Query query,
ScrollContext scrollContext,
@Nullable SortAndFormats sortAndFormats,
int numHits,
boolean trackMaxScore,
int numberOfShards,
int trackTotalHitsUpTo
) {
super(query, sortAndFormats, scrollContext.lastEmittedDoc, numHits, trackMaxScore, trackTotalHitsUpTo);
int trackTotalHitsUpTo,
boolean hasFilterCollector
) throws IOException {
super(
reader,
query,
sortAndFormats,
scrollContext.lastEmittedDoc,
numHits,
trackMaxScore,
trackTotalHitsUpTo,
hasFilterCollector
);
this.scrollContext = Objects.requireNonNull(scrollContext);
this.numberOfShards = numberOfShards;
}

@Override
void postProcess(QuerySearchResult result) {
void postProcess(QuerySearchResult result) throws IOException {
final TopDocsAndMaxScore topDocs = newTopDocs();
if (scrollContext.totalHits == null) {
// first round
Expand All @@ -311,10 +352,72 @@ void postProcess(QuerySearchResult result) {
}
}

/**
* Returns query total hit count if the <code>query</code> is a {@link MatchAllDocsQuery}
* or a {@link TermQuery} and the <code>reader</code> has no deletions,
* -1 otherwise.
*/
static int shortcutTotalHitCount(IndexReader reader, Query query) throws IOException {
while (true) {
// remove wrappers that don't matter for counts
// this is necessary so that we don't only optimize match_all
// queries but also match_all queries that are nested in
// a constant_score query
if (query instanceof ConstantScoreQuery) {
query = ((ConstantScoreQuery) query).getQuery();
} else if (query instanceof BoostQuery) {
query = ((BoostQuery) query).getQuery();
} else {
break;
}
}
if (query.getClass() == MatchAllDocsQuery.class) {
return reader.numDocs();
} else if (query.getClass() == TermQuery.class && reader.hasDeletions() == false) {
final Term term = ((TermQuery) query).getTerm();
int count = 0;
for (LeafReaderContext context : reader.leaves()) {
count += context.reader().docFreq(term);
}
return count;
} else if (query.getClass() == FieldExistsQuery.class && reader.hasDeletions() == false) {
final String field = ((FieldExistsQuery) query).getField();
int count = 0;
for (LeafReaderContext context : reader.leaves()) {
FieldInfos fieldInfos = context.reader().getFieldInfos();
FieldInfo fieldInfo = fieldInfos.fieldInfo(field);
if (fieldInfo != null) {
if (fieldInfo.getDocValuesType() == DocValuesType.NONE) {
// no shortcut possible: it's a text field, empty values are counted as no value.
return -1;
}
if (fieldInfo.getPointIndexDimensionCount() > 0) {
PointValues points = context.reader().getPointValues(field);
if (points != null) {
count += points.getDocCount();
}
} else if (fieldInfo.getIndexOptions() != IndexOptions.NONE) {
Terms terms = context.reader().terms(field);
if (terms != null) {
count += terms.getDocCount();
}
} else {
return -1; // no shortcut possible for fields that are not indexed
}
}
}
return count;
} else {
return -1;
}
}

/**
* Creates a {@link TopDocsCollectorContext} from the provided <code>searchContext</code>.
* @param hasFilterCollector True if the collector chain contains at least one collector that can filters document.
*/
static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searchContext) {
static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searchContext, boolean hasFilterCollector)
throws IOException {
if (searchContext.size() == 0) {
// no matter what the value of from is
return new EmptyTopDocsCollectorContext(searchContext.sort(), searchContext.trackTotalHitsUpTo());
Expand All @@ -332,13 +435,15 @@ static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searc
// no matter what the value of from is
int numDocs = Math.min(searchContext.size(), totalNumDocs);
return new ScrollingTopDocsCollectorContext(
reader,
query,
searchContext.scrollContext(),
searchContext.sort(),
numDocs,
searchContext.trackScores(),
searchContext.numberOfShards(),
trackTotalHitsUpTo
trackTotalHitsUpTo,
hasFilterCollector
);
} else if (searchContext.collapse() != null) {
boolean trackScores = searchContext.sort() == null ? true : searchContext.trackScores();
Expand All @@ -360,12 +465,14 @@ static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searc
}
}
return new SimpleTopDocsCollectorContext(
reader,
query,
searchContext.sort(),
searchContext.searchAfter(),
numDocs,
searchContext.trackScores(),
searchContext.trackTotalHitsUpTo()
searchContext.trackTotalHitsUpTo(),
hasFilterCollector
) {
@Override
boolean shouldRescore() {
Expand Down
Loading

0 comments on commit 69a800f

Please sign in to comment.