Skip to content

Commit

Permalink
Multi vector support for Faiss HNSW (#1371)
Browse files Browse the repository at this point in the history
Apply the parentId filter to the Faiss HNSW search method. This ensures that documents are deduplicated based on their parentId, and the method returns k results for documents with nested fields.

Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Jan 16, 2024
1 parent ead4411 commit d0dc3f9
Show file tree
Hide file tree
Showing 17 changed files with 1,355 additions and 85 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/test_security.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ jobs:
with:
submodules: true

# Git functionality in CMAKE file does not work with given ubuntu image. Therefore, handling it here.
- name: Apply Git Patch
# Deleting file at the end to skip `git apply` inside CMAKE file
run: |
cd jni/external/faiss
git apply --ignore-space-change --ignore-whitespace --3way ../../patches/faiss/0001-Custom-patch-to-support-multi-vector.patch
rm ../../patches/faiss/0001-Custom-patch-to-support-multi-vector.patch
working-directory: ${{ github.workspace }}

- name: Setup Java ${{ matrix.java }}
uses: actions/setup-java@v1
with:
Expand Down
15 changes: 9 additions & 6 deletions jni/include/knn_extension/faiss/utils/BitSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,21 @@ struct BitSet {
* bitmap: 10001000 00000100
*
* for next_set_bit call with 4
* 1. it looks for bitmap[0]
* 2. bitmap[0] >> 4
* 1. it looks for words[0]
* 2. words[0] >> 4
* 3. count trailing zero of the result from step 2 which is 3
* 4. return 4(current index) + 3(result from step 3)
*/
struct FixedBitSet : public BitSet {
// Length of bitmap
size_t numBits;
// The number of bits in use
idx_t num_bits;

// Pointer to an array of uint64_t
// The exact number of longs needed to hold num_bits
size_t num_words;

// Array of uint64_t holding the bits
// Using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h
uint64_t* bitmap;
uint64_t* words;

FixedBitSet(const int* int_array, const int length);
idx_t next_set_bit(idx_t index) const;
Expand Down
6 changes: 4 additions & 2 deletions jni/include/knn_extension/faiss/utils/Heap.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ static inline void down_heap(

/**
* Push the value to the max heap
* As the heap contains only one value per group id, pushing a value of existing group id
* will break the data integrity. For existing group id, use maxheap_update instead.
* The parent_id should not exist in in bh_ids, parent_id_to_id, and parent_id_to_index.
*
* @param nres number of values in the binary heap(bh_val, and bh_ids)
Expand All @@ -152,7 +154,7 @@ inline void maxheap_push(
std::unordered_map<int64_t, size_t>* parent_id_to_index,
int64_t parent_id) {

assert(parent_id_to_index->find(parent_id) != parent_id_to_index->end() && "parent id should not exist in the binary heap");
assert(parent_id_to_index->find(parent_id) == parent_id_to_index->end() && "parent id should not exist in the binary heap");

up_heap<faiss::CMax<T, int64_t>>(
bh_val,
Expand Down Expand Up @@ -189,7 +191,7 @@ inline void maxheap_replace_top(
std::unordered_map<int64_t, size_t>* parent_id_to_index,
int64_t parent_id) {

assert(parent_id_to_index->find(parent_id) != parent_id_to_index->end() && "parent id should not exist in the binary heap");
assert(parent_id_to_index->find(parent_id) == parent_id_to_index->end() && "parent id should not exist in the binary heap");

parent_id_to_id->erase(bh_ids[0]);
parent_id_to_index->erase(bh_ids[0]);
Expand Down
23 changes: 14 additions & 9 deletions jni/src/knn_extension/faiss/utils/BitSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,30 @@
FixedBitSet::FixedBitSet(const int* int_array, const int length){
assert(int_array && "int_array should not be null");
const int* maxValue = std::max_element(int_array, int_array + length);
this->numBits = (*maxValue >> 6) + 1; // div by 64
this->bitmap = new uint64_t[this->numBits]();
this->num_bits = *maxValue + 1;
this->num_words = (num_bits >> 6) + 1; // div by 64
this->words = new uint64_t[this->num_words]();
for(int i = 0 ; i < length ; i ++) {
int value = int_array[i];
int bitsetArrayIndex = value >> 6;
this->bitmap[bitsetArrayIndex] |= 1ULL << (value & 63); // Equivalent of 1ULL << (value % 64)
int bitset_array_index = value >> 6;
this->words[bitset_array_index] |= 1ULL << (value & 63); // Equivalent of 1ULL << (value % 64)
}
}

idx_t FixedBitSet::next_set_bit(idx_t index) const {
idx_t i = index >> 6; // div by 64
uint64_t word = this->bitmap[i] >> (index & 63); // Equivalent of bitmap[i] >> (index % 64)
assert(index >= 0 && "index shouldn't be less than zero");
assert(index < this->num_bits && "index should be less than total number of bits");

idx_t i = index >> 6; // div by 64
uint64_t word = this->words[i] >> (index & 63); // Equivalent of words[i] >> (index % 64)
// word is non zero after right shift, it means, next set bit is in current word
// The index of set bit is "given index" + "trailing zero in the right shifted word"
if (word != 0) {
return index + __builtin_ctzll(word);
}

while (++i < this->numBits) {
word = this->bitmap[i];
while (++i < this->num_words) {
word = this->words[i];
if (word != 0) {
return (i << 6) + __builtin_ctzll(word);
}
Expand All @@ -38,5 +43,5 @@ idx_t FixedBitSet::next_set_bit(idx_t index) const {
}

FixedBitSet::~FixedBitSet() {
delete this->bitmap;
delete this->words;
}
24 changes: 21 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.knn.index.KNNSettings;

import java.io.IOException;
Expand All @@ -33,20 +34,37 @@ public class KNNQuery extends Query {
@Getter
@Setter
private Query filterQuery;

public KNNQuery(String field, float[] queryVector, int k, String indexName) {
@Getter
private BitSetProducer parentsFilter;

public KNNQuery(
final String field,
final float[] queryVector,
final int k,
final String indexName,
final BitSetProducer parentsFilter
) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
this.parentsFilter = parentsFilter;
}

public KNNQuery(String field, float[] queryVector, int k, String indexName, Query filterQuery) {
public KNNQuery(
final String field,
final float[] queryVector,
final int k,
final String indexName,
final Query filterQuery,
final BitSetProducer parentsFilter
) {
this.field = field;
this.queryVector = queryVector;
this.k = k;
this.indexName = indexName;
this.filterQuery = filterQuery;
this.parentsFilter = parentsFilter;
}

public String getField() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,17 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
final Query filterQuery = getFilterQuery(createQueryRequest);

BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter();
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) {
log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k);
return new KNNQuery(fieldName, vector, k, indexName, filterQuery);
return new KNNQuery(fieldName, vector, k, indexName, filterQuery, parentFilter);
}
log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
return new KNNQuery(fieldName, vector, k, indexName);
return new KNNQuery(fieldName, vector, k, indexName, parentFilter);
}

log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
BitSetProducer parentFilter = createQueryRequest.context == null ? null : createQueryRequest.context.getParentFilter();
if (VectorDataType.BYTE == vectorDataType) {
return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter);
} else if (VectorDataType.FLOAT == vectorDataType) {
Expand Down Expand Up @@ -205,9 +205,6 @@ static class CreateQueryRequest {
private VectorDataType vectorDataType;
@Getter
private int k;
// can be null in cases filter not passed with the knn query
@Getter
private BitSetProducer parentFilter;
private QueryBuilder filter;
// can be null in cases filter not passed with the knn query
private QueryShardContext context;
Expand Down
41 changes: 25 additions & 16 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,28 @@ private int[] getFilterIdsArray(final LeafReaderContext context) throws IOExcept
if (filterWeight == null) {
return new int[0];
}
final BitSet filteredDocsBitSet = getFilteredDocsBitSet(context, this.filterWeight);
final int[] filteredIds = new int[filteredDocsBitSet.cardinality()];
int filteredIdsIndex = 0;
int docId = 0;
while (docId < filteredDocsBitSet.length()) {
docId = filteredDocsBitSet.nextSetBit(docId);
if (docId == DocIdSetIterator.NO_MORE_DOCS || docId + 1 == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
filteredIds[filteredIdsIndex] = docId;
filteredIdsIndex++;
docId++;
return bitSetToIntArray(getFilteredDocsBitSet(context, this.filterWeight));
}

private int[] getParentIdsArray(final LeafReaderContext context) throws IOException {
if (knnQuery.getParentsFilter() == null) {
return null;
}
return filteredIds;
return bitSetToIntArray(knnQuery.getParentsFilter().getBitSet(context));
}

private int[] bitSetToIntArray(final BitSet bitSet) {
final int cardinality = bitSet.cardinality();
final int[] intArray = new int[cardinality];
final BitSetIterator bitSetIterator = new BitSetIterator(bitSet, cardinality);
int index = 0;
int docId = bitSetIterator.nextDoc();
while (docId != DocIdSetIterator.NO_MORE_DOCS) {
assert index < intArray.length;
intArray[index++] = docId;
docId = bitSetIterator.nextDoc();
}
return intArray;
}

private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final int[] filterIdsArray) throws IOException {
Expand Down Expand Up @@ -265,13 +273,14 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final i
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
}

int[] parentIds = getParentIdsArray(context);
results = JNIService.queryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getK(),
knnEngine.getName(),
filterIdsArray
filterIdsArray,
parentIds
);

} catch (Exception e) {
Expand All @@ -296,7 +305,7 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final i
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) {
private Map<Integer, Float> doExactSearch(final LeafReaderContext leafReaderContext, final int[] filterIdsArray) throws IOException {
final SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader());
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
float[] queryVector = this.knnQuery.getQueryVector();
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,14 @@ public static long loadIndex(String indexPath, Map<String, Object> parameters, S
* @param filteredIds array of ints on which should be used for search.
* @return KNNQueryResult array of k neighbors
*/
public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, String engineName, int[] filteredIds) {
public static KNNQueryResult[] queryIndex(
long indexPointer,
float[] queryVector,
int k,
String engineName,
int[] filteredIds,
int[] parentIds
) {
if (KNNEngine.NMSLIB.getName().equals(engineName)) {
return NmslibService.queryIndex(indexPointer, queryVector, k);
}
Expand All @@ -112,9 +119,9 @@ public static KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector
// filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine
// normally.
if (ArrayUtils.isNotEmpty(filteredIds)) {
return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, null);
return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, parentIds);
}
return FaissService.queryIndex(indexPointer, queryVector, k, null);
return FaissService.queryIndex(indexPointer, queryVector, k, parentIds);
}
throw new IllegalArgumentException("QueryIndex not supported for provided engine");
}
Expand Down
35 changes: 26 additions & 9 deletions src/test/java/org/opensearch/knn/index/NestedSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.NestedKnnDocBuilder;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.K;
Expand Down Expand Up @@ -56,7 +53,7 @@ public final void cleanUp() {
}

@SneakyThrows
public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() {
public void testNestedSearchWithLucene_whenKIsTwo_thenReturnTwoResults() {
createKnnIndex(2, KNNEngine.LUCENE.getName());

String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
Expand All @@ -73,12 +70,32 @@ public void testNestedSearch_whenKIsTwo_thenReturnTwoResults() {

Float[] queryVector = { 1f, 1f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
String entity = EntityUtils.toString(response.getEntity());
assertEquals(2, parseHits(entity));
assertEquals(2, parseTotalSearchHits(entity));
}

@SneakyThrows
public void testNestedSearchWithFaiss_whenKIsTwo_thenReturnTwoResults() {
createKnnIndex(2, KNNEngine.FAISS.getName());

List<Object> hits = (List<Object>) ((Map<String, Object>) createParser(
MediaTypeRegistry.getDefaultMediaType().xContent(),
EntityUtils.toString(response.getEntity())
).map().get("hits")).get("hits");
assertEquals(2, hits.size());
String doc1 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.addVectors(FIELD_NAME_VECTOR, new Float[] { 1f, 1f }, new Float[] { 1f, 1f })
.build();
addKnnDoc(INDEX_NAME, "1", doc1);

String doc2 = NestedKnnDocBuilder.create(FIELD_NAME_NESTED)
.addVectors(FIELD_NAME_VECTOR, new Float[] { 2f, 2f }, new Float[] { 2f, 2f })
.build();
addKnnDoc(INDEX_NAME, "2", doc2);

refreshIndex(INDEX_NAME);

Float[] queryVector = { 1f, 1f };
Response response = queryNestedField(INDEX_NAME, 2, queryVector);
String entity = EntityUtils.toString(response.getEntity());
assertEquals(2, parseHits(entity));
assertEquals(2, parseTotalSearchHits(entity));
}

/**
Expand Down
17 changes: 12 additions & 5 deletions src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.index.mapper.MapperService;
Expand Down Expand Up @@ -162,14 +163,20 @@ public void testMultiFieldsKnnIndex(Codec codec) throws Exception {

// query to verify distance for each of the field
IndexSearcher searcher = new IndexSearcher(reader);
float score = searcher.search(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy"), 10).scoreDocs[0].score;
float score1 = searcher.search(new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy"), 10).scoreDocs[0].score;
float score = searcher.search(
new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null),
10
).scoreDocs[0].score;
float score1 = searcher.search(
new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy", (BitSetProducer) null),
10
).scoreDocs[0].score;
assertEquals(1.0f / (1 + 25), score, 0.01f);
assertEquals(1.0f / (1 + 169), score1, 0.01f);

// query to determine the hits
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy")));
assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy")));
assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null)));
assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy", (BitSetProducer) null)));

reader.close();
dir.close();
Expand Down Expand Up @@ -254,7 +261,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio
NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService);
float[] query = { 10.0f, 10.0f, 10.0f };
IndexSearcher searcher = new IndexSearcher(reader);
TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy"), 10);
TopDocs topDocs = searcher.search(new KNNQuery(fieldName, query, 4, "dummy", (BitSetProducer) null), 10);

assertEquals(3, topDocs.scoreDocs[0].doc);
assertEquals(2, topDocs.scoreDocs[1].doc);
Expand Down
Loading

0 comments on commit d0dc3f9

Please sign in to comment.