diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index c96187a65..2c79c56e5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -148,11 +148,6 @@ public Collection getSubQueries() { return Collections.unmodifiableCollection(subQueries); } - public void addSubQuery(final Query query) { - Objects.requireNonNull(subQueries, "collection of queries must not be null"); - subQueries.add(query); - } - /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index c838a7d8e..7b9224c39 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -71,7 +71,7 @@ public boolean searchWith( query = extractHybridQuery(searchContext, query); return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } - validateHybridQuery(query); + validateQuery(query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } @@ -112,7 +112,7 @@ && mightBeWrappedHybridQuery(query) return query; } - private void validateHybridQuery(final Query query) { + private void validateQuery(final Query query) { if (query instanceof BooleanQuery) { List booleanClauses = ((BooleanQuery) query).clauses(); for (BooleanClause booleanClause : booleanClauses) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 176628eb8..ce622fb2d 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -51,6 +51,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user"; private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; @@ -344,7 +345,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { indexName, buildIndexConfiguration( Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), - List.of("user"), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), 1 ), "" diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index b8dc8fa4a..fb062a4ef 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -5,8 +5,10 @@ package org.opensearch.neuralsearch.search.query; +import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -21,6 +23,7 @@ import java.util.ArrayList; import java.util.LinkedList; import java.util.List; +import java.util.Set; import lombok.SneakyThrows; @@ -30,6 +33,8 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -38,11 +43,13 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -83,7 +90,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -126,6 +134,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -152,7 +161,8 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -196,6 +206,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.queryResult()).thenReturn(new QuerySearchResult()); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -205,8 +216,6 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() Query query = termSubQuery.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); - MapperService mapperService = mock(MapperService.class); - when(searchContext.mapperService()).thenReturn(mapperService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -315,7 +324,8 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -365,6 +375,7 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); when(searchContext.indexShard()).thenReturn(indexShard); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean(); @@ -409,6 +420,223 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes releaseResources(directory, w, reader); } + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + MapperService mapperService = mock(MapperService.class); + when(mapperService.hasNested()).thenReturn(false); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + TermQueryBuilder termQuery3 = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(queryBuilder).should(termQuery3); + + Query query = boolQueryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> hybridQueryPhaseSearcher.searchWith( + searchContext, + contextIndexSearcher, + query, + collectors, + hasFilterCollector, + hasTimeout + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + containsString("hybrid query must be a top level query and cannot be wrapped into other queries") + ); + + releaseResources(directory, w, reader); + } + + @SneakyThrows + public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + + MapperService mapperService = createMapperService(mapping(b -> { + b.startObject("field"); + b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject(); + b.endObject(); + b.startObject("field2"); + b.field("type", "nested"); + b.endObject(); + })); + + /*MapperService mapperService = createMapperService( + fieldMapping( + b -> b.field("type", "text") + .field("fielddata", true) + .startObject("fielddata_frequency_filter") + .field("min", 2d) + .field("min_segment_size", 1000) + .endObject() + .field("type", "nested") + ) + );*/ + + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + // when(mapperService.hasNested()).thenReturn(true); + when(mockQueryShardContext.getMapperService()).thenReturn(mapperService); + when(mockQueryShardContext.simpleMatchToIndexNames(anyString())).thenReturn(Set.of(TEXT_FIELD_NAME)); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + int docId4 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId4, TEST_DOC_TEXT4, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.size()).thenReturn(4); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1)); + queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2)); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD) + .add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER); + Query query = builder.build(); + + when(searchContext.query()).thenReturn(query); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + assertNotNull(querySearchResult.topDocs()); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + TopDocs topDocs = topDocsAndMaxScore.topDocs; + assertTrue(topDocs.totalHits.value > 0); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertTrue(scoreDocs.length > 0); + assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); + assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); + List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + + TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); + List expectedIds1 = List.of(docId1); + assertQueryResults(subQueryTopDocs1, expectedIds1, reader); + + TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); + List expectedIds2 = List.of(); + assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + + releaseResources(directory, w, reader); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value);