Skip to content

Commit

Permalink
Add a post-collection hook to LeafCollector. (apache#12380)
Browse files Browse the repository at this point in the history
This adds `LeafCollector#finish` as a per-segment post-collection hook. While
it was already possible to do this sort of things on top of the collector API
before, a downside is that the last leaf would need to be post-collected in the
current thread instead of using the executor, which is a missed opportunity for
making queries concurrent.
  • Loading branch information
jpountz committed Jun 30, 2023
1 parent aa04747 commit de07bdd
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 80 deletions.
4 changes: 3 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ API Changes

New Features
---------------------
(No changes)

* GITHUB#12383: Introduced LeafCollector#finish, a hook that runs after
collection has finished running on a leaf. (Adrien Grand)

Improvements
---------------------
Expand Down
71 changes: 35 additions & 36 deletions lucene/core/src/java/org/apache/lucene/search/CachingCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ private static class NoScoreCachingCollector extends CachingCollector {
List<LeafReaderContext> contexts;
List<int[]> docs;
int maxDocsToCache;
NoScoreCachingLeafCollector lastCollector;

NoScoreCachingCollector(Collector in, int maxDocsToCache) {
super(in);
Expand All @@ -76,7 +75,7 @@ private static class NoScoreCachingCollector extends CachingCollector {
}

protected NoScoreCachingLeafCollector wrap(LeafCollector in, int maxDocsToCache) {
return new NoScoreCachingLeafCollector(in, maxDocsToCache);
return new NoScoreCachingLeafCollector(in, maxDocsToCache, this);
}

// note: do *not* override needScore to say false. Just because we aren't caching the score
Expand All @@ -85,13 +84,12 @@ protected NoScoreCachingLeafCollector wrap(LeafCollector in, int maxDocsToCache)

@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
postCollection();
final LeafCollector in = this.in.getLeafCollector(context);
if (contexts != null) {
contexts.add(context);
}
if (maxDocsToCache >= 0) {
return lastCollector = wrap(in, maxDocsToCache);
if (contexts != null) {
contexts.add(context);
}
return wrap(in, maxDocsToCache);
} else {
return in;
}
Expand All @@ -103,33 +101,16 @@ protected void invalidate() {
this.docs = null;
}

protected void postCollect(NoScoreCachingLeafCollector collector) {
final int[] docs = collector.cachedDocs();
maxDocsToCache -= docs.length;
this.docs.add(docs);
}

private void postCollection() {
if (lastCollector != null) {
if (!lastCollector.hasCache()) {
invalidate();
} else {
postCollect(lastCollector);
}
lastCollector = null;
}
}

protected void collect(LeafCollector collector, int i) throws IOException {
final int[] docs = this.docs.get(i);
for (int doc : docs) {
collector.collect(doc);
}
collector.finish();
}

@Override
public void replay(Collector other) throws IOException {
postCollection();
if (!isCached()) {
throw new IllegalStateException(
"cannot replay: cache was cleared because too much RAM was required");
Expand All @@ -154,14 +135,7 @@ private static class ScoreCachingCollector extends NoScoreCachingCollector {

@Override
protected NoScoreCachingLeafCollector wrap(LeafCollector in, int maxDocsToCache) {
return new ScoreCachingLeafCollector(in, maxDocsToCache);
}

@Override
protected void postCollect(NoScoreCachingLeafCollector collector) {
final ScoreCachingLeafCollector coll = (ScoreCachingLeafCollector) collector;
super.postCollect(coll);
scores.add(coll.cachedScores());
return new ScoreCachingLeafCollector(in, maxDocsToCache, this);
}

/**
Expand Down Expand Up @@ -191,12 +165,15 @@ protected void collect(LeafCollector collector, int i) throws IOException {
private class NoScoreCachingLeafCollector extends FilterLeafCollector {

final int maxDocsToCache;
final NoScoreCachingCollector collector;
int[] docs;
int docCount;

NoScoreCachingLeafCollector(LeafCollector in, int maxDocsToCache) {
NoScoreCachingLeafCollector(
LeafCollector in, int maxDocsToCache, NoScoreCachingCollector collector) {
super(in);
this.maxDocsToCache = maxDocsToCache;
this.collector = collector;
docs = new int[Math.min(maxDocsToCache, INITIAL_ARRAY_SIZE)];
docCount = 0;
}
Expand Down Expand Up @@ -235,6 +212,21 @@ public void collect(int doc) throws IOException {
super.collect(doc);
}

protected void postCollect() {
final int[] docs = cachedDocs();
collector.maxDocsToCache -= docs.length;
collector.docs.add(docs);
}

@Override
public void finish() {
if (!hasCache()) {
collector.invalidate();
} else {
postCollect();
}
}

boolean hasCache() {
return docs != null;
}
Expand All @@ -249,8 +241,9 @@ private class ScoreCachingLeafCollector extends NoScoreCachingLeafCollector {
Scorable scorer;
float[] scores;

ScoreCachingLeafCollector(LeafCollector in, int maxDocsToCache) {
super(in, maxDocsToCache);
ScoreCachingLeafCollector(
LeafCollector in, int maxDocsToCache, ScoreCachingCollector collector) {
super(in, maxDocsToCache, collector);
scores = new float[docs.length];
}

Expand Down Expand Up @@ -281,6 +274,12 @@ protected void buffer(int doc) throws IOException {
float[] cachedScores() {
return docs == null ? null : ArrayUtil.copyOfSubArray(scores, 0, docCount);
}

@Override
protected void postCollect() {
super.postCollect();
((ScoreCachingCollector) collector).scores.add(cachedScores());
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ public void collect(int doc) throws IOException {
in.collect(doc);
}

@Override
public void finish() throws IOException {
in.finish();
}

@Override
public String toString() {
String name = getClass().getSimpleName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,9 @@ protected void search(List<LeafReaderContext> leaves, Weight weight, Collector c
partialResult = true;
}
}
// Note: this is called if collection ran successfully, including the above special cases of
// CollectionTerminatedException and TimeExceededException, but no other exception.
leafCollector.finish();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,13 @@ public interface LeafCollector {
default DocIdSetIterator competitiveIterator() throws IOException {
return null;
}

/**
* Hook that gets called once the leaf that is associated with this collector has finished
* collecting successfully, including when a {@link CollectionTerminatedException} is thrown. This
* is typically useful to compile data that has been collected on this leaf, e.g. to convert facet
* counts on leaf ordinals to facet counts on global ordinals. The default implementation does
* nothing.
*/
default void finish() throws IOException {}
}
10 changes: 10 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/MultiCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ public void collect(int doc) throws IOException {
} catch (
@SuppressWarnings("unused")
CollectionTerminatedException e) {
collectors[i].finish();
collectors[i] = null;
if (allCollectorsTerminated()) {
throw new CollectionTerminatedException();
Expand All @@ -232,6 +233,15 @@ public void collect(int doc) throws IOException {
}
}

@Override
public void finish() throws IOException {
for (LeafCollector collector : collectors) {
if (collector != null) {
collector.finish();
}
}
}

private boolean allCollectorsTerminated() {
for (int i = 0; i < collectors.length; i++) {
if (collectors[i] != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public void testBasic() throws Exception {
for (int i = 0; i < 1000; i++) {
acc.collect(i);
}
acc.finish();

// now replay them
cc.replay(
Expand Down Expand Up @@ -127,6 +128,7 @@ public void testNoWrappedCollector() throws Exception {
acc.collect(0);

assertTrue(cc.isCached());
acc.finish();
cc.replay(new NoOpCollector());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -198,6 +199,7 @@ private void doQueryFirstScoringSingleDim(

docID = baseApproximation.nextDoc();
}
finish(collector, Collections.singleton(dim));
}

/**
Expand Down Expand Up @@ -334,6 +336,8 @@ protected boolean lessThan(DocsAndCost a, DocsAndCost b) {

docID = baseApproximation.nextDoc();
}

finish(collector, sidewaysDims);
}

private static int advanceIfBehind(int docID, DocIdSetIterator iterator) throws IOException {
Expand Down Expand Up @@ -552,6 +556,7 @@ private void doDrillDownAdvanceScoring(

nextChunkStart += CHUNK;
}
finish(collector, Arrays.asList(dims));
}

private void doUnionScoring(Bits acceptDocs, LeafCollector collector, DocsAndCost[] dims)
Expand Down Expand Up @@ -706,6 +711,8 @@ private void doUnionScoring(Bits acceptDocs, LeafCollector collector, DocsAndCos

nextChunkStart += CHUNK;
}

finish(collector, Arrays.asList(dims));
}

private void collectHit(LeafCollector collector, DocsAndCost[] dims) throws IOException {
Expand Down Expand Up @@ -757,6 +764,16 @@ private void collectNearMiss(LeafCollector sidewaysCollector) throws IOException
sidewaysCollector.collect(collectDocID);
}

private void finish(LeafCollector collector, Collection<DocsAndCost> dims) throws IOException {
collector.finish();
if (drillDownLeafCollector != null) {
drillDownLeafCollector.finish();
}
for (DocsAndCost dim : dims) {
dim.sidewaysLeafCollector.finish();
}
}

private void setScorer(LeafCollector mainCollector, Scorable scorer) throws IOException {
mainCollector.setScorer(scorer);
if (drillDownLeafCollector != null) {
Expand Down
19 changes: 9 additions & 10 deletions lucene/facet/src/java/org/apache/lucene/facet/FacetsCollector.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ public final boolean getKeepScores() {

/** Returns the documents matched by the query, one {@link MatchingDocs} per visited segment. */
public List<MatchingDocs> getMatchingDocs() {
if (docsBuilder != null) {
matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
docsBuilder = null;
scores = null;
context = null;
}

return matchingDocs;
}

Expand Down Expand Up @@ -139,9 +132,7 @@ public final void setScorer(Scorable scorer) throws IOException {

@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
if (docsBuilder != null) {
matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
}
assert docsBuilder == null;
docsBuilder = new DocIdSetBuilder(context.reader().maxDoc());
totalHits = 0;
if (keepScores) {
Expand All @@ -150,6 +141,14 @@ protected void doSetNextReader(LeafReaderContext context) throws IOException {
this.context = context;
}

@Override
public void finish() throws IOException {
matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
docsBuilder = null;
scores = null;
context = null;
}

/** Utility method, to search and also collect all hits into the provided {@link Collector}. */
public static TopDocs search(IndexSearcher searcher, Query q, int n, Collector fc)
throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,6 @@ public TopGroups<?> getTopGroups(
// if (queueFull) {
// System.out.println("getTopGroups groupOffset=" + groupOffset + " topNGroups=" + topNGroups);
// }
if (subDocUpto != 0) {
processGroup();
}
if (groupOffset >= groupQueue.size()) {
return null;
}
Expand Down Expand Up @@ -472,9 +469,6 @@ public void collect(int doc) throws IOException {

@Override
protected void doSetNextReader(LeafReaderContext readerContext) throws IOException {
if (subDocUpto != 0) {
processGroup();
}
subDocUpto = 0;
docBase = readerContext.docBase;
// System.out.println("setNextReader base=" + docBase + " r=" + readerContext.reader);
Expand All @@ -492,6 +486,13 @@ protected void doSetNextReader(LeafReaderContext readerContext) throws IOExcepti
}
}

@Override
public void finish() throws IOException {
if (subDocUpto != 0) {
processGroup();
}
}

@Override
public ScoreMode scoreMode() {
return needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ protected GroupFacetCollector(String groupField, String facetField, BytesRef fac
*/
public GroupedFacetResult mergeSegmentResults(int size, int minCount, boolean orderByCount)
throws IOException {
if (segmentFacetCounts != null) {
segmentResults.add(createSegmentResult());
segmentFacetCounts = null; // reset
}

int totalCount = 0;
int missingCount = 0;
SegmentResultPriorityQueue segments = new SegmentResultPriorityQueue(segmentResults.size());
Expand Down Expand Up @@ -109,6 +104,12 @@ public GroupedFacetResult mergeSegmentResults(int size, int minCount, boolean or
return facetResult;
}

@Override
public void finish() throws IOException {
segmentResults.add(createSegmentResult());
segmentFacetCounts = null;
}

protected abstract SegmentResult createSegmentResult() throws IOException;

@Override
Expand Down
Loading

0 comments on commit de07bdd

Please sign in to comment.