Skip to content

Commit

Permalink
Added comments, improved exception handling
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 29, 2023
1 parent 19791bb commit e33de9b
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.SeqNoFieldMapper;
import org.opensearch.index.search.NestedHelper;
import org.opensearch.neuralsearch.query.HybridQuery;
Expand All @@ -53,8 +55,6 @@
@Log4j2
public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper {

final static int MAX_NESTED_SUBQUERY_LIMIT = 50;

public HybridQueryPhaseSearcher() {
super();
}
Expand All @@ -71,7 +71,7 @@ public boolean searchWith(
query = extractHybridQuery(searchContext, query);
return searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
validateQuery(query);
validateQuery(searchContext, query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}

Expand Down Expand Up @@ -100,17 +100,18 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte
// }
// }
// }
// TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression
if (query instanceof BooleanQuery == false) {
return false;
}
return ((BooleanQuery) query).clauses()
.stream()
.filter(clause -> clause.getQuery() instanceof HybridQuery == false)
.allMatch(
clause -> clause.getOccur() == BooleanClause.Occur.FILTER
.allMatch(clause -> {
return clause.getOccur() == BooleanClause.Occur.FILTER
&& clause.getQuery() instanceof FieldExistsQuery
&& SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField())
);
&& SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField());
});
}
return false;
}
Expand All @@ -130,20 +131,38 @@ && mightBeWrappedHybridQuery(query)
&& ((BooleanQuery) query).clauses().size() > 0) {
// extract hybrid query and replace bool with hybrid query
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
Query hybridQuery = booleanClauses.stream().findFirst().get().getQuery();
if (!(hybridQuery instanceof HybridQuery)) {
throw new IllegalStateException("cannot find hybrid type query in expected location");
if (booleanClauses.isEmpty() || booleanClauses.get(0).getQuery() instanceof HybridQuery == false) {
throw new IllegalStateException("cannot process hybrid query due to incorrect structure of top level bool query");
}
return hybridQuery;
return booleanClauses.get(0).getQuery();
}
return query;
}

private void validateQuery(final Query query) {
/**
* Validate the query from neural-search plugin point of view. Current main goal for validation is to block cases
* when hybrid query is wrapped into other compound queries.
* For example, if we have Bool query like below we need to throw an error
* bool: {
* should: [
* match: {},
* hybrid: {
* sub_query1 {}
* sub_query2 {}
* }
* ]
* }
* TODO add similar validation for other compound type queries like dis_max, constant_score etc.
*
* @param query query to validate
*/
private void validateQuery(final SearchContext searchContext, final Query query) {
if (query instanceof BooleanQuery) {
Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings();
int maxDepthLimit = MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue();
List<BooleanClause> booleanClauses = ((BooleanQuery) query).clauses();
for (BooleanClause booleanClause : booleanClauses) {
validateNestedBooleanQuery(booleanClause.getQuery(), 1);
validateNestedBooleanQuery(booleanClause.getQuery(), maxDepthLimit);
}
}
}
Expand All @@ -152,12 +171,12 @@ private void validateNestedBooleanQuery(final Query query, int level) {
if (query instanceof HybridQuery) {
throw new IllegalArgumentException("hybrid query must be a top level query and cannot be wrapped into other queries");
}
if (level >= MAX_NESTED_SUBQUERY_LIMIT) {
if (level <= 0) {
throw new IllegalStateException("reached max nested query limit, cannot process query");
}
if (query instanceof BooleanQuery) {
for (BooleanClause booleanClause : ((BooleanQuery) query).clauses()) {
validateNestedBooleanQuery(booleanClause.getQuery(), level + 1);
validateNestedBooleanQuery(booleanClause.getQuery(), level - 1);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.MAX_NESTED_SUBQUERY_LIMIT;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement;
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

Expand All @@ -25,6 +24,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.UUID;

import lombok.SneakyThrows;

Expand All @@ -44,10 +44,13 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.opensearch.action.OriginalIndices;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.common.lucene.search.Queries;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.BoolQueryBuilder;
Expand Down Expand Up @@ -82,6 +85,8 @@ public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
private static final String MODEL_ID = "mfgfgdsfgfdgsde";
private static final int K = 10;
private static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder();
private static final UUID INDEX_UUID = UUID.randomUUID();
private static final String TEST_INDEX = "index";

@SneakyThrows
public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() {
Expand Down Expand Up @@ -473,6 +478,13 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() {
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);
when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString()));
when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY);
Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build();
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand Down Expand Up @@ -509,6 +521,114 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBool_thenFail() {
releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridWrappedIntoBoolAndIncorrectStructure_thenFail() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
MapperService mapperService = createMapperService(mapping(b -> {
b.startObject("field");
b.field("type", "text")
.field("fielddata", true)
.startObject("fielddata_frequency_filter")
.field("min", 2d)
.field("min_segment_size", 1000)
.endObject();
b.endObject();
b.startObject("user");
b.field("type", "nested");
b.endObject();
}));

TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();
int docId1 = RandomizedTest.randomInt();
w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.size()).thenReturn(4);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);
when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString()));
when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY);
Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build();
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1));
queryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2));

BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.add(Queries.newNonNestedFilter(), BooleanClause.Occur.FILTER)
.add(queryBuilder.toQuery(mockQueryShardContext), BooleanClause.Occur.SHOULD);
Query query = builder.build();

when(searchContext.query()).thenReturn(query);

IllegalStateException exception = expectThrows(
IllegalStateException.class,
() -> hybridQueryPhaseSearcher.searchWith(
searchContext,
contextIndexSearcher,
query,
collectors,
hasFilterCollector,
hasTimeout
)
);

org.hamcrest.MatcherAssert.assertThat(
exception.getMessage(),
containsString("cannot process hybrid query due to incorrect structure of top level bool query")
);

releaseResources(directory, w, reader);
}

@SneakyThrows
public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_thenSuccess() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher();
Expand Down Expand Up @@ -677,6 +797,13 @@ public void testBoolQuery_whenTooManyNestedLevels_thenFail() {
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);
when(searchContext.getQueryShardContext()).thenReturn(mockQueryShardContext);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.getIndex()).thenReturn(new Index(TEST_INDEX, INDEX_UUID.toString()));
when(indexMetadata.getSettings()).thenReturn(Settings.EMPTY);
Settings settings = Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, Integer.toString(1)).build();
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);
when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();
boolean hasFilterCollector = randomBoolean();
Expand All @@ -685,7 +812,7 @@ public void testBoolQuery_whenTooManyNestedLevels_thenFail() {
Query query = createNestedBoolQuery(
QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1).toQuery(mockQueryShardContext),
QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT2).toQuery(mockQueryShardContext),
MAX_NESTED_SUBQUERY_LIMIT + 1
(int) (MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.getDefault(null) + 1)
);

when(searchContext.query()).thenReturn(query);
Expand Down

0 comments on commit e33de9b

Please sign in to comment.