Skip to content

Commit

Permalink
Adding null check for case when hybrid query wrapped into bool query
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 11, 2023
1 parent cda2f82 commit ef57d1d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
Fixed exception for case when Hybrid query being wrapped into bool query ([#XXX](https://github.com/opensearch-project/neural-search/compare/main...martin-gaievski:neural-search:fix_npe_for_wrapped_hybrid_query?expand=1)
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public DocIdSetIterator iterator() {
*/
@Override
public float getMaxScore(int upTo) throws IOException {
return subScorers.stream().filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
return subScorers.stream().filter(Objects::nonNull).filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
try {
return scorer.getMaxScore(upTo);
} catch (IOException e) {
Expand Down
88 changes: 66 additions & 22 deletions src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.junit.After;
import org.junit.Before;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -33,6 +34,7 @@ public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index";
private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index";
private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index";
private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index";
private static final String TEST_QUERY_TEXT = "greetings";
private static final String TEST_QUERY_TEXT2 = "salute";
private static final String TEST_QUERY_TEXT3 = "hello";
Expand Down Expand Up @@ -188,6 +190,35 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult(
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
public void testNestedQuery_whenHybridQueryIsWrappedIntoOtherQuery_thenSuccess() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);

MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
MatchQueryBuilder matchQuery2Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder();
hybridQueryBuilderOnlyTerm.add(matchQueryBuilder);
hybridQueryBuilderOnlyTerm.add(matchQuery2Builder);
MatchQueryBuilder matchQuery3Builder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery().should(hybridQueryBuilderOnlyTerm).should(matchQuery3Builder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
boolQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertTrue(getHitCount(searchResponseAsMap) > 0);
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(getMaxScore(searchResponseAsMap).get() > 0.0f);

Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertTrue((int) total.get("value") > 0);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
prepareKnnIndex(
Expand Down Expand Up @@ -242,32 +273,45 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
TEST_MULTI_DOC_INDEX_NAME,
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE))
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"1",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector1).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1)
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"2",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector2).toArray())
);
addKnnDoc(
TEST_MULTI_DOC_INDEX_NAME,
"3",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector3).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT2)
addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME);
}

if (TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD)) {
prepareKnnIndex(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)),
1
);
assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME));
addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
}
}

private void addDocsToIndex(final String testMultiDocIndexName) {
addKnnDoc(
testMultiDocIndexName,
"1",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector1).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1)
);
addKnnDoc(
testMultiDocIndexName,
"2",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector2).toArray())
);
addKnnDoc(
testMultiDocIndexName,
"3",
Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1),
Collections.singletonList(Floats.asList(testVector3).toArray()),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT2)
);
assertEquals(3, getDocCount(testMultiDocIndexName));
}

private List<Map<String, Object>> getNestedHits(Map<String, Object> searchResponseAsMap) {
Map<String, Object> hitsMap = (Map<String, Object>) searchResponseAsMap.get("hits");
return (List<Map<String, Object>>) hitsMap.get("hits");
Expand Down

0 comments on commit ef57d1d

Please sign in to comment.