Skip to content

Commit

Permalink
Fold scoring into existing Lucene operator, rather than adding a new …
Browse files Browse the repository at this point in the history
…scoring operator
  • Loading branch information
ChrisHegarty committed Nov 3, 2024
1 parent 15c340b commit 6f17272
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 291 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public abstract static class Factory implements SourceOperator.SourceOperatorFac
protected final DataPartitioning dataPartitioning;
protected final int taskConcurrency;
protected final int limit;
protected final ScoreMode scoreMode;
protected final LuceneSliceQueue sliceQueue;

/**
Expand All @@ -95,6 +96,7 @@ protected Factory(
ScoreMode scoreMode
) {
this.limit = limit;
this.scoreMode = scoreMode;
this.dataPartitioning = dataPartitioning;
var weightFunction = weightFunction(queryFunction, scoreMode);
this.sliceQueue = LuceneSliceQueue.create(contexts, weightFunction, dataPartitioning, taskConcurrency);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.apache.lucene.search.ScoreMode;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.DocVector;
import org.elasticsearch.compute.data.DoubleBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
Expand All @@ -25,6 +27,9 @@
import java.util.List;
import java.util.function.Function;

import static org.apache.lucene.search.ScoreMode.COMPLETE;
import static org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES;

/**
* Source operator that incrementally runs Lucene searches
*/
Expand All @@ -34,6 +39,7 @@ public class LuceneSourceOperator extends LuceneOperator {
private int remainingDocs;

private IntVector.Builder docsBuilder;
private DoubleVector.Builder scoreBuilder;
private final LeafCollector leafCollector;
private final int minPageSize;

Expand All @@ -47,15 +53,16 @@ public Factory(
DataPartitioning dataPartitioning,
int taskConcurrency,
int maxPageSize,
int limit
int limit,
boolean scoring
) {
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
this.maxPageSize = maxPageSize;
}

@Override
public SourceOperator get(DriverContext driverContext) {
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit);
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, scoreMode);
}

public int maxPageSize() {
Expand All @@ -70,32 +77,56 @@ public String describe() {
+ maxPageSize
+ ", limit = "
+ limit
+ ", scoreMode = "
+ scoreMode
+ "]";
}
}

public LuceneSourceOperator(BlockFactory blockFactory, int maxPageSize, LuceneSliceQueue sliceQueue, int limit) {
public LuceneSourceOperator(BlockFactory blockFactory, int maxPageSize, LuceneSliceQueue sliceQueue, int limit, ScoreMode scoreMode) {
super(blockFactory, maxPageSize, sliceQueue);
this.minPageSize = Math.max(1, maxPageSize / 2);
this.remainingDocs = limit;
this.docsBuilder = blockFactory.newIntVectorBuilder(Math.min(limit, maxPageSize));
this.leafCollector = new LeafCollector() {
@Override
public void setScorer(Scorable scorer) {
int estimatedSize = Math.min(limit, maxPageSize);
this.docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
if (scoreMode.needsScores()) {
scoreBuilder = blockFactory.newDoubleVectorBuilder(estimatedSize);
this.leafCollector = new ScoringCollector();
} else {
scoreBuilder = null;
this.leafCollector = new LimitingCollector();
}
}

}
class LimitingCollector implements LeafCollector {
@Override
public void setScorer(Scorable scorer) {}

@Override
public void collect(int doc) {
if (remainingDocs > 0) {
--remainingDocs;
docsBuilder.appendInt(doc);
currentPagePos++;
} else {
throw new CollectionTerminatedException();
}
@Override
public void collect(int doc) throws IOException {
if (remainingDocs > 0) {
--remainingDocs;
docsBuilder.appendInt(doc);
currentPagePos++;
} else {
throw new CollectionTerminatedException();
}
};
}
}

final class ScoringCollector extends LuceneSourceOperator.LimitingCollector {
private Scorable scorable;

@Override
public void setScorer(Scorable scorer) {
this.scorable = scorer;
}

@Override
public void collect(int doc) throws IOException {
super.collect(doc);
scoreBuilder.appendDouble(scorable.score());
}
}

@Override
Expand Down Expand Up @@ -139,15 +170,23 @@ public Page getCheckedOutput() throws IOException {
IntBlock shard = null;
IntBlock leaf = null;
IntVector docs = null;
DoubleBlock scores = null;
try {
shard = blockFactory.newConstantIntBlockWith(scorer.shardContext().index(), currentPagePos);
leaf = blockFactory.newConstantIntBlockWith(scorer.leafReaderContext().ord, currentPagePos);
docs = docsBuilder.build();
docsBuilder = blockFactory.newIntVectorBuilder(Math.min(remainingDocs, maxPageSize));
page = new Page(currentPagePos, new DocVector(shard.asVector(), leaf.asVector(), docs, true).asBlock());
var docBlock = new DocVector(shard.asVector(), leaf.asVector(), docs, true).asBlock();
if (scoreBuilder == null) {
page = new Page(currentPagePos, docBlock);
} else {
scores = scoreBuilder.build().asBlock();
scoreBuilder = blockFactory.newDoubleVectorBuilder(Math.min(remainingDocs, maxPageSize));
page = new Page(currentPagePos, docBlock, scores);
}
} finally {
if (page == null) {
Releasables.closeExpectNoException(shard, leaf, docs);
Releasables.closeExpectNoException(shard, leaf, docs, scores);
}
}
currentPagePos = 0;
Expand All @@ -161,6 +200,7 @@ public Page getCheckedOutput() throws IOException {
@Override
public void close() {
docsBuilder.close();
if (scoreBuilder != null) scoreBuilder.close();
}

@Override
Expand Down
Loading

0 comments on commit 6f17272

Please sign in to comment.