Skip to content

Commit

Permalink
Multi vector support for Faiss HNSW
Browse files Browse the repository at this point in the history
Pass parentId filter to faiss HNSW search method so that documents will be deduped on its parentId and k results will be returned for documents with nested fields.

Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jan 3, 2024
1 parent 0b252e7 commit e7a5a42
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 50 deletions.
9 changes: 7 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.knn.index.KNNSettings;

import java.io.IOException;
Expand All @@ -33,20 +34,24 @@ public class KNNQuery extends Query {
@Getter
@Setter
private Query filterQuery;
@Getter
private BitSetProducer bitSetProducer;

public KNNQuery(String field, float[] queryVector, int k, String indexName) {
public KNNQuery(String field, float[] queryVector, int k, String indexName, final BitSetProducer bitSetProducer) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
this.bitSetProducer = bitSetProducer;
}

public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) {
public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery, BitSetProducer bitSetProducer) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
this.filterQuery = filterQuery;
this.bitSetProducer = bitSetProducer;
}

public String getField() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
final Query filterQuery = getFilterQuery(createQueryRequest);

BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter();
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) {
log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k);
return new KNNQuery(fieldName, vector, k, indexName, filterQuery);
return new KNNQuery(fieldName, vector, k, indexName, filterQuery, parentFilter);
}
log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KNNQuery(fieldName, vector, k, indexName);
return new KNNQuery(fieldName, vector, k, indexName, parentFilter);
}

log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter();
if (VectorDataType.BYTE == vectorDataType) {
return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter);
} else if (VectorDataType.FLOAT == vectorDataType) {
Expand Down
36 changes: 24 additions & 12 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -119,7 +120,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
if (filterWeight != null && canDoExactSearch(filterIdsArray.length)) {
docIdsToScoreMap.putAll(doExactSearch(context, filterIdsArray));
} else {
Map<Integer, Float> annResults = doANNSearch(context, filterIdsArray);
Map<Integer, Float> annResults = doANNSearch(context, filterIdsArray, knnQuery.getBitSetProducer());
if (annResults == null) {
return null;
}
Expand Down Expand Up @@ -172,23 +173,33 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept
if (filterWeight == null) {
return new int[0];
}
final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight);
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
return bitSetToIntArray(getFilteredDocsBitSet(context, this.filterWeight));
}

private int[] getParentIdsArray(final LeafReaderContext context, final BitSetProducer parentFilter) throws IOException {
if (parentFilter == null) {
return null;
}
return bitSetToIntArray(parentFilter.getBitSet(context));
}

private int[] bitSetToIntArray(final BitSet bitSet) {
final int[] intArray = new int[bitSet.cardinality()];
int index = 0;
int docId = 0;
while (docId < filteredDocsBitSet.length()) {
docId = filteredDocsBitSet.nextSetBit(docId);
while (docId < bitSet.length()) {
docId = bitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
filteredIds[filteredIdsIndex] = docId;
filteredIdsIndex++;
intArray[index] = docId;
index++;
docId++;
}
return filteredIds;
return intArray;
}

private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException {
private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final int[] filterIdsArray, final BitSetProducer parentFilter) throws IOException {
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

Expand Down Expand Up @@ -265,13 +276,14 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final i
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
}

int[] parentIds = getParentIdsArray(context, parentFilter);
results = JNIService.queryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getK(),
knnEngine.getName(),
filterIdsArray
filterIdsArray,
parentIds
);

} catch (Exception e) {
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.jni;

import org.apache.commons.lang.ArrayUtils;
import org.apache.lucene.util.BitSet;
import org.opensearch.knn.index.query.KNNQueryResult;
import org.opensearch.knn.index.util.KNNEngine;

Expand Down Expand Up @@ -101,7 +102,7 @@ public static long loadIndex(String indexPath, Map<String, Object> parameters, S
* @param filteredIds array of ints on which should be used for search.
* @return KNNQueryResult array of k neighbors
*/
public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds) {
public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds, int[] parentIds) {
if (KNNEngine.NMSLIB.getName().equals(engineName)) {
return NmslibService.queryIndex(indexPointer, queryVector, k);
}
Expand All @@ -112,9 +113,9 @@ public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector
// filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine
// normally.
if (ArrayUtils.isNotEmpty(filteredIds)) {
return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, null);
return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, parentIds);
}
return FaissService.queryIndex(indexPointer, queryVector, k, null);
return FaissService.queryIndex(indexPointer, queryVector, k, parentIds);
}
throw new IllegalArgumentException("QueryIndex not supported for provided engine");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.index.mapper.MapperService;
Expand Down Expand Up @@ -162,14 +163,14 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception {

// query to verify distance for each of the field
IndexSearcher searcher = new IndexSearcher(reader);
float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"), 10).scoreDocs[0].score;
float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy"), 10).scoreDocs[0].score;
float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null), 10).scoreDocs[0].score;
float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy", (BitSetProducer) null), 10).scoreDocs[0].score;
assertEquals(1.0f / (1 + 25), score, 0.01f);
assertEquals(1.0f / (1 + 169), score1, 0.01f);

// query to determine the hits
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy")));
assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy")));
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null)));
assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy", (BitSetProducer) null)));

reader.close();
dir.close();
Expand Down Expand Up @@ -254,7 +255,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio
NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService);
float[] query = { 10.0f, 10.0f, 10.0f };
IndexSearcher searcher = new IndexSearcher(reader);
TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy"), 10);
TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy", (BitSetProducer) null), 10);

assertEquals(3, topDocs.scoreDocs[0].doc);
assertEquals(2, topDocs.scoreDocs[1].doc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ public static void assertLoadableByEngine(
);
int k = 2;
float[] queryVector = new float[dimension];
KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null);
KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine.getName(), null, null);
assertTrue(results.length > 0);
JNIService.free(indexPtr, knnEngine.getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void testIndexLoadStrategy_load() throws IOException {
// Confirm that the file was loaded by querying
float[] query = new float[dimension];
Arrays.fill(query, numVectors + 1);
KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null);
KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine.getName(), null, null);
assertTrue(results.length > 0);
}

Expand Down
31 changes: 16 additions & 15 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -154,10 +155,10 @@ public void testQueryScoreForFaissWithModel() throws IOException {
SpaceType spaceType = SpaceType.L2;
final Function<Float, Float> scoreTranslator = spaceType::scoreTranslation;
final String modelId = "modelId";
jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any()))
jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any()))
.thenReturn(getKNNQueryResults());

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
Expand Down Expand Up @@ -221,7 +222,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException {
SpaceType spaceType = SpaceType.L2;
final String modelId = "modelId";

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
Expand Down Expand Up @@ -253,7 +254,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException {

@SneakyThrows
public void testShardWithoutFiles() {
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null);
final KNNWeight knnWeight = new KNNWeight(query, 0.0f);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down Expand Up @@ -294,10 +295,10 @@ public void testShardWithoutFiles() {
@SneakyThrows
public void testEmptyQueryResults() {
final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {};
jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any()))
jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any()))
.thenReturn(knnQueryResults);

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null);
final KNNWeight knnWeight = new KNNWeight(query, 0.0f);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down Expand Up @@ -338,7 +339,7 @@ public void testEmptyQueryResults() {
public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
int k = 3;
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };
jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds)))
jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any()))
.thenReturn(getFilteredKNNQueryResults());
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
Expand All @@ -351,7 +352,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
when(liveDocsBits.length()).thenReturn(1000);
when(leafReaderContext.reader()).thenReturn(reader);

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null);
final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
Expand Down Expand Up @@ -395,7 +396,7 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() {
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());
jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds)));
jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), eq(filterDocIds), any()));

final List<Integer> actualDocIds = new ArrayList<>();
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
Expand All @@ -415,7 +416,7 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() {
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null);
final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
Expand Down Expand Up @@ -465,7 +466,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null);
final Weight filterQueryWeight = mock(Weight.class);
final Scorer filterScorer = mock(Scorer.class);
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
Expand Down Expand Up @@ -534,7 +535,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces

when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length));

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null);

final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);
final Map<String, String> attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name());
Expand Down Expand Up @@ -577,7 +578,7 @@ public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.empty());

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, FILTER_QUERY, null);
final KNNWeight knnWeight = new KNNWeight(query, 0.0f, filterQueryWeight);

final FieldInfos fieldInfos = mock(FieldInfos.class);
Expand All @@ -598,10 +599,10 @@ private void testQueryScore(
final Set<String> segmentFiles,
final Map<String, String> fileAttributes
) throws IOException {
jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any()))
jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), anyString(), any(), any()))
.thenReturn(getKNNQueryResults());

final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME);
final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null);
final KNNWeight knnWeight = new KNNWeight(query, 0.0f);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
Expand Down
Loading

0 comments on commit e7a5a42

Please sign in to comment.