diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index a20b52517..f317a9e12 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -138,7 +138,13 @@ private void updateOriginalFetchResults( // 3. update original scores to normalized and combined values // 4. order scores based on normalized and combined values FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get(); - SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult); + // checking case when results are cached + boolean requestCache = Objects.nonNull(querySearchResults) + && !querySearchResults.isEmpty() + && Objects.nonNull(querySearchResults.get(0).getShardSearchRequest().requestCache()) + && querySearchResults.get(0).getShardSearchRequest().requestCache(); + + SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult, requestCache); // create map of docId to index of search hits. This solves (2), duplicates are from // delimiter and start/stop elements, they all have same valid doc_id. For this map @@ -168,7 +174,7 @@ private void updateOriginalFetchResults( fetchSearchResult.hits(updatedSearchHits); } - private SearchHit[] getSearchHits(final List docIds, final FetchSearchResult fetchSearchResult) { + private SearchHit[] getSearchHits(final List docIds, final FetchSearchResult fetchSearchResult, final boolean requestCache) { SearchHits searchHits = fetchSearchResult.hits(); SearchHit[] searchHitArray = searchHits.getHits(); // validate the both collections are of the same size @@ -177,7 +183,9 @@ private SearchHit[] getSearchHits(final List docIds, final FetchSearchR "score normalization processor cannot produce final query result, fetch query phase returns empty results" ); } - if (searchHitArray.length != docIds.size()) { + // in case of cached request results of fetch and query may be different, only restriction is + // that number of query results size is greater or equal size of fetch results + if ((!requestCache && searchHitArray.length != docIds.size()) || requestCache && docIds.size() < searchHitArray.length) { throw new IllegalStateException( String.format( Locale.ROOT, diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 57698cd7e..dd185e227 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -53,6 +53,7 @@ import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.fetch.QueryFetchSearchResult; +import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -401,6 +402,9 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); queryFetchSearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); @@ -485,6 +489,9 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); queryFetchSearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE); + querySearchResult.setShardSearchRequest(shardSearchRequest); queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index f34f8f59b..9785761d2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -4,7 +4,9 @@ */ package org.opensearch.neuralsearch.processor; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; @@ -29,6 +31,7 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; @@ -156,6 +159,9 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); querySearchResults.add(querySearchResult); SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(0, "10", Map.of(), Map.of()), @@ -213,6 +219,9 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); querySearchResults.add(querySearchResult); SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(-1, "10", Map.of(), Map.of()), @@ -236,7 +245,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom TestUtils.assertFetchResultScores(fetchSearchResult, 4); } - public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { + public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -270,6 +279,9 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE); + querySearchResult.setShardSearchRequest(shardSearchRequest); querySearchResults.add(querySearchResult); SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(-1, "10", Map.of(), Map.of()), @@ -291,4 +303,63 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then ) ); } + + public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + new DocValueFormat[0] + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); + querySearchResults.add(querySearchResult); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "1", Map.of(), Map.of()), + new SearchHit(-1, "2", Map.of(), Map.of()), + new SearchHit(-1, "3", Map.of(), Map.of()) }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ); + + TestUtils.assertQueryResultScores(querySearchResults); + TestUtils.assertFetchResultScores(fetchSearchResult, 4); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 487a378df..1d89d0a6a 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -29,6 +29,7 @@ import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; @@ -43,14 +44,17 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-hybrid-multi-doc-single-shard-index"; private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = "test-hybrid-multi-doc-nested-type-single-shard-index"; + private static final String TEST_INDEX_WITH_KEYWORDS_ONE_SHARD = "test-hybrid-keywords-single-shard-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; private static final String TEST_QUERY_TEXT4 = "place"; private static final String TEST_QUERY_TEXT5 = "welcome"; + private static final String TEST_QUERY_TEXT6 = "machine"; private static final String TEST_DOC_TEXT1 = "Hello world"; private static final String TEST_DOC_TEXT2 = "Hi to this place"; private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "There was no telling what thoughts would come from the machine"; private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; @@ -59,6 +63,18 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String NESTED_FIELD_2 = "lastname"; private static final String NESTED_FIELD_1_VALUE = "john"; private static final String NESTED_FIELD_2_VALUE = "black"; + private static final String KEYWORD_FIELD_1 = "doc_keyword"; + private static final String KEYWORD_FIELD_1_VALUE = "workable"; + private static final String KEYWORD_FIELD_2_VALUE = "angry"; + private static final String KEYWORD_FIELD_3_VALUE = "likeable"; + private static final String KEYWORD_FIELD_4_VALUE = "entire"; + private static final String INTEGER_FIELD_PRICE = "doc_price"; + private static final int INTEGER_FIELD_PRICE_1_VALUE = 130; + private static final int INTEGER_FIELD_PRICE_2_VALUE = 100; + private static final int INTEGER_FIELD_PRICE_3_VALUE = 200; + private static final int INTEGER_FIELD_PRICE_4_VALUE = 25; + private static final int INTEGER_FIELD_PRICE_5_VALUE = 30; + private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -399,6 +415,87 @@ public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess( } } + @SneakyThrows + public void testRequestCache_whenQueryReturnResults_thenSuccessful() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD); + modelId = prepareModel(); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(KEYWORD_FIELD_1, KEYWORD_FIELD_2_VALUE); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_PRICE).gte(10).lte(1000); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(rangeQueryBuilder); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_WITH_KEYWORDS_ONE_SHARD, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "request_cache", Boolean.TRUE.toString()) + ); + + assertEquals(6, getHitCount(firstSearchResponseAsMap)); + + List> hits1NestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(6, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // second query is served from the cache + Map secondSearchResponseAsMap = search( + TEST_INDEX_WITH_KEYWORDS_ONE_SHARD, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "request_cache", Boolean.TRUE.toString()) + ); + + assertEquals(6, getHitCount(secondSearchResponseAsMap)); + + List> hitsNestedListSecondQuery = getNestedHits(secondSearchResponseAsMap); + List idsSecondQuery = new ArrayList<>(); + List scoresSecondQuery = new ArrayList<>(); + for (Map oneHit : hitsNestedListSecondQuery) { + idsSecondQuery.add((String) oneHit.get("_id")); + scoresSecondQuery.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue( + IntStream.range(0, scoresSecondQuery.size() - 1) + .noneMatch(idx -> scoresSecondQuery.get(idx) < scoresSecondQuery.get(idx + 1)) + ); + // verify that all ids are unique + assertEquals(Set.copyOf(idsSecondQuery).size(), idsSecondQuery.size()); + + Map totalSecondQuery = getTotalHits(secondSearchResponseAsMap); + assertNotNull(totalSecondQuery.get("value")); + assertEquals(6, totalSecondQuery.get("value")); + assertNotNull(totalSecondQuery.get("relation")); + assertEquals(RELATION_EQUAL_TO, totalSecondQuery.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD, null, modelId, SEARCH_PIPELINE); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -490,6 +587,111 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE)) ); } + + if (TEST_INDEX_WITH_KEYWORDS_ONE_SHARD.equals(indexName) && !indexExists(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_PRICE), List.of(KEYWORD_FIELD_1), List.of(), 1), + "" + ); + + addKnnDoc( + indexName, + "1", + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_1_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_1_VALUE), + List.of(), + List.of() + ); + addKnnDoc( + indexName, + "2", + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_2_VALUE), + List.of(), + List.of(), + List.of(), + List.of() + ); + addKnnDoc( + indexName, + "3", + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_3_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_2_VALUE), + List.of(), + List.of() + ); + addKnnDoc( + indexName, + "4", + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_4_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_3_VALUE), + List.of(), + List.of() + ); + addKnnDoc( + indexName, + "5", + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_5_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_4_VALUE), + List.of(), + List.of() + ); + addKnnDoc( + indexName, + "6", + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_6_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_4_VALUE), + List.of(), + List.of() + ); + } } private void addDocsToIndex(final String testMultiDocIndexName) {