Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing multiple issues reported in #497 #524

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,20 @@
SearchHits searchHits = fetchSearchResult.hits();
SearchHit[] searchHitArray = searchHits.getHits();
// validate the both collections are of the same size
if (Objects.isNull(searchHitArray) || searchHitArray.length != docIds.size()) {
throw new IllegalStateException("Score normalization processor cannot produce final query result");
if (Objects.isNull(searchHitArray)) {
throw new IllegalStateException(

Check warning on line 177 in src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java#L177

Added line #L177 was not covered by tests
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
"score normalization processor cannot produce final query result, fetch query phase returns empty results"
);
}
if (searchHitArray.length != docIds.size()) {
throw new IllegalStateException(
String.format(
Locale.ROOT,
"score normalization processor cannot produce final query result, the number of documents after fetch phase [%d] is different from number of documents from query phase [%d]",
searchHitArray.length,
docIds.size()
)
);
}
return searchHitArray;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -37,7 +39,7 @@

private final float[] subScores;

private final Map<Query, Integer> queryToIndex;
private final Map<Query, List<Integer>> queryToIndex;

public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
super(weight);
Expand Down Expand Up @@ -111,24 +113,43 @@
DisiWrapper topList = subScorersPQ.topList();
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
Scorer scorer = disiWrapper.scorer;
if (scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
float subScore = disiWrapper.scorer.score();
scores[queryToIndex.get(disiWrapper.scorer.getWeight().getQuery())] = subScore;
Query query = scorer.getWeight().getQuery();
List<Integer> indexes = queryToIndex.get(query);
// we need to find the index of first sub-query that hasn't been set yet. Such score will have initial value of "0.0"
int index = indexes.stream()
.mapToInt(idx -> idx)
.filter(idx -> Float.compare(scores[idx], 0.0f) == 0)
.findFirst()
.orElseThrow(
() -> new IllegalStateException(
String.format(

Check warning on line 129 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L128-L129

Added lines #L128 - L129 were not covered by tests
Locale.ROOT,
"cannot set score for one of hybrid search subquery [%s] and document [%d]",
query.toString(),
scorer.docID()

Check warning on line 133 in src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java#L132-L133

Added lines #L132 - L133 were not covered by tests
)
)
);
scores[index] = scorer.score();
}
return scores;
}

private Map<Query, Integer> mapQueryToIndex() {
Map<Query, Integer> queryToIndex = new HashMap<>();
private Map<Query, List<Integer>> mapQueryToIndex() {
Map<Query, List<Integer>> queryToIndex = new HashMap<>();
int idx = 0;
for (Scorer scorer : subScorers) {
if (scorer == null) {
idx++;
continue;
}
queryToIndex.put(scorer.getWeight().getQuery(), idx);
Query query = scorer.getWeight().getQuery();
queryToIndex.putIfAbsent(query, new ArrayList<>());
queryToIndex.get(query).add(idx);
idx++;
}
return queryToIndex;
Expand All @@ -137,7 +158,9 @@
private DisiPriorityQueue initializeSubScorersPQ() {
Objects.requireNonNull(queryToIndex, "should not be null");
Objects.requireNonNull(subScorers, "should not be null");
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(queryToIndex.size());
// we need to count this way in order to include all identical sub-queries
int numOfSubQueries = queryToIndex.values().stream().map(List::size).reduce(0, Integer::sum);
DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numOfSubQueries);
for (Scorer scorer : subScorers) {
if (scorer == null) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@ public HitsThresholdChecker(int totalHitsThreshold) {
if (totalHitsThreshold < 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be >= 0, got %d", totalHitsThreshold));
}
if (totalHitsThreshold == Integer.MAX_VALUE) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "totalHitsThreshold must be less than max integer value"));
}
this.totalHitsThreshold = totalHitsThreshold;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor;

import static org.hamcrest.Matchers.startsWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
Expand All @@ -15,6 +16,7 @@
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -45,9 +47,13 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.fetch.QueryFetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
Expand Down Expand Up @@ -325,4 +331,172 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul

verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any());
}

public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
NormalizationProcessor normalizationProcessor = new NormalizationProcessor(
PROCESSOR_TAG,
DESCRIPTION,
new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD),
new ScoreCombinationFactory().createCombination(COMBINATION_METHOD),
normalizationProcessorWorkflow
);

SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
searchRequest.setBatchedReduceSize(4);
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(
searchRequest,
executor,
new NoopCircuitBreaker(CircuitBreaker.REQUEST),
searchPhaseController,
SearchProgressListener.NOOP,
writableRegistry(),
10,
e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
curr.addSuppressed(prev);
return curr;
})
);
CountDownLatch partialReduceLatch = new CountDownLatch(5);
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),

new ScoreDoc[] {
createStartStopElementForHybridSearchResults(4),
createDelimiterElementForHybridSearchResults(4),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(4) }

);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);

FetchSearchResult fetchSearchResult = new FetchSearchResult();
fetchSearchResult.setShardIndex(shardId);
fetchSearchResult.setSearchShardTarget(searchShardTarget);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(0, "10", Map.of(), Map.of()),
new SearchHit(2, "1", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(10, "3", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()) };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext);

List<QuerySearchResult> querySearchResults = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());

TestUtils.assertQueryResultScores(querySearchResults);
verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any());
}

public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() {
NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy(
new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner())
);
NormalizationProcessor normalizationProcessor = new NormalizationProcessor(
PROCESSOR_TAG,
DESCRIPTION,
new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD),
new ScoreCombinationFactory().createCombination(COMBINATION_METHOD),
normalizationProcessorWorkflow
);

SearchRequest searchRequest = new SearchRequest(INDEX_NAME);
searchRequest.setBatchedReduceSize(4);
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(
searchRequest,
executor,
new NoopCircuitBreaker(CircuitBreaker.REQUEST),
searchPhaseController,
SearchProgressListener.NOOP,
writableRegistry(),
10,
e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
curr.addSuppressed(prev);
return curr;
})
);
CountDownLatch partialReduceLatch = new CountDownLatch(5);
int shardId = 0;
SearchShardTarget searchShardTarget = new SearchShardTarget(
"node",
new ShardId("index", "uuid", shardId),
null,
OriginalIndices.NONE
);
QuerySearchResult querySearchResult = new QuerySearchResult();
TopDocs topDocs = new TopDocs(
new TotalHits(4, TotalHits.Relation.EQUAL_TO),

new ScoreDoc[] {
createStartStopElementForHybridSearchResults(4),
createDelimiterElementForHybridSearchResults(4),
new ScoreDoc(0, 0.5f),
new ScoreDoc(2, 0.3f),
new ScoreDoc(4, 0.25f),
new ScoreDoc(10, 0.2f),
createStartStopElementForHybridSearchResults(4) }

);
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]);
querySearchResult.setSearchShardTarget(searchShardTarget);
querySearchResult.setShardIndex(shardId);

FetchSearchResult fetchSearchResult = new FetchSearchResult();
fetchSearchResult.setShardIndex(shardId);
fetchSearchResult.setSearchShardTarget(searchShardTarget);
SearchHit[] searchHitArray = new SearchHit[] {
new SearchHit(0, "10", Map.of(), Map.of()),
new SearchHit(2, "1", Map.of(), Map.of()),
new SearchHit(4, "2", Map.of(), Map.of()),
new SearchHit(10, "3", Map.of(), Map.of()),
new SearchHit(0, "10", Map.of(), Map.of()), };
SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(5, TotalHits.Relation.EQUAL_TO), 10);
fetchSearchResult.hits(searchHits);

QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult);
queryFetchSearchResult.setShardIndex(shardId);

queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown);

SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class);
IllegalStateException exception = expectThrows(
IllegalStateException.class,
() -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext)
);
org.hamcrest.MatcherAssert.assertThat(
exception.getMessage(),
startsWith("score normalization processor cannot produce final query result")
);
}
}
Loading
Loading