Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of multi vector in jni #1364

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,21 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S
set(FAISS_ENABLE_PYTHON OFF)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/faiss EXCLUDE_FROM_ALL)

add_library(${TARGET_LIB_FAISS} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp)
add_library(
${TARGET_LIB_FAISS} SHARED
${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/utils/BitSet.cpp
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollector.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp)
target_link_libraries(${TARGET_LIB_FAISS} faiss ${TARGET_LIB_COMMON} OpenMP::OpenMP_CXX)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE} ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss
${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss/utils
$ENV{JAVA_HOME}/include
$ENV{JAVA_HOME}/include/${JVM_OS_TYPE}
${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
set_target_properties(${TARGET_LIB_FAISS} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_FAISS} PROPERTIES POSITION_INDEPENDENT_CODE ON)

Expand Down Expand Up @@ -198,7 +210,12 @@ if ("${WIN32}" STREQUAL "")
jni_test
tests/faiss_wrapper_test.cpp
tests/nmslib_wrapper_test.cpp
tests/test_util.cpp)
tests/test_util.cpp
tests/knn_extension/faiss/utils/BitSetTest.cpp
tests/knn_extension/faiss/utils/HeapTest.cpp
tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp
tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp
)

target_link_libraries(
jni_test
Expand Down
5 changes: 2 additions & 3 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#define OPENSEARCH_KNN_FAISS_WRAPPER_H

#include "jni_util.h"

#include <jni.h>

namespace knn_jni {
Expand All @@ -38,13 +37,13 @@ namespace knn_jni {
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ);
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ);

// Execute a query against the index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ);
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
Expand Down
69 changes: 69 additions & 0 deletions jni/include/knn_extension/faiss/MultiVectorResultCollector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/impl/ResultCollector.h>
#include <faiss/MetricType.h>
#include "knn_extension/faiss/utils/BitSet.h"
#include <unordered_map>

namespace os_faiss {

using idx_t = faiss::idx_t;
/**
* Implementation of ResultCollector to support multi vector
*
* Only supports HNSW algorithm
*
* Example:
* When there is two lucene document with two nested fields, the parent_bit_set value of 100100 is provided where
* parent doc ids are 2, and 5. Doc id for nested fields of parent document 2 are 0, and 1. Doc id for nested fields
* of parent document 5 are 3, and 4. For faiss, only nested fields are stored. Therefore corresponding doc ids for
* nested fields 0, 1, 3, 4 is 0, 1, 2, 3 in faiss. This mapping data is stored in id_map parameter.
*
* When collect method is called
* 1. It switches from faiss id to lucene id and look for its parent id.
* 2. See if the parent id already exist in heap using either parent_id_to_id or parent_id_to_index.
* 3. If it does not exist, add the parent id and distance value in the heap(bh_ids, bh_val) and update parent_id_to_id, and parent_id_to_index.
* 4. If it does exist, update the distance value(bh_val), parent_id_to_id, and parent_id_to_index.
*
* When post_process method is called
* 1. Convert lucene parent ID to faiss doc ID using parent_id_to_id
*/
struct MultiVectorResultCollector:faiss::ResultCollector {
// BitSet of lucene parent doc ID
const BitSet* parent_bit_set;

// Mapping data from Faiss doc ID to Lucene doc ID
const std::vector<int64_t>* id_map;

// Lucene parent doc ID to to Faiss doc ID
// Lucene parent doc ID to index in heap(bh_val, bh_ids)
std::unordered_map<idx_t, idx_t> parent_id_to_id;
std::unordered_map<idx_t, size_t> parent_id_to_index;
MultiVectorResultCollector(const BitSet* parent_bit_set, const std::vector<int64_t>* id_map);

/**
*
* @param k max size of bh_val, and bh_ids
* @param nres number of results in bh_val, and bh_ids
* @param bh_val binary heap storing values (For this case distance from query to result)
* @param bh_ids binary heap storing document IDs
* @param val a new value to add in bh_val
* @param ids a new doc id to add in bh_ids
*/
void collect(
int k,
int& nres,
float* bh_val,
int64_t* bh_ids,
float val,
int64_t ids) override;
void post_process(int64_t nres, int64_t* bh_ids) override;
heemin32 marked this conversation as resolved.
Show resolved Hide resolved
};

} // namespace os_faiss

Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/impl/ResultCollectorFactory.h>
#include "knn_extension/faiss/utils/BitSet.h"

namespace os_faiss {
/**
* Create MultiVectorResultCollector for single query request
*
* Creating new collector is required because MultiVectorResultCollector has instance variables
* which should be isolated for each query.
*/
struct MultiVectorResultCollectorFactory:faiss::ResultCollectorFactory {
BitSet* parent_bit_set;

MultiVectorResultCollectorFactory(BitSet* parent_bit_set);
faiss::ResultCollector* new_collector() override;
void delete_collector(faiss::ResultCollector* resultCollector) override;
};

} // namespace os_faiss
51 changes: 51 additions & 0 deletions jni/include/knn_extension/faiss/utils/BitSet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/MetricType.h>
#include <faiss/impl/platform_macros.h>
#include <limits>

using idx_t = faiss::idx_t;
heemin32 marked this conversation as resolved.
Show resolved Hide resolved

struct BitSet {
const int NO_MORE_DOCS = std::numeric_limits<int>::max();
/**
* Returns the index of the first set bit starting at the index specified.
* NO_MORE_DOCS is returned if there are no more set bits.
*/
virtual idx_t next_set_bit(idx_t index) const = 0;
virtual ~BitSet() = default;
};


/**
* BitSet of fixed length (numBits), implemented using an array of unit64.
* See https://github.com/apache/lucene/blob/main/lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
*
* Here a block is 64 bit. However, for simplicity let's assume its size is 8 bits.
* Then, if have an array of 3, 7, and 10, it will be represented in bitmap as follow.
* [0] [1]
* bitmap: 10001000 00000100
*
* for next_set_bit call with 4
* 1. it looks for bitmap[0]
* 2. bitmap[0] >> 4
* 3. count trailing zero of the result from step 2 which is 3
* 4. return 4(current index) + 3(result from step 3)
*/
struct FixedBitSet : public BitSet {
// Length of bitmap
size_t numBits;

// Pointer to an array of uint64_t
// Using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h
uint64_t* bitmap;
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
heemin32 marked this conversation as resolved.
Show resolved Hide resolved

FixedBitSet(const int* int_array, const int length);
idx_t next_set_bit(idx_t index) const;
~FixedBitSet();
};
Loading
Loading