diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 2c79c56e5..c96187a65 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -148,6 +148,11 @@ public Collection getSubQueries() { return Collections.unmodifiableCollection(subQueries); } + public void addSubQuery(final Query query) { + Objects.requireNonNull(subQueries, "collection of queries must not be null"); + subQueries.add(query); + } + /** * Create the Weight used to score this query * diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index f65e30222..c838a7d8e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -19,12 +19,17 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.SeqNoFieldMapper; +import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; @@ -48,6 +53,8 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { + final static int MAX_NESTED_SUBQUERY_LIMIT = 20; + public HybridQueryPhaseSearcher() { super(); } @@ -55,17 +62,79 @@ public HybridQueryPhaseSearcher() { public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, - final Query query, + Query query, final LinkedList collectors, final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (query instanceof HybridQuery) { + if (isHybridQuery(query, searchContext)) { + query = extractHybridQuery(searchContext, query); return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } + validateHybridQuery(query); return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } + private boolean isHybridQuery(final Query query, final SearchContext searchContext) { + if (query instanceof HybridQuery) { + return true; + } else if (hasNestedFieldOrNestedDocs(query, searchContext) && mightBeWrappedHybridQuery(query)) { + BooleanQuery booleanQuery = (BooleanQuery) query; + return booleanQuery.clauses() + .stream() + .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .allMatch( + clause -> clause.getOccur() == BooleanClause.Occur.FILTER + && clause.getQuery() instanceof FieldExistsQuery + && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()) + ); + } + return false; + } + + private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); + } + + private boolean mightBeWrappedHybridQuery(final Query query) { + return query instanceof BooleanQuery + && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); + } + + private Query extractHybridQuery(final SearchContext searchContext, final Query query) { + if (hasNestedFieldOrNestedDocs(query, searchContext) + && mightBeWrappedHybridQuery(query) + && ((BooleanQuery) query).clauses().size() > 0) { + // extract hybrid query and replace bool with hybrid query + List booleanClauses = ((BooleanQuery) query).clauses(); + return booleanClauses.stream().findFirst().get().getQuery(); + } + return query; + } + + private void validateHybridQuery(final Query query) { + if (query instanceof BooleanQuery) { + List booleanClauses = ((BooleanQuery) query).clauses(); + for (BooleanClause booleanClause : booleanClauses) { + validateNestedBooleanQuery(booleanClause.getQuery(), 1); + } + } + } + + private void validateNestedBooleanQuery(final Query query, int level) { + if (query instanceof HybridQuery) { + throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries"); + } + if (level >= MAX_NESTED_SUBQUERY_LIMIT) { + throw new IllegalStateException("reached max nested query limit, cannot process query"); + } + if (query instanceof BooleanQuery) { + for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) { + validateNestedBooleanQuery(booleanClause.getQuery(), level + 1); + } + } + } + @VisibleForTesting protected boolean searchWithCollector( final SearchContext searchContext, diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 33cdff9a0..61c6c13df 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -523,7 +523,16 @@ protected boolean checkComplete(Map node) { } @SneakyThrows - private String buildIndexConfiguration(List knnFieldConfigs, int numberOfShards) { + protected String buildIndexConfiguration(final List knnFieldConfigs, final int numberOfShards) { + return buildIndexConfiguration(knnFieldConfigs, Collections.emptyList(), numberOfShards); + } + + @SneakyThrows + protected String buildIndexConfiguration( + final List knnFieldConfigs, + final List nestedFields, + final int numberOfShards + ) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject("settings") @@ -544,6 +553,11 @@ private String buildIndexConfiguration(List knnFieldConfigs, int .endObject() .endObject(); } + + for (String nestedField : nestedFields) { + xContentBuilder.startObject(nestedField).field("type", "nested").endObject(); + } + xContentBuilder.endObject().endObject().endObject(); return xContentBuilder.toString(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 171d2f4a4..176628eb8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.query; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; @@ -21,6 +23,7 @@ import org.junit.After; import org.junit.Before; +import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -35,6 +38,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = + "test-neural-multi-doc-nested-type--single-shard-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -191,7 +196,7 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( } @SneakyThrows - public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() { + public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenFail() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); @@ -202,23 +207,71 @@ public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder); + ResponseException exceptionNoNestedTypes = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, boolQueryBuilder, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE)) + ); + + org.hamcrest.MatcherAssert.assertThat( + exceptionNoNestedTypes.getMessage(), + allOf( + containsString("hybrid query must be a top level query and cannot be wrapped into other queries"), + containsString("illegal_argument_exception") + ) + ); + + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + ResponseException exceptionQWithNestedTypes = expectThrows( + ResponseException.class, + () -> search( + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + boolQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ) + ); + + org.hamcrest.MatcherAssert.assertThat( + exceptionQWithNestedTypes.getMessage(), + allOf( + containsString("hybrid query must be a top level query and cannot be wrapped into other queries"), + containsString("illegal_argument_exception") + ) + ); + } + + @SneakyThrows + public void testIndexWithNestedFields_whenHybridQuery_thenSuccess() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + TermQueryBuilder termQuery2Builder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT2); + HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder(); + hybridQueryBuilderOnlyTerm.add(termQueryBuilder); + hybridQueryBuilderOnlyTerm.add(termQuery2Builder); + Map searchResponseAsMap = search( - TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, - boolQueryBuilder, + TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD, + hybridQueryBuilderOnlyTerm, null, 10, Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertTrue(getHitCount(searchResponseAsMap) > 0); + assertEquals(0, getHitCount(searchResponseAsMap)); assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f); + assertEquals(0.0f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); Map total = getTotalHits(searchResponseAsMap); assertNotNull(total.get("value")); - assertTrue((int) total.get("value") > 0); + assertEquals(0, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { prepareKnnIndex( @@ -284,6 +337,21 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { ); addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD); } + + if (TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD.equals(indexName) + && !indexExists(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of("user"), + 1 + ), + "" + ); + + addDocsToIndex(TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD); + } } private void addDocsToIndex(final String testMultiDocIndexName) { diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e9c55cc54..b8dc8fa4a 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -41,6 +41,7 @@ import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -204,6 +205,8 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() Query query = termSubQuery.toQuery(mockQueryShardContext); when(searchContext.query()).thenReturn(query); + MapperService mapperService = mock(MapperService.class); + when(searchContext.mapperService()).thenReturn(mapperService); hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); @@ -217,7 +220,8 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); Directory directory = newDirectory(); @@ -265,6 +269,7 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { QuerySearchResult querySearchResult = new QuerySearchResult(); when(searchContext.queryResult()).thenReturn(querySearchResult); when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); LinkedList collectors = new LinkedList<>(); boolean hasFilterCollector = randomBoolean();