From 66f47541ad6cc7227df7e2337aaa54bda9e0f227 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Tue, 23 Apr 2024 16:27:00 +0100 Subject: [PATCH] Allow rescorer with field collapsing This change adds the support for rescoring collapsed documents. The rescoring is applied on the top document per group on each shard. Closes #27243 --- .../collapse-search-results.asciidoc | 66 ++++++++- .../learning-to-rank-search-usage.asciidoc | 6 - .../test/search/110_field_collapsing.yml | 18 --- .../search/functionscore/QueryRescorerIT.java | 133 ++++++++++++++++++ .../action/search/SearchRequest.java | 3 - .../elasticsearch/search/SearchService.java | 3 - .../query/QueryPhaseCollectorManager.java | 61 ++++---- .../search/rescore/RescorePhase.java | 44 +++++- 8 files changed, 269 insertions(+), 65 deletions(-) diff --git a/docs/reference/search/search-your-data/collapse-search-results.asciidoc b/docs/reference/search/search-your-data/collapse-search-results.asciidoc index ffb6238c89e10..1f4161eced4f5 100644 --- a/docs/reference/search/search-your-data/collapse-search-results.asciidoc +++ b/docs/reference/search/search-your-data/collapse-search-results.asciidoc @@ -47,7 +47,7 @@ NOTE: Collapsing is applied to the top hits only and does not affect aggregation [[expand-collapse-results]] ==== Expand collapse results -It is also possible to expand each collapsed top hits with the `inner_hits` option. +It is also possible to expand each collapsed top hits with the <> option. [source,console] ---- @@ -86,7 +86,7 @@ GET /my-index-000001/_search See <> for the complete list of supported options and the format of the response. -It is also possible to request multiple `inner_hits` for each collapsed hit. This can be useful when you want to get +It is also possible to request multiple <> for each collapsed hit. This can be useful when you want to get multiple representations of the collapsed hits. [source,console] @@ -145,8 +145,7 @@ The `max_concurrent_group_searches` request parameter can be used to control the maximum number of concurrent searches allowed in this phase. The default is based on the number of data nodes and the default search thread pool size. -WARNING: `collapse` cannot be used in conjunction with <> or -<>. +WARNING: `collapse` cannot be used in conjunction with <>. [discrete] [[collapsing-with-search-after]] @@ -175,6 +174,65 @@ GET /my-index-000001/_search ---- // TEST[setup:my_index] +[discrete] +[[expand-collapse-results]] +==== Rescore collapse results + +You can use field collapsing alongside the <> search parameter. +Rescorers runs on every shard for the top-ranked document per collapsed field. +To maintain a reliable order, it is recommended to cluster documents sharing the same collapse +field value on one shard. +This is achieved by assigning the collapse field value as the <> +during indexing: + +[source,console] +---- +POST /my-index-000001/_doc?routing=xyz <1> +{ + "@timestamp": "2099-11-15T13:12:00", + "message": "You know for search!", + "user.id": "xyz" +} +---- +// TEST[setup:my_index] +<1> Assign routing with the collapse field value (`user.id`). + +By doing this, you guarantee that only one top document per +collapse key gets rescored globally. + +The following request utilizes field collapsing on the `user.id` +field and then rescores the top groups with a <>: + +[source,console] +---- +GET /my-index-000001/_search +{ + "query": { + "match": { + "message": "you know for search" + } + }, + "collapse": { + "field": "user.id" + }, + "rescore" : { + "window_size" : 50, + "query" : { + "rescore_query" : { + "match_phrase": { + "message": "you know for search" + } + }, + "query_weight" : 0.3, + "rescore_query_weight" : 1.4 + } + } +} +---- +// TEST[setup:my_index] + +WARNING: Rescorers are not applied to <>. + [discrete] [[second-level-of-collapsing]] ==== Second level of collapsing diff --git a/docs/reference/search/search-your-data/learning-to-rank-search-usage.asciidoc b/docs/reference/search/search-your-data/learning-to-rank-search-usage.asciidoc index 1d040a116ad9a..2e9693eff0451 100644 --- a/docs/reference/search/search-your-data/learning-to-rank-search-usage.asciidoc +++ b/docs/reference/search/search-your-data/learning-to-rank-search-usage.asciidoc @@ -64,12 +64,6 @@ When exposing pagination to users, `window_size` should remain constant as each Depending on how your model is trained, it’s possible that the model will return negative scores for documents. While negative scores are not allowed from first-stage retrieval and ranking, it is possible to use them in the LTR rescorer. -[discrete] -[[learning-to-rank-rescorer-limitations-field-collapsing]] -====== Compatibility with field collapsing - -LTR rescorers are not compatible with the <>. - [discrete] [[learning-to-rank-rescorer-limitations-term-statistics]] ====== Term statistics as features diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/110_field_collapsing.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/110_field_collapsing.yml index 76207fd76e45b..c10d3c48259f1 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/110_field_collapsing.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search/110_field_collapsing.yml @@ -281,24 +281,6 @@ setup: - match: { hits.hits.1.fields.numeric_group: [1] } - match: { hits.hits.1.sort: [1] } ---- -"field collapsing and rescore": - - - do: - catch: /cannot use \`collapse\` in conjunction with \`rescore\`/ - search: - rest_total_hits_as_int: true - index: test - body: - collapse: { field: numeric_group } - rescore: - window_size: 20 - query: - rescore_query: - match_all: {} - query_weight: 1 - rescore_query_weight: 2 - --- "no hits and inner_hits": diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java index 110ac76849e0b..79a149255d997 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.rescore.QueryRescoreMode; import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.SortBuilders; @@ -30,8 +31,10 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.List; import static org.elasticsearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS; import static org.elasticsearch.common.lucene.search.function.CombineFunction.REPLACE; @@ -845,4 +848,134 @@ public void testRescorePhaseWithInvalidSort() throws Exception { } ); } + + record GroupDoc(String id, String group, float firstPassScore, float secondPassScore, boolean shouldFilter) {} + + public void testRescoreAfterCollapse() throws Exception { + assertAcked(prepareCreate("test").setMapping("group", "type=keyword", "shouldFilter", "type=boolean")); + ensureGreen("test"); + GroupDoc[] groupDocs = new GroupDoc[] { + new GroupDoc("1", "c", 200, 1, false), + new GroupDoc("2", "a", 1, 10, true), + new GroupDoc("3", "b", 2, 30, false), + new GroupDoc("4", "c", 1, 1000, false), + // should be highest on rescore, but filtered out during collapse + new GroupDoc("5", "b", 1, 40, false), + new GroupDoc("6", "a", 2, 20, false) }; + List requests = new ArrayList<>(); + for (var groupDoc : groupDocs) { + requests.add( + client().prepareIndex("test") + .setId(groupDoc.id()) + .setRouting(groupDoc.group()) + .setSource( + "group", + groupDoc.group(), + "firstPassScore", + groupDoc.firstPassScore(), + "secondPassScore", + groupDoc.secondPassScore(), + "shouldFilter", + groupDoc.shouldFilter() + ) + ); + } + indexRandom(true, requests); + + var request = client().prepareSearch("test") + .setQuery(fieldValueScoreQuery("firstPassScore")) + .addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore"))) + .setCollapse(new CollapseBuilder("group")); + assertResponse(request, resp -> { + assertThat(resp.getHits().getTotalHits().value, equalTo(5L)); + assertThat(resp.getHits().getHits().length, equalTo(3)); + + SearchHit hit1 = resp.getHits().getAt(0); + assertThat(hit1.getId(), equalTo("1")); + assertThat(hit1.getScore(), equalTo(201F)); + assertThat(hit1.field("group").getValues().size(), equalTo(1)); + assertThat(hit1.field("group").getValues().get(0), equalTo("c")); + + SearchHit hit2 = resp.getHits().getAt(1); + assertThat(hit2.getId(), equalTo("3")); + assertThat(hit2.getScore(), equalTo(32F)); + assertThat(hit2.field("group").getValues().size(), equalTo(1)); + assertThat(hit2.field("group").getValues().get(0), equalTo("b")); + + SearchHit hit3 = resp.getHits().getAt(2); + assertThat(hit3.getId(), equalTo("6")); + assertThat(hit3.getScore(), equalTo(22F)); + assertThat(hit3.field("group").getValues().size(), equalTo(1)); + assertThat(hit3.field("group").getValues().get(0), equalTo("a")); + }); + } + + public void testRescoreAfterCollapseRandom() throws Exception { + assertAcked(prepareCreate("test").setMapping("group", "type=keyword", "shouldFilter", "type=boolean")); + ensureGreen("test"); + int numGroups = randomIntBetween(1, 100); + int numDocs = atLeast(100); + GroupDoc[] groups = new GroupDoc[numGroups]; + int numHits = 0; + List requests = new ArrayList<>(); + for (int i = 0; i < numDocs; i++) { + int group = randomIntBetween(0, numGroups - 1); + boolean shouldFilter = rarely(); + String id = randomUUID(); + float firstPassScore = randomFloat(); + float secondPassScore = randomFloat(); + float bestScore = groups[group] == null ? -1 : groups[group].firstPassScore; + var groupDoc = new GroupDoc(id, Integer.toString(group), firstPassScore, secondPassScore, shouldFilter); + if (shouldFilter == false) { + numHits++; + if (firstPassScore > bestScore) { + groups[group] = groupDoc; + } + } + requests.add( + client().prepareIndex("test") + .setId(groupDoc.id()) + .setRouting(groupDoc.group()) + .setSource( + "group", + groupDoc.group(), + "firstPassScore", + groupDoc.firstPassScore(), + "secondPassScore", + groupDoc.secondPassScore(), + "shouldFilter", + groupDoc.shouldFilter() + ) + ); + } + indexRandom(true, requests); + + GroupDoc[] sortedGroups = Arrays.stream(groups) + .filter(g -> g != null) + .sorted(Comparator.comparingDouble(GroupDoc::secondPassScore)) + .toArray(GroupDoc[]::new); + + var request = client().prepareSearch("test") + .setQuery(fieldValueScoreQuery("firstPassScore")) + .addRescorer(new QueryRescorerBuilder(fieldValueScoreQuery("secondPassScore")).windowSize(numGroups)) + .setCollapse(new CollapseBuilder("group")) + .setSize(Math.min(numGroups, 10)); + long expectedNumHits = numHits; + assertResponse(request, resp -> { + assertThat(resp.getHits().getTotalHits().value, equalTo((long) expectedNumHits)); + for (int pos = 0; pos < resp.getHits().getHits().length; pos++) { + SearchHit hit = resp.getHits().getAt(pos); + assertThat(hit.getId(), equalTo(groups[pos].id())); + int group = Integer.valueOf(hit.field("group").getValue()); + assertThat(group, equalTo(sortedGroups[pos].group())); + assertThat(hit.getScore(), equalTo(sortedGroups[pos].secondPassScore)); + } + }); + } + + private QueryBuilder fieldValueScoreQuery(String scoreField) { + return functionScoreQuery(termQuery("shouldFilter", false), ScoreFunctionBuilders.fieldValueFactorFunction(scoreField)).boostMode( + CombineFunction.REPLACE + ); + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index 5c0db65868dbc..dd204016b7544 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -366,9 +366,6 @@ public ActionRequestValidationException validate() { validationException ); } - if (source.collapse() != null && source.rescores() != null && source.rescores().isEmpty() == false) { - validationException = addValidationError("cannot use `collapse` in conjunction with `rescore`", validationException); - } if (source.storedFields() != null) { if (source.storedFields().fetchFields() == false) { if (source.fetchSource() != null && source.fetchSource().fetchSource()) { diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 5cbb97976dbc9..0432c960081cf 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -1455,9 +1455,6 @@ private static void validateSearchSource(SearchSourceBuilder source, boolean has if (hasScroll) { throw new IllegalArgumentException("cannot use `collapse` in a scroll context"); } - if (source.rescores() != null && source.rescores().isEmpty() == false) { - throw new IllegalArgumentException("cannot use `collapse` in conjunction with `rescore`"); - } } if (source.slice() != null) { if (source.pointInTimeBuilder() == null && (hasScroll == false)) { diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java index 7fd09d3ddfdf1..2286eb2e69f88 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhaseCollectorManager.java @@ -256,21 +256,6 @@ static CollectorManager createQueryPhaseCollectorMa searchContext.scrollContext(), searchContext.numberOfShards() ); - } else if (searchContext.collapse() != null) { - boolean trackScores = searchContext.sort() == null || searchContext.trackScores(); - int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); - return forCollapsing( - postFilterWeight, - terminateAfterChecker, - aggsCollectorManager, - searchContext.minimumScore(), - searchContext.getProfilers() != null, - searchContext.collapse(), - searchContext.sort(), - numDocs, - trackScores, - searchContext.searchAfter() - ); } else { int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); final boolean rescore = searchContext.rescore().isEmpty() == false; @@ -280,21 +265,37 @@ static CollectorManager createQueryPhaseCollectorMa numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); } } - return new WithHits( - postFilterWeight, - terminateAfterChecker, - aggsCollectorManager, - searchContext.minimumScore(), - searchContext.getProfilers() != null, - reader, - query, - searchContext.sort(), - searchContext.searchAfter(), - numDocs, - searchContext.trackScores(), - searchContext.trackTotalHitsUpTo(), - hasFilterCollector - ); + if (searchContext.collapse() != null) { + boolean trackScores = searchContext.sort() == null || searchContext.trackScores(); + return forCollapsing( + postFilterWeight, + terminateAfterChecker, + aggsCollectorManager, + searchContext.minimumScore(), + searchContext.getProfilers() != null, + searchContext.collapse(), + searchContext.sort(), + numDocs, + trackScores, + searchContext.searchAfter() + ); + } else { + return new WithHits( + postFilterWeight, + terminateAfterChecker, + aggsCollectorManager, + searchContext.minimumScore(), + searchContext.getProfilers() != null, + reader, + query, + searchContext.sort(), + searchContext.searchAfter(), + numDocs, + searchContext.trackScores(), + searchContext.trackTotalHitsUpTo(), + hasFilterCollector + ); + } } } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java index 81f079b74c18f..e93829be5fb2f 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java @@ -8,13 +8,18 @@ package org.elasticsearch.search.rescore; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.lucene.grouping.TopFieldGroups; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; /** * Rescore phase of a search request, used to run potentially expensive scoring models against the top matching documents. @@ -24,7 +29,7 @@ public class RescorePhase { private RescorePhase() {} public static void execute(SearchContext context) { - if (context.size() == 0 || context.collapse() != null || context.rescore() == null || context.rescore().isEmpty()) { + if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) { return; } @@ -32,6 +37,11 @@ public static void execute(SearchContext context) { if (topDocs.scoreDocs.length == 0) { return; } + TopFieldGroups topGroups = null; + if (topDocs instanceof TopFieldGroups topFieldGroups) { + assert context.collapse() != null; + topGroups = topFieldGroups; + } try { for (RescoreContext ctx : context.rescore()) { topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx); @@ -39,6 +49,15 @@ public static void execute(SearchContext context) { // here we only assert that this condition is met. assert context.sort() == null && topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore"; } + if (topGroups != null) { + assert context.collapse() != null; + /** + * Since rescorers don't preserve collapsing, we must reconstruct the group and field + * values from the originalTopGroups to create a new {@link TopFieldGroups} from the + * rescored top documents. + */ + topDocs = rewriteTopGroups(topGroups, topDocs); + } context.queryResult() .topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats()); } catch (IOException e) { @@ -46,6 +65,29 @@ public static void execute(SearchContext context) { } } + private static TopFieldGroups rewriteTopGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) { + assert originalTopGroups.fields.length == 1 && SortField.FIELD_SCORE.equals(originalTopGroups.fields[0]) + : "rescore must always sort by score descending"; + Map docIdToGroupValue = new HashMap<>(); + for (int i = 0; i < originalTopGroups.scoreDocs.length; i++) { + docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]); + } + var newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length]; + var newGroupValues = new Object[originalTopGroups.groupValues.length]; + int pos = 0; + for (var doc : rescoredTopDocs.scoreDocs) { + newScoreDocs[pos] = new FieldDoc(doc.doc, doc.score, new Object[] { doc.score }); + newGroupValues[pos++] = docIdToGroupValue.get(doc.doc); + } + return new TopFieldGroups( + originalTopGroups.field, + originalTopGroups.totalHits, + newScoreDocs, + originalTopGroups.fields, + newGroupValues + ); + } + /** * Returns true if the provided docs are sorted by score. */