-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support of multi vector in jni (#1364)
Signed-off-by: Heemin Kim <[email protected]>
- Loading branch information
Showing
19 changed files
with
1,022 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 69 additions & 0 deletions
69
jni/include/knn_extension/faiss/MultiVectorResultCollector.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <faiss/impl/ResultCollector.h> | ||
#include <faiss/MetricType.h> | ||
#include "knn_extension/faiss/utils/BitSet.h" | ||
#include <unordered_map> | ||
|
||
namespace os_faiss { | ||
|
||
using idx_t = faiss::idx_t; | ||
/** | ||
* Implementation of ResultCollector to support multi vector | ||
* | ||
* Only supports HNSW algorithm | ||
* | ||
* Example: | ||
* When there is two lucene document with two nested fields, the parent_bit_set value of 100100 is provided where | ||
* parent doc ids are 2, and 5. Doc id for nested fields of parent document 2 are 0, and 1. Doc id for nested fields | ||
* of parent document 5 are 3, and 4. For faiss, only nested fields are stored. Therefore corresponding doc ids for | ||
* nested fields 0, 1, 3, 4 is 0, 1, 2, 3 in faiss. This mapping data is stored in id_map parameter. | ||
* | ||
* When collect method is called | ||
* 1. It switches from faiss id to lucene id and look for its parent id. | ||
* 2. See if the parent id already exist in heap using either parent_id_to_id or parent_id_to_index. | ||
* 3. If it does not exist, add the parent id and distance value in the heap(bh_ids, bh_val) and update parent_id_to_id, and parent_id_to_index. | ||
* 4. If it does exist, update the distance value(bh_val), parent_id_to_id, and parent_id_to_index. | ||
* | ||
* When post_process method is called | ||
* 1. Convert lucene parent ID to faiss doc ID using parent_id_to_id | ||
*/ | ||
struct MultiVectorResultCollector:faiss::ResultCollector { | ||
// BitSet of lucene parent doc ID | ||
const BitSet* parent_bit_set; | ||
|
||
// Mapping data from Faiss doc ID to Lucene doc ID | ||
const std::vector<int64_t>* id_map; | ||
|
||
// Lucene parent doc ID to to Faiss doc ID | ||
// Lucene parent doc ID to index in heap(bh_val, bh_ids) | ||
std::unordered_map<idx_t, idx_t> parent_id_to_id; | ||
std::unordered_map<idx_t, size_t> parent_id_to_index; | ||
MultiVectorResultCollector(const BitSet* parent_bit_set, const std::vector<int64_t>* id_map); | ||
|
||
/** | ||
* | ||
* @param k max size of bh_val, and bh_ids | ||
* @param nres number of results in bh_val, and bh_ids | ||
* @param bh_val binary heap storing values (For this case distance from query to result) | ||
* @param bh_ids binary heap storing document IDs | ||
* @param val a new value to add in bh_val | ||
* @param ids a new doc id to add in bh_ids | ||
*/ | ||
void collect( | ||
int k, | ||
int& nres, | ||
float* bh_val, | ||
int64_t* bh_ids, | ||
float val, | ||
int64_t ids) override; | ||
void post_process(int64_t nres, int64_t* bh_ids) override; | ||
}; | ||
|
||
} // namespace os_faiss | ||
|
26 changes: 26 additions & 0 deletions
26
jni/include/knn_extension/faiss/MultiVectorResultCollectorFactory.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <faiss/impl/ResultCollectorFactory.h> | ||
#include "knn_extension/faiss/utils/BitSet.h" | ||
|
||
namespace os_faiss { | ||
/** | ||
* Create MultiVectorResultCollector for single query request | ||
* | ||
* Creating new collector is required because MultiVectorResultCollector has instance variables | ||
* which should be isolated for each query. | ||
*/ | ||
struct MultiVectorResultCollectorFactory:faiss::ResultCollectorFactory { | ||
BitSet* parent_bit_set; | ||
|
||
MultiVectorResultCollectorFactory(BitSet* parent_bit_set); | ||
faiss::ResultCollector* new_collector() override; | ||
void delete_collector(faiss::ResultCollector* resultCollector) override; | ||
}; | ||
|
||
} // namespace os_faiss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <faiss/MetricType.h> | ||
#include <faiss/impl/platform_macros.h> | ||
#include <limits> | ||
|
||
using idx_t = faiss::idx_t; | ||
|
||
struct BitSet { | ||
const int NO_MORE_DOCS = std::numeric_limits<int>::max(); | ||
/** | ||
* Returns the index of the first set bit starting at the index specified. | ||
* NO_MORE_DOCS is returned if there are no more set bits. | ||
*/ | ||
virtual idx_t next_set_bit(idx_t index) const = 0; | ||
virtual ~BitSet() = default; | ||
}; | ||
|
||
|
||
/** | ||
* BitSet of fixed length (numBits), implemented using an array of unit64. | ||
* See https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java | ||
* | ||
* Here a block is 64 bit. However, for simplicity let's assume its size is 8 bits. | ||
* Then, if have an array of 3, 7, and 10, it will be represented in bitmap as follow. | ||
* [0] [1] | ||
* bitmap: 10001000 00000100 | ||
* | ||
* for next_set_bit call with 4 | ||
* 1. it looks for bitmap[0] | ||
* 2. bitmap[0] >> 4 | ||
* 3. count trailing zero of the result from step 2 which is 3 | ||
* 4. return 4(current index) + 3(result from step 3) | ||
*/ | ||
struct FixedBitSet : public BitSet { | ||
// Length of bitmap | ||
size_t numBits; | ||
|
||
// Pointer to an array of uint64_t | ||
// Using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h | ||
uint64_t* bitmap; | ||
|
||
FixedBitSet(const int* int_array, const int length); | ||
idx_t next_set_bit(idx_t index) const; | ||
~FixedBitSet(); | ||
}; |
Oops, something went wrong.