Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add check to directly use ANN Search when filters match all docs. #2320

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
*/
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
final BitSet filterBitSet = getFilteredDocsBitSet(context);
final int maxDoc = context.reader().maxDoc();
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
Expand All @@ -145,6 +146,12 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
Map<Integer, Float> result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}
/*
* If filters match all docs in this segment, then there is no need to do any extra step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you give more context to the comment?? Would like to see why there's no need for extra steps for all doc matches case, and what passing new FixedBitSet(0) means. Basically, we want to save a bitset look up cost when it's not required as possible.

* and should directly do ANN Search*/
if (filterWeight != null && cardinality == maxDoc) {
return new PerLeafResult(new FixedBitSet(0), doANNSearch(context, new FixedBitSet(0), 0, k));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see lots of new FixedBitSet(0) in KNNWeight.
Could you factor it out with a meaningful name as a static variable then pass it down to the downstream?

private static final FixedBitSet ... = new FixedBitSet(0);

}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
Expand Down Expand Up @@ -320,6 +327,7 @@ private Map<Integer, Float> doANNSearch(
// Now that we have the allocation, we need to readLock it
indexAllocation.readLock();
indexAllocation.incRef();

try {
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
Expand Down
101 changes: 97 additions & 4 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
when(liveDocsBits.length()).thenReturn(1000);

final SegmentReader reader = mockSegmentReader();
when(reader.maxDoc()).thenReturn(filterDocIds.length);
when(reader.maxDoc()).thenReturn(filterDocIds.length + 1);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down Expand Up @@ -758,6 +758,97 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

@SneakyThrows
public void testANNWithFilterQuery_whenFiltersMatchAllDocs_thenSuccess() {
// Given
int k = 3;
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };
FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length);
for (int docId : filterDocIds) {
filterBitSet.set(docId);
}

jniServiceMockedStatic.when(
() -> JNIService.queryIndex(
anyLong(),
eq(QUERY_VECTOR),
eq(k),
eq(HNSW_METHOD_PARAMETERS),
any(),
eq(new FixedBitSet(0).getBits()),
anyInt(),
any()
)
).thenReturn(getFilteredKNNQueryResults());

final Bits liveDocsBits = mock(Bits.class);
for (int filterDocId : filterDocIds) {
when(liveDocsBits.get(filterDocId)).thenReturn(true);
}
when(liveDocsBits.length()).thenReturn(1000);

final SegmentReader reader = mockSegmentReader();
when(reader.maxDoc()).thenReturn(filterDocIds.length);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
when(leafReaderContext.reader()).thenReturn(reader);

final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(QUERY_VECTOR)
.k(k)
.indexName(INDEX_NAME)
.filterQuery(FILTER_QUERY)
.methodParameters(HNSW_METHOD_PARAMETERS)
.build();

final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// Just to make sure that we are not hitting the exact search condition
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length + 1));

final float boost = (float) randomDoubleBetween(0, 10, true);
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);

final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
final Map<String, String> attributesMap = ImmutableMap.of(
KNN_ENGINE,
KNNEngine.FAISS.getName(),
SPACE_TYPE,
SpaceType.L2.getValue()
);

when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
when(fieldInfo.attributes()).thenReturn(attributesMap);

// When
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);

// Then
assertNotNull(knnScorer);
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());

jniServiceMockedStatic.verify(
() -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(k), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()),
times(1)
);

final List<Integer> actualDocIds = new ArrayList<>();
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
}

private SegmentReader mockSegmentReader() {
Path path = mock(Path.class);

Expand Down Expand Up @@ -815,7 +906,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// scorer will return 2 documents
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1));
when(reader.maxDoc()).thenReturn(1);
when(reader.maxDoc()).thenReturn(2);
final Bits liveDocsBits = mock(Bits.class);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
when(liveDocsBits.get(filterDocId)).thenReturn(true);
Expand Down Expand Up @@ -891,6 +982,7 @@ public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() {
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);
when(reader.maxDoc()).thenReturn(1);

final FSDirectory directory = mock(FSDirectory.class);
when(reader.directory()).thenReturn(directory);
Expand Down Expand Up @@ -968,7 +1060,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
// scorer will return 2 documents
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(1));
when(reader.maxDoc()).thenReturn(1);
when(reader.maxDoc()).thenReturn(2);
final Bits liveDocsBits = mock(Bits.class);
when(reader.getLiveDocs()).thenReturn(liveDocsBits);
when(liveDocsBits.get(filterDocId)).thenReturn(true);
Expand Down Expand Up @@ -1168,6 +1260,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);
when(reader.maxDoc()).thenReturn(1);

final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
Expand Down Expand Up @@ -1202,7 +1295,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() {
// We will have 0, 1 for filteredIds and 2 will be the parent id for both of them
final Scorer filterScorer = mock(Scorer.class);
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(2));
when(reader.maxDoc()).thenReturn(2);
when(reader.maxDoc()).thenReturn(3);

// Query vector is {1.8f, 2.4f}, therefore, second vector {1.9f, 2.5f} should be returned in a result
final List<float[]> vectors = Arrays.asList(new float[] { 0.1f, 0.3f }, new float[] { 1.9f, 2.5f });
Expand Down
Loading