Skip to content

Commit

Permalink
Allow rescorer with field collapsing
Browse files Browse the repository at this point in the history
This change adds the support for rescoring collapsed documents.
The rescoring is applied on the top document per group on each shard.

Closes elastic#27243
  • Loading branch information
jimczi committed Apr 23, 2024
1 parent 1c4e0b2 commit 66f4754
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ NOTE: Collapsing is applied to the top hits only and does not affect aggregation
[[expand-collapse-results]]
==== Expand collapse results

It is also possible to expand each collapsed top hits with the `inner_hits` option.
It is also possible to expand each collapsed top hits with the <<inner-hits, `inner hits`>> option.

[source,console]
----
Expand Down Expand Up @@ -86,7 +86,7 @@ GET /my-index-000001/_search

See <<inner-hits, inner hits>> for the complete list of supported options and the format of the response.

It is also possible to request multiple `inner_hits` for each collapsed hit. This can be useful when you want to get
It is also possible to request multiple <<inner-hits, `inner hits`>> for each collapsed hit. This can be useful when you want to get
multiple representations of the collapsed hits.

[source,console]
Expand Down Expand Up @@ -145,8 +145,7 @@ The `max_concurrent_group_searches` request parameter can be used to control
the maximum number of concurrent searches allowed in this phase.
The default is based on the number of data nodes and the default search thread pool size.

WARNING: `collapse` cannot be used in conjunction with <<scroll-search-results, scroll>> or
<<rescore, rescore>>.
WARNING: `collapse` cannot be used in conjunction with <<scroll-search-results, scroll>>.

[discrete]
[[collapsing-with-search-after]]
Expand Down Expand Up @@ -175,6 +174,65 @@ GET /my-index-000001/_search
----
// TEST[setup:my_index]

[discrete]
[[expand-collapse-results]]
==== Rescore collapse results

You can use field collapsing alongside the <<rescore, `rescore`>> search parameter.
Rescorers runs on every shard for the top-ranked document per collapsed field.
To maintain a reliable order, it is recommended to cluster documents sharing the same collapse
field value on one shard.
This is achieved by assigning the collapse field value as the <<search-routing, routing key>>
during indexing:

[source,console]
----
POST /my-index-000001/_doc?routing=xyz <1>
{
"@timestamp": "2099-11-15T13:12:00",
"message": "You know for search!",
"user.id": "xyz"
}
----
// TEST[setup:my_index]
<1> Assign routing with the collapse field value (`user.id`).

By doing this, you guarantee that only one top document per
collapse key gets rescored globally.

The following request utilizes field collapsing on the `user.id`
field and then rescores the top groups with a <<query-rescorer, query rescorer>>:

[source,console]
----
GET /my-index-000001/_search
{
"query": {
"match": {
"message": "you know for search"
}
},
"collapse": {
"field": "user.id"
},
"rescore" : {
"window_size" : 50,
"query" : {
"rescore_query" : {
"match_phrase": {
"message": "you know for search"
}
},
"query_weight" : 0.3,
"rescore_query_weight" : 1.4
}
}
}
----
// TEST[setup:my_index]

WARNING: Rescorers are not applied to <<inner-hits, `inner hits`>>.

[discrete]
[[second-level-of-collapsing]]
==== Second level of collapsing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ When exposing pagination to users, `window_size` should remain constant as each

Depending on how your model is trained, it’s possible that the model will return negative scores for documents. While negative scores are not allowed from first-stage retrieval and ranking, it is possible to use them in the LTR rescorer.

[discrete]
[[learning-to-rank-rescorer-limitations-field-collapsing]]
====== Compatibility with field collapsing

LTR rescorers are not compatible with the <<collapse-search-results, collapse feature>>.

[discrete]
[[learning-to-rank-rescorer-limitations-term-statistics]]
====== Term statistics as features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,24 +281,6 @@ setup:
- match: { hits.hits.1.fields.numeric_group: [1] }
- match: { hits.hits.1.sort: [1] }

---
"field collapsing and rescore":

- do:
catch: /cannot use \`collapse\` in conjunction with \`rescore\`/
search:
rest_total_hits_as_int: true
index: test
body:
collapse: { field: numeric_group }
rescore:
window_size: 20
query:
rescore_query:
match_all: {}
query_weight: 1
rescore_query_weight: 2

---
"no hits and inner_hits":

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.rescore.QueryRescoreMode;
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
import static org.elasticsearch.common.lucene.search.function.CombineFunction.REPLACE;
Expand Down Expand Up @@ -845,4 +848,134 @@ public void testRescorePhaseWithInvalidSort() throws Exception {
}
);
}

record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {}

public void testRescoreAfterCollapse() throws Exception {
assertAcked(prepareCreate("test").setMapping("group", "type=keyword", "shouldFilter", "type=boolean"));
ensureGreen("test");
GroupDoc[] groupDocs = new GroupDoc[] {
new GroupDoc("1", "c", 200, 1, false),
new GroupDoc("2", "a", 1, 10, true),
new GroupDoc("3", "b", 2, 30, false),
new GroupDoc("4", "c", 1, 1000, false),
// should be highest on rescore, but filtered out during collapse
new GroupDoc("5", "b", 1, 40, false),
new GroupDoc("6", "a", 2, 20, false) };
List<IndexRequestBuilder> requests = new ArrayList<>();
for (var groupDoc : groupDocs) {
requests.add(
client().prepareIndex("test")
.setId(groupDoc.id())
.setRouting(groupDoc.group())
.setSource(
"group",
groupDoc.group(),
"firstPassScore",
groupDoc.firstPassScore(),
"secondPassScore",
groupDoc.secondPassScore(),
"shouldFilter",
groupDoc.shouldFilter()
)
);
}
indexRandom(true, requests);

var request = client().prepareSearch("test")
.setQuery(fieldValueScoreQuery("firstPassScore"))
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")))
.setCollapse(new CollapseBuilder("group"));
assertResponse(request, resp -> {
assertThat(resp.getHits().getTotalHits().value, equalTo(5L));
assertThat(resp.getHits().getHits().length, equalTo(3));

SearchHit hit1 = resp.getHits().getAt(0);
assertThat(hit1.getId(), equalTo("1"));
assertThat(hit1.getScore(), equalTo(201F));
assertThat(hit1.field("group").getValues().size(), equalTo(1));
assertThat(hit1.field("group").getValues().get(0), equalTo("c"));

SearchHit hit2 = resp.getHits().getAt(1);
assertThat(hit2.getId(), equalTo("3"));
assertThat(hit2.getScore(), equalTo(32F));
assertThat(hit2.field("group").getValues().size(), equalTo(1));
assertThat(hit2.field("group").getValues().get(0), equalTo("b"));

SearchHit hit3 = resp.getHits().getAt(2);
assertThat(hit3.getId(), equalTo("6"));
assertThat(hit3.getScore(), equalTo(22F));
assertThat(hit3.field("group").getValues().size(), equalTo(1));
assertThat(hit3.field("group").getValues().get(0), equalTo("a"));
});
}

public void testRescoreAfterCollapseRandom() throws Exception {
assertAcked(prepareCreate("test").setMapping("group", "type=keyword", "shouldFilter", "type=boolean"));
ensureGreen("test");
int numGroups = randomIntBetween(1, 100);
int numDocs = atLeast(100);
GroupDoc[] groups = new GroupDoc[numGroups];
int numHits = 0;
List<IndexRequestBuilder> requests = new ArrayList<>();
for (int i = 0; i < numDocs; i++) {
int group = randomIntBetween(0, numGroups - 1);
boolean shouldFilter = rarely();
String id = randomUUID();
float firstPassScore = randomFloat();
float secondPassScore = randomFloat();
float bestScore = groups[group] == null ? -1 : groups[group].firstPassScore;
var groupDoc = new GroupDoc(id, Integer.toString(group), firstPassScore, secondPassScore, shouldFilter);
if (shouldFilter == false) {
numHits++;
if (firstPassScore > bestScore) {
groups[group] = groupDoc;
}
}
requests.add(
client().prepareIndex("test")
.setId(groupDoc.id())
.setRouting(groupDoc.group())
.setSource(
"group",
groupDoc.group(),
"firstPassScore",
groupDoc.firstPassScore(),
"secondPassScore",
groupDoc.secondPassScore(),
"shouldFilter",
groupDoc.shouldFilter()
)
);
}
indexRandom(true, requests);

GroupDoc[] sortedGroups = Arrays.stream(groups)
.filter(g -> g != null)
.sorted(Comparator.comparingDouble(GroupDoc::secondPassScore))
.toArray(GroupDoc[]::new);

var request = client().prepareSearch("test")
.setQuery(fieldValueScoreQuery("firstPassScore"))
.addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).windowSize(numGroups))
.setCollapse(new CollapseBuilder("group"))
.setSize(Math.min(numGroups, 10));
long expectedNumHits = numHits;
assertResponse(request, resp -> {
assertThat(resp.getHits().getTotalHits().value, equalTo((long) expectedNumHits));
for (int pos = 0; pos < resp.getHits().getHits().length; pos++) {
SearchHit hit = resp.getHits().getAt(pos);
assertThat(hit.getId(), equalTo(groups[pos].id()));
int group = Integer.valueOf(hit.field("group").getValue());
assertThat(group, equalTo(sortedGroups[pos].group()));
assertThat(hit.getScore(), equalTo(sortedGroups[pos].secondPassScore));
}
});
}

private QueryBuilder fieldValueScoreQuery(String scoreField) {
return functionScoreQuery(termQuery("shouldFilter", false), ScoreFunctionBuilders.fieldValueFactorFunction(scoreField)).boostMode(
CombineFunction.REPLACE
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,6 @@ public ActionRequestValidationException validate() {
validationException
);
}
if (source.collapse() != null && source.rescores() != null && source.rescores().isEmpty() == false) {
validationException = addValidationError("cannot use `collapse` in conjunction with `rescore`", validationException);
}
if (source.storedFields() != null) {
if (source.storedFields().fetchFields() == false) {
if (source.fetchSource() != null && source.fetchSource().fetchSource()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1455,9 +1455,6 @@ private static void validateSearchSource(SearchSourceBuilder source, boolean has
if (hasScroll) {
throw new IllegalArgumentException("cannot use `collapse` in a scroll context");
}
if (source.rescores() != null && source.rescores().isEmpty() == false) {
throw new IllegalArgumentException("cannot use `collapse` in conjunction with `rescore`");
}
}
if (source.slice() != null) {
if (source.pointInTimeBuilder() == null && (hasScroll == false)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,6 @@ static CollectorManager<Collector, QueryPhaseResult> createQueryPhaseCollectorMa
searchContext.scrollContext(),
searchContext.numberOfShards()
);
} else if (searchContext.collapse() != null) {
boolean trackScores = searchContext.sort() == null || searchContext.trackScores();
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
return forCollapsing(
postFilterWeight,
terminateAfterChecker,
aggsCollectorManager,
searchContext.minimumScore(),
searchContext.getProfilers() != null,
searchContext.collapse(),
searchContext.sort(),
numDocs,
trackScores,
searchContext.searchAfter()
);
} else {
int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
final boolean rescore = searchContext.rescore().isEmpty() == false;
Expand All @@ -280,21 +265,37 @@ static CollectorManager<Collector, QueryPhaseResult> createQueryPhaseCollectorMa
numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
}
}
return new WithHits(
postFilterWeight,
terminateAfterChecker,
aggsCollectorManager,
searchContext.minimumScore(),
searchContext.getProfilers() != null,
reader,
query,
searchContext.sort(),
searchContext.searchAfter(),
numDocs,
searchContext.trackScores(),
searchContext.trackTotalHitsUpTo(),
hasFilterCollector
);
if (searchContext.collapse() != null) {
boolean trackScores = searchContext.sort() == null || searchContext.trackScores();
return forCollapsing(
postFilterWeight,
terminateAfterChecker,
aggsCollectorManager,
searchContext.minimumScore(),
searchContext.getProfilers() != null,
searchContext.collapse(),
searchContext.sort(),
numDocs,
trackScores,
searchContext.searchAfter()
);
} else {
return new WithHits(
postFilterWeight,
terminateAfterChecker,
aggsCollectorManager,
searchContext.minimumScore(),
searchContext.getProfilers() != null,
reader,
query,
searchContext.sort(),
searchContext.searchAfter(),
numDocs,
searchContext.trackScores(),
searchContext.trackTotalHitsUpTo(),
hasFilterCollector
);
}
}
}

Expand Down
Loading

0 comments on commit 66f4754

Please sign in to comment.