Skip to content

Commit

Permalink
Added exact search for cases when filteredIds < k to improve the reca…
Browse files Browse the repository at this point in the history
…ll for exact search (#928)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Jun 8, 2023
1 parent 1135d79 commit 96aaab7
Showing 1 changed file with 113 additions and 58 deletions.
171 changes: 113 additions & 58 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
Expand All @@ -35,10 +40,12 @@
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -97,6 +104,79 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
if (filterWeight != null && filterIdsArray.length == 0) {
return KNNScorer.emptyScorer(this);
}
final Map<Integer, Float> docIdsToScoreMap = new HashMap<>();

/*
* The idea for this optimization is to get K results, we need to atleast look at K vectors in the HNSW graph
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
* This improves the recall.
*/
if (filterWeight != null && filterIdsArray.length <= knnQuery.getK()) {
docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray));
} else {
final Map<Integer, Float> annResults = doANNSearch(context, filterIdsArray);
if (annResults == null) {
return null;
}
docIdsToScoreMap.putAll(annResults);
}
return convertSearchResponseToScorer(docIdsToScoreMap);
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException {
final Bits liveDocs = ctx.reader().getLiveDocs();
final int maxDoc = ctx.reader().maxDoc();

final Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return new FixedBitSet(0);
}

final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
// TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the
// distance for K vectors. This can avoid calls to native layer and save some latency.
final int cost = acceptDocs.cardinality();
log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost);
return acceptDocs;
}

private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return ((BitSetIterator) filteredDocIdsIterator).getBitSet();
}
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
return BitSet.of(filterIterator, maxDoc);
}

private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException {
if (filterWeight == null) {
return new int[0];
}
final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight);
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
int docId = 0;
while (true) {
docId = filteredDocsBitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
log.debug("Docs in filtered docs id set is : {}", docId);
filteredIds[filteredIdsIndex] = docId;
filteredIdsIndex++;
docId++;
}
return filteredIds;
}

private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException {
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

Expand Down Expand Up @@ -200,70 +280,45 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return null;
}

Map<Integer, Float> scores = Arrays.stream(results)
return Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
int maxDoc = Collections.max(scores.keySet()) + 1;
DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);

// The docIdSetIterator will contain the docids of the returned results. So, before adding results to
// the builder, we can grow to results.length
DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(results.length);
Arrays.stream(results).forEach(result -> setAdder.add(result.getId()));
DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
return new KNNScorer(this, docIdSetIter, scores, boost);
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx, final Weight filterWeight) throws IOException {
final Bits liveDocs = ctx.reader().getLiveDocs();
final int maxDoc = ctx.reader().maxDoc();

final Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
return new FixedBitSet(0);
}

final BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
// TODO: Based on this cost shift to exact search, because even in ANN search you have to calculate the
// distance for K vectors. This can avoid calls to native layer and save some latency.
final int cost = acceptDocs.cardinality();
log.debug("Number of docs valid for filter is = Cost for filtered k-nn is : {}", cost);
return acceptDocs;
}

private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
return ((BitSetIterator) filteredDocIdsIterator).getBitSet();
}
// Create a new BitSet from matching and live docs
FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(filteredDocIdsIterator) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
try {
final BinaryDocValues values = DocValues.getBinary(leafReaderContext.reader(), fieldInfo.name);
final SpaceType spaceType = SpaceType.getSpace(fieldInfo.getAttribute(SPACE_TYPE));

final Map<Integer, Float> docToScore = new HashMap<>();
for (int j : filterIdsArray) {
int docId = values.advance(j);
BytesRef value = values.binaryValue();
ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
// making min score as high score as this is closest to the vector
float score = spaceType.getVectorSimilarityFunction().compare(queryVector, vector);
docToScore.put(docId, score);
}
};
return BitSet.of(filterIterator, maxDoc);
return docToScore;
} catch (Exception e) {
log.error("Error while getting the doc values to do the k-NN Search for query : {}", this.knnQuery);
}
return Collections.emptyMap();
}

private int[] getFilterIdsArray(final LeafReaderContext context) throws IOException {
if (filterWeight == null) {
return new int[0];
}
final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight);
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
int docId = 0;
while (true) {
docId = filteredDocsBitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
log.debug("Docs in filtered docs id set is : {}", docId);
filteredIds[filteredIdsIndex] = docId;
filteredIdsIndex++;
docId++;
}
return filteredIds;
private Scorer convertSearchResponseToScorer(final Map<Integer, Float> docsToScore) throws IOException {
final int maxDoc = Collections.max(docsToScore.keySet()) + 1;
final DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
// The docIdSetIterator will contain the docids of the returned results. So, before adding results to
// the builder, we can grow to docsToScore.size()
final DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(docsToScore.size());
docsToScore.keySet().forEach(setAdder::add);
final DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
return new KNNScorer(this, docIdSetIter, docsToScore, boost);
}

@Override
Expand Down

0 comments on commit 96aaab7

Please sign in to comment.