Skip to content

Commit

Permalink
Add integration and unit tests for missing RRF coverage (#997)
Browse files Browse the repository at this point in the history
* Initial unit test implementation

Signed-off-by: Ryan Bogan <[email protected]>

---------
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan authored Dec 3, 2024
1 parent 29fafd6 commit c64f09b
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Optional;

import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
Expand Down Expand Up @@ -98,7 +99,8 @@ public boolean isIgnoreFailure() {
return false;
}

private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
@VisibleForTesting
<Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) {
return true;
}
Expand All @@ -111,7 +113,8 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
@VisibleForTesting
boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
Expand All @@ -120,17 +123,16 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
final SearchPhaseResults<Result> results
) {
<Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(final SearchPhaseResults<Result> results) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
@VisibleForTesting
<Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.SneakyThrows;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE;

public class RRFProcessorIT extends BaseNeuralSearchIT {

private int currentDoc = 1;
private static final String RRF_INDEX_NAME = "rrf-index";
private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline";
private static final String RRF_INGEST_PIPELINE = "rrf-ingest-pipeline";

private static final int RRF_DIMENSION = 5;

@SneakyThrows
public void testRRF_whenValidInput_thenSucceed() {
try {
createPipelineProcessor(null, RRF_INGEST_PIPELINE, ProcessorType.TEXT_EMBEDDING);
prepareKnnIndex(
RRF_INDEX_NAME,
Collections.singletonList(new KNNFieldConfig("passage_embedding", RRF_DIMENSION, TEST_SPACE_TYPE))
);
addDocuments();
createDefaultRRFSearchPipeline();

HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder();

Map<String, Object> results = search(
RRF_INDEX_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", RRF_SEARCH_PIPELINE)
);
Map<String, Object> hits = (Map<String, Object>) results.get("hits");
ArrayList<HashMap<String, Object>> hitsList = (ArrayList<HashMap<String, Object>>) hits.get("hits");
assertEquals(3, hitsList.size());
assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION);
} finally {
wipeOfTestResources(RRF_INDEX_NAME, RRF_INGEST_PIPELINE, null, RRF_SEARCH_PIPELINE);
}
}

private HybridQueryBuilder getHybridQueryBuilder() {
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco");
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder.Builder().fieldName("passage_embedding")
.k(5)
.vector(new float[] { 0.1f, 1.2f, 2.3f, 3.4f, 4.5f })
.build();

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(matchQueryBuilder);
hybridQueryBuilder.add(knnQueryBuilder);
return hybridQueryBuilder;
}

@SneakyThrows
private void addDocuments() {
addDocument(
"A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .",
"4319130149.jpg"
);
addDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg");
addDocument(
"People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .",
"2664027527.jpg"
);
addDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg");
addDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg");
}

@SneakyThrows
private void addDocument(String description, String imageText) {
addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.SneakyThrows;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.IndicesOptions;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.util.concurrent.AtomicArray;
import org.opensearch.core.common.Strings;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.AliasFilter;
import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;

import java.util.List;
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RRFProcessorTests extends OpenSearchTestCase {

@Mock
private ScoreNormalizationTechnique mockNormalizationTechnique;
@Mock
private ScoreCombinationTechnique mockCombinationTechnique;
@Mock
private NormalizationProcessorWorkflow mockNormalizationWorkflow;
@Mock
private SearchPhaseResults<SearchPhaseResult> mockSearchPhaseResults;
@Mock
private SearchPhaseContext mockSearchPhaseContext;
@Mock
private QueryPhaseResultConsumer mockQueryPhaseResultConsumer;

private RRFProcessor rrfProcessor;
private static final String TAG = "tag";
private static final String DESCRIPTION = "description";

@Before
@SneakyThrows
public void setUp() {
super.setUp();
MockitoAnnotations.openMocks(this);
rrfProcessor = new RRFProcessor(TAG, DESCRIPTION, mockNormalizationTechnique, mockCombinationTechnique, mockNormalizationWorkflow);
}

@SneakyThrows
public void testGetType() {
assertEquals(RRFProcessor.TYPE, rrfProcessor.getType());
}

@SneakyThrows
public void testGetBeforePhase() {
assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase());
}

@SneakyThrows
public void testGetAfterPhase() {
assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase());
}

@SneakyThrows
public void testIsIgnoreFailure() {
assertFalse(rrfProcessor.isIgnoreFailure());
}

@SneakyThrows
public void testProcess_whenNullSearchPhaseResult_thenSkipWorkflow() {
rrfProcessor.process(null, mockSearchPhaseContext);
verify(mockNormalizationWorkflow, never()).execute(any());
}

@SneakyThrows
public void testProcess_whenNonQueryPhaseResultConsumer_thenSkipWorkflow() {
rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext);
verify(mockNormalizationWorkflow, never()).execute(any());
}

@SneakyThrows
public void testProcess_whenValidHybridInput_thenSucceed() {
QuerySearchResult result = createQuerySearchResult(true);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow).execute(any(NormalizationExecuteDTO.class));
}

@SneakyThrows
public void testProcess_whenValidNonHybridInput_thenSucceed() {
QuerySearchResult result = createQuerySearchResult(false);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow, never()).execute(any(NormalizationExecuteDTO.class));
}

@SneakyThrows
public void testGetTag() {
assertEquals(TAG, rrfProcessor.getTag());
}

@SneakyThrows
public void testGetDescription() {
assertEquals(DESCRIPTION, rrfProcessor.getDescription());
}

@SneakyThrows
public void testShouldSkipProcessor() {
assertTrue(rrfProcessor.shouldSkipProcessor(null));
assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults));

AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, createQuerySearchResult(false));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer));

atomicArray.set(0, createQuerySearchResult(true));
assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer));
}

@SneakyThrows
public void testGetQueryPhaseSearchResults() {
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(2);
atomicArray.set(0, createQuerySearchResult(true));
atomicArray.set(1, createQuerySearchResult(false));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

List<QuerySearchResult> results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer);
assertEquals(2, results.size());
assertNotNull(results.get(0));
assertNotNull(results.get(1));
}

@SneakyThrows
public void testGetFetchSearchResults() {
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, createQuerySearchResult(true));
when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

Optional<FetchSearchResult> result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer);
assertFalse(result.isPresent());
}

private QuerySearchResult createQuerySearchResult(boolean isHybrid) {
ShardId shardId = new ShardId("index", "uuid", 0);
OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed());
SearchRequest searchRequest = new SearchRequest("index");
searchRequest.source(new SearchSourceBuilder());
searchRequest.allowPartialSearchResults(true);

int numberOfShards = 1;
AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY);
float indexBoost = 1.0f;
long nowInMillis = System.currentTimeMillis();
String clusterAlias = null;
String[] indexRoutings = Strings.EMPTY_ARRAY;

ShardSearchRequest shardSearchRequest = new ShardSearchRequest(
originalIndices,
searchRequest,
shardId,
numberOfShards,
aliasFilter,
indexBoost,
nowInMillis,
clusterAlias,
indexRoutings
);

QuerySearchResult result = new QuerySearchResult(
new ShardSearchContextId("test", 1),
new SearchShardTarget("node1", shardId, clusterAlias, originalIndices),
shardSearchRequest
);
result.from(0).size(10);

ScoreDoc[] scoreDocs;
if (isHybrid) {
scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) };
} else {
scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) };
}

TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f);
result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]);

return result;
}
}
Loading

0 comments on commit c64f09b

Please sign in to comment.