Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Handle pagination_depth when from =0 #1136

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
- Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132))
### Infrastructure
- Update batch related tests to use batch_size in processor & refactor BWC version check ([#852](https://github.com/opensearch-project/neural-search/pull/852))
- Fix CI for JDK upgrade towards 21 ([#835](https://github.com/opensearch-project/neural-search/pull/835))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
private Integer paginationDepth;

static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
private final static int DEFAULT_PAGINATION_DEPTH = 10;
private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 0;

public HybridQueryBuilder(StreamInput in) throws IOException {
Expand Down Expand Up @@ -167,7 +166,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException {
float boost = AbstractQueryBuilder.DEFAULT_BOOST;

int paginationDepth = DEFAULT_PAGINATION_DEPTH;
Integer paginationDepth = null;
final List<QueryBuilder> queries = new ArrayList<>();
String queryName = null;

Expand Down Expand Up @@ -324,7 +323,7 @@ private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, Quer
return queries;
}

private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) {
private static void validatePaginationDepth(final Integer paginationDepth, final QueryShardContext queryShardContext) {
if (Objects.isNull(paginationDepth)) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,19 @@ private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchRe
*/
private static int getSubqueryResultsRetrievalSize(final SearchContext searchContext) {
HybridQuery hybridQuery = unwrapHybridQuery(searchContext);
int paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();
Integer paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();

// Switch to from+size retrieval size during standard hybrid query execution.
if (searchContext.from() == 0) {
return searchContext.size();
// Pagination is expected to work only when pagination_depth is provided in the search request.
if (Objects.isNull(paginationDepth) && searchContext.from() > 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth param is missing in the search request"));
}
log.info("pagination_depth is {}", paginationDepth);
return paginationDepth;

if (Objects.nonNull(paginationDepth)) {
return paginationDepth;
}

// Switch to from+size retrieval size during standard hybrid query execution where from is 0.
return searchContext.size();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
.endObject()
.endObject()
.endArray()
.field("pagination_depth", 10)
.endObject();

NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,40 +870,6 @@ public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSucc
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful() {
try {
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder();
hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder());

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
hybridQueryBuilderOnlyMatchAll,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null,
null,
false,
null,
2
);

assertEquals(2, getHitCount(searchResponseAsMap));
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(4, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
} finally {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE);
}
}

@SneakyThrows
public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() {
try {
Expand All @@ -912,6 +878,7 @@ public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() {
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder();
hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder());
hybridQueryBuilderOnlyMatchAll.paginationDepth(10);

ResponseException responseException = assertThrows(
ResponseException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build();
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build();

HybridQuery hybridQueryWithMatchAll = new HybridQuery(
List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)),
Expand Down Expand Up @@ -633,7 +633,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build();
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build();

HybridQuery hybridQueryWithTerm = new HybridQuery(
List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)),
Expand Down Expand Up @@ -1169,4 +1169,41 @@ public void testScrollWithHybridQuery_thenFail() {
illegalArgumentException.getMessage()
);
}

@SneakyThrows
public void testCreateCollectorManager_whenPaginationDepthIsEqualToNullAndFromIsGreaterThanZero_thenFail() {
SearchContext searchContext = mock(SearchContext.class);
// From >0
when(searchContext.from()).thenReturn(5);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1);

HybridQuery hybridQuery = new HybridQuery(
List.of(termSubQuery.toQuery(mockQueryShardContext)),
HybridQueryContext.builder().build() // pagination_depth is set to null
);

when(searchContext.query()).thenReturn(hybridQuery);
ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class);
IndexReader indexReader = mock(IndexReader.class);
when(indexSearcher.getIndexReader()).thenReturn(indexReader);
when(searchContext.searcher()).thenReturn(indexSearcher);
MapperService mapperService = createMapperService();
when(searchContext.mapperService()).thenReturn(mapperService);

Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> classCollectorManagerMap = new HashMap<>();
when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap);
when(searchContext.shouldUseConcurrentSearch()).thenReturn(false);

IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> HybridCollectorManager.createHybridCollectorManager(searchContext)
);
assertEquals(
String.format(Locale.ROOT, "pagination_depth param is missing in the search request"),
illegalArgumentException.getMessage()
);
}
}
Loading