Skip to content

Commit

Permalink
Add bulkScorer to script score query (#46336) (#49734)
Browse files Browse the repository at this point in the history
Some queries return bulk scorers that can be significantly faster than
iterating naively over the scorer. By giving script_score a BulkScorer
that would delegate to the wrapped query, we could make it faster in some cases.

Closes #40837
  • Loading branch information
mayya-sharipova authored Nov 29, 2019
1 parent 1d745f1 commit 62a891b
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.FilterLeafCollector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.util.Bits;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.script.ScoreScript;
Expand Down Expand Up @@ -83,6 +88,19 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
Weight subQueryWeight = subQuery.createWeight(searcher, subQueryScoreMode, boost);

return new Weight(this){
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
if (minScore == null) {
final BulkScorer subQueryBulkScorer = subQueryWeight.bulkScorer(context);
if (subQueryBulkScorer == null) {
return null;
}
return new ScriptScoreBulkScorer(subQueryBulkScorer, subQueryScoreMode, makeScoreScript(context));
} else {
return super.bulkScorer(context);
}
}

@Override
public void extractTerms(Set<Term> terms) {
subQueryWeight.extractTerms(terms);
Expand All @@ -94,8 +112,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
if (subQueryScorer == null) {
return null;
}
Scorer scriptScorer = makeScriptScorer(subQueryScorer, context, null);

Scorer scriptScorer = new ScriptScorer(this, makeScoreScript(context), subQueryScorer, subQueryScoreMode, null);
if (minScore != null) {
scriptScorer = new MinScoreScorer(this, scriptScorer, minScore);
}
Expand All @@ -109,7 +126,8 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
return subQueryExplanation;
}
ExplanationHolder explanationHolder = new ExplanationHolder();
Scorer scorer = makeScriptScorer(subQueryWeight.scorer(context), context, explanationHolder);
Scorer scorer = new ScriptScorer(this, makeScoreScript(context),
subQueryWeight.scorer(context), subQueryScoreMode, explanationHolder);
int newDoc = scorer.iterator().advance(doc);
assert doc == newDoc; // subquery should have already matched above
float score = scorer.score();
Expand All @@ -132,42 +150,13 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
}
return explanation;
}

private Scorer makeScriptScorer(Scorer subQueryScorer, LeafReaderContext context,
ExplanationHolder explanation) throws IOException {

private ScoreScript makeScoreScript(LeafReaderContext context) throws IOException {
final ScoreScript scoreScript = scriptBuilder.newInstance(context);
scoreScript.setScorer(subQueryScorer);
scoreScript._setIndexName(indexName);
scoreScript._setShard(shardId);
scoreScript._setIndexVersion(indexVersion);

return new Scorer(this) {
@Override
public float score() throws IOException {
int docId = docID();
scoreScript.setDocument(docId);
float score = (float) scoreScript.execute(explanation);
if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) {
throw new ElasticsearchException(
"script score query returned an invalid score: " + score + " for doc: " + docId);
}
return score;
}
@Override
public int docID() {
return subQueryScorer.docID();
}

@Override
public DocIdSetIterator iterator() {
return subQueryScorer.iterator();
}

@Override
public float getMaxScore(int upTo) {
return Float.MAX_VALUE; // TODO: what would be a good upper bound?
}
};
return scoreScript;
}

@Override
Expand All @@ -187,7 +176,7 @@ public void visit(QueryVisitor visitor) {
@Override
public String toString(String field) {
StringBuilder sb = new StringBuilder();
sb.append("script score (").append(subQuery.toString(field)).append(", script: ");
sb.append("script_score (").append(subQuery.toString(field)).append(", script: ");
sb.append("{" + script.toString() + "}");
return sb.toString();
}
Expand All @@ -209,4 +198,118 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(subQuery, script, minScore, indexName, shardId, indexVersion);
}


private static class ScriptScorer extends Scorer {
private final ScoreScript scoreScript;
private final Scorer subQueryScorer;
private final ExplanationHolder explanation;

ScriptScorer(Weight weight, ScoreScript scoreScript, Scorer subQueryScorer,
ScoreMode subQueryScoreMode, ExplanationHolder explanation) {
super(weight);
this.scoreScript = scoreScript;
if (subQueryScoreMode == ScoreMode.COMPLETE) {
scoreScript.setScorer(subQueryScorer);
}
this.subQueryScorer = subQueryScorer;
this.explanation = explanation;
}

@Override
public float score() throws IOException {
int docId = docID();
scoreScript.setDocument(docId);
float score = (float) scoreScript.execute(explanation);
if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) {
throw new ElasticsearchException(
"script_score query returned an invalid score [" + score + "] for doc [" + docId + "].");
}
return score;
}
@Override
public int docID() {
return subQueryScorer.docID();
}

@Override
public DocIdSetIterator iterator() {
return subQueryScorer.iterator();
}

@Override
public float getMaxScore(int upTo) {
return Float.MAX_VALUE; // TODO: what would be a good upper bound?
}

}

private static class ScriptScorable extends Scorable {
private final ScoreScript scoreScript;
private final Scorable subQueryScorer;
private final ExplanationHolder explanation;

ScriptScorable(ScoreScript scoreScript, Scorable subQueryScorer,
ScoreMode subQueryScoreMode, ExplanationHolder explanation) {
this.scoreScript = scoreScript;
if (subQueryScoreMode == ScoreMode.COMPLETE) {
scoreScript.setScorer(subQueryScorer);
}
this.subQueryScorer = subQueryScorer;
this.explanation = explanation;
}

@Override
public float score() throws IOException {
int docId = docID();
scoreScript.setDocument(docId);
float score = (float) scoreScript.execute(explanation);
if (score == Float.NEGATIVE_INFINITY || Float.isNaN(score)) {
throw new ElasticsearchException(
"script_score query returned an invalid score [" + score + "] for doc [" + docId + "].");
}
return score;
}
@Override
public int docID() {
return subQueryScorer.docID();
}
}

/**
* Use the {@link BulkScorer} of the sub-query,
* as it may be significantly faster (e.g. BooleanScorer) than iterating over the scorer
*/
private static class ScriptScoreBulkScorer extends BulkScorer {
private final BulkScorer subQueryBulkScorer;
private final ScoreMode subQueryScoreMode;
private final ScoreScript scoreScript;

ScriptScoreBulkScorer(BulkScorer subQueryBulkScorer, ScoreMode subQueryScoreMode, ScoreScript scoreScript) {
this.subQueryBulkScorer = subQueryBulkScorer;
this.subQueryScoreMode = subQueryScoreMode;
this.scoreScript = scoreScript;
}

@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
return subQueryBulkScorer.score(wrapCollector(collector), acceptDocs, min, max);
}

private LeafCollector wrapCollector(LeafCollector collector) {
return new FilterLeafCollector(collector) {
@Override
public void setScorer(Scorable scorer) throws IOException {
in.setScorer(new ScriptScorable(scoreScript, scorer, subQueryScoreMode, null));
}
};
}

@Override
public long cost() {
return subQueryBulkScorer.cost();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.script.MockScriptPlugin;
Expand All @@ -35,6 +36,7 @@
import java.util.Map;
import java.util.function.Function;

import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
import static org.elasticsearch.index.query.QueryBuilders.scriptScoreQuery;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
Expand Down Expand Up @@ -104,6 +106,33 @@ public void testScriptScore() {
assertOrderedSearchHits(resp, "10", "8", "6");
}

public void testScriptScoreBoolQuery() {
assertAcked(
prepareCreate("test-index").addMapping("_doc", "field1", "type=text", "field2", "type=double")
);
int docCount = 10;
for (int i = 1; i <= docCount; i++) {
client().prepareIndex("test-index", "_doc", "" + i)
.setSource("field1", "text" + i, "field2", i)
.get();
}
refresh();

Map<String, Object> params = new HashMap<>();
params.put("param1", 0.1);
Script script = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "doc['field2'].value * param1", params);
QueryBuilder boolQuery = boolQuery().should(matchQuery("field1", "text1")).should(matchQuery("field1", "text10"));
SearchResponse resp = client()
.prepareSearch("test-index")
.setQuery(scriptScoreQuery(boolQuery, script))
.get();
assertNoFailures(resp);
assertOrderedSearchHits(resp, "10", "1");
assertFirstHit(resp, hasScore(1.0f));
assertSecondHit(resp, hasScore(0.1f));
}


// test that when the internal query is rewritten script_score works well
public void testRewrittenQuery() {
assertAcked(
Expand Down

0 comments on commit 62a891b

Please sign in to comment.