Skip to content

Commit

Permalink
Handle pagination_depth when from =0 (#1132)
Browse files Browse the repository at this point in the history
* Handle pagination_depth when from =0

Signed-off-by: Varun Jain <[email protected]>

* Add changelog

Signed-off-by: Varun Jain <[email protected]>

* Remove unecessary logs

Signed-off-by: Varun Jain <[email protected]>

---------

Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun authored Jan 22, 2025
1 parent a4ca7b4 commit 3dbdcba
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Update NeuralQueryBuilder doEquals() and doHashCode() to cater the missing parameters information ([#1045](https://github.com/opensearch-project/neural-search/pull/1045)).
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
- Handle pagination_depth when from =0 and removes default value of pagination_depth ([#1132](https://github.com/opensearch-project/neural-search/pull/1132))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu
private Integer paginationDepth;

static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
private final static int DEFAULT_PAGINATION_DEPTH = 10;
private static final int LOWER_BOUND_OF_PAGINATION_DEPTH = 0;

public HybridQueryBuilder(StreamInput in) throws IOException {
Expand Down Expand Up @@ -167,7 +166,7 @@ protected Query doToQuery(QueryShardContext queryShardContext) throws IOExceptio
public static HybridQueryBuilder fromXContent(XContentParser parser) throws IOException {
float boost = AbstractQueryBuilder.DEFAULT_BOOST;

int paginationDepth = DEFAULT_PAGINATION_DEPTH;
Integer paginationDepth = null;
final List<QueryBuilder> queries = new ArrayList<>();
String queryName = null;

Expand Down Expand Up @@ -324,7 +323,7 @@ private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, Quer
return queries;
}

private static void validatePaginationDepth(final int paginationDepth, final QueryShardContext queryShardContext) {
private static void validatePaginationDepth(final Integer paginationDepth, final QueryShardContext queryShardContext) {
if (Objects.isNull(paginationDepth)) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,19 @@ private ReduceableSearchResult reduceSearchResults(final List<ReduceableSearchRe
*/
private static int getSubqueryResultsRetrievalSize(final SearchContext searchContext) {
HybridQuery hybridQuery = unwrapHybridQuery(searchContext);
int paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();
Integer paginationDepth = hybridQuery.getQueryContext().getPaginationDepth();

// Switch to from+size retrieval size during standard hybrid query execution.
if (searchContext.from() == 0) {
return searchContext.size();
// Pagination is expected to work only when pagination_depth is provided in the search request.
if (Objects.isNull(paginationDepth) && searchContext.from() > 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "pagination_depth param is missing in the search request"));
}
log.info("pagination_depth is {}", paginationDepth);
return paginationDepth;

if (Objects.nonNull(paginationDepth)) {
return paginationDepth;
}

// Switch to from+size retrieval size during standard hybrid query execution where from is 0.
return searchContext.size();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ public void testFromXContent_whenMultipleSubQueries_thenBuildSuccessfully() {
.endObject()
.endObject()
.endArray()
.field("pagination_depth", 10)
.endObject();

NamedXContentRegistry namedXContentRegistry = new NamedXContentRegistry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -870,40 +870,6 @@ public void testHybridQuery_whenFromAndPaginationDepthIsGreaterThanZero_thenSucc
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testHybridQuery_whenFromIsGreaterThanZeroAndPaginationDepthIsNotSent_thenSuccessful() {
try {
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder();
hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder());

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
hybridQueryBuilderOnlyMatchAll,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null,
null,
false,
null,
2
);

assertEquals(2, getHitCount(searchResponseAsMap));
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(4, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
} finally {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE);
}
}

@SneakyThrows
public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() {
try {
Expand All @@ -912,6 +878,7 @@ public void testHybridQuery_whenFromIsGreaterThanTotalResultCount_thenFail() {
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
HybridQueryBuilder hybridQueryBuilderOnlyMatchAll = new HybridQueryBuilder();
hybridQueryBuilderOnlyMatchAll.add(new MatchAllQueryBuilder());
hybridQueryBuilderOnlyMatchAll.paginationDepth(10);

ResponseException responseException = assertThrows(
ResponseException.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ public void testReduce_whenMatchedDocsAndSortingIsApplied_thenSuccessful() {
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build();
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build();

HybridQuery hybridQueryWithMatchAll = new HybridQuery(
List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)),
Expand Down Expand Up @@ -633,7 +633,7 @@ public void testReduceWithConcurrentSegmentSearch_whenMultipleCollectorsMatchedD
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().paginationDepth(10).build();
HybridQueryContext hybridQueryContext = HybridQueryContext.builder().build();

HybridQuery hybridQueryWithTerm = new HybridQuery(
List.of(QueryBuilders.matchAllQuery().toQuery(mockQueryShardContext)),
Expand Down Expand Up @@ -1169,4 +1169,41 @@ public void testScrollWithHybridQuery_thenFail() {
illegalArgumentException.getMessage()
);
}

@SneakyThrows
public void testCreateCollectorManager_whenPaginationDepthIsEqualToNullAndFromIsGreaterThanZero_thenFail() {
SearchContext searchContext = mock(SearchContext.class);
// From >0
when(searchContext.from()).thenReturn(5);
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);
TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1);

HybridQuery hybridQuery = new HybridQuery(
List.of(termSubQuery.toQuery(mockQueryShardContext)),
HybridQueryContext.builder().build() // pagination_depth is set to null
);

when(searchContext.query()).thenReturn(hybridQuery);
ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class);
IndexReader indexReader = mock(IndexReader.class);
when(indexSearcher.getIndexReader()).thenReturn(indexReader);
when(searchContext.searcher()).thenReturn(indexSearcher);
MapperService mapperService = createMapperService();
when(searchContext.mapperService()).thenReturn(mapperService);

Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> classCollectorManagerMap = new HashMap<>();
when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap);
when(searchContext.shouldUseConcurrentSearch()).thenReturn(false);

IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> HybridCollectorManager.createHybridCollectorManager(searchContext)
);
assertEquals(
String.format(Locale.ROOT, "pagination_depth param is missing in the search request"),
illegalArgumentException.getMessage()
);
}
}

0 comments on commit 3dbdcba

Please sign in to comment.