Skip to content

Commit

Permalink
Added separate interface for creating and writing in a faiss index to…
Browse files Browse the repository at this point in the history
… reduce memory footprint for faiss

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Apr 13, 2024
1 parent cee100f commit 7165079
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 45 deletions.
7 changes: 7 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ namespace knn_jni {
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ);

// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
// The index is serialized to indexPathJ.
long long CreateIndexIteratively(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jlong indexAddressJ, jobject parametersJ);

void writeIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env,jlong indexAddressJ, jstring indexPathJ, jobject parametersJ);

// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
Expand Down
17 changes: 17 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ extern "C" {
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexIteratively
* Signature: ([IJIJLjava/util/Map;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexIteratively
(JNIEnv *, jclass, jintArray, jlong, jint, jlong, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: writeIndex
* Signature: (JLjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex
(JNIEnv *, jclass, jlong, jstring, jobject);


/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand Down
98 changes: 98 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,104 @@ bool isIndexIVFPQL2(faiss::Index * index);
// IndexIDMap which has member that will point to underlying index that stores the data
faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index);

void knn_jni::faiss_wrapper::writeIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env,jlong indexAddressJ, jstring indexPathJ, jobject parametersJ) {
// parametersJ is a Java Map<String, Object>. ConvertJavaMapToCppMap converts it to a c++ map<string, jobject>
// so that it is easier to access.
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env,
parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);

std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
auto *idMap = reinterpret_cast<faiss::IndexIDMap *>((long long)indexAddressJ);
faiss::write_index(idMap, indexPathCpp.c_str());
// Deleting the internal index of the idMap index
delete idMap->index;
delete idMap;
}

long long knn_jni::faiss_wrapper::CreateIndexIteratively(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jlong indexAddressJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}

if (vectorsAddressJ <= 0) {
throw std::runtime_error("VectorsAddress cannot be less than 0");
}

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

// parametersJ is a Java Map<String, Object>. ConvertJavaMapToCppMap converts it to a c++ map<string, jobject>
// so that it is easier to access.
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

// Get space type for this index
jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE);
std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ));
faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp);

// Read vectors from memory address
auto *inputVectors = reinterpret_cast<std::vector<float>*>(vectorsAddressJ);
int dim = (int) dimJ;
// The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value
int numVectors = (int) (inputVectors->size() / (uint64_t) dim);
if (numVectors == 0) {
throw std::runtime_error("Number of vectors cannot be 0");
}

int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}
auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap *idMap = nullptr;
long indexAddress = (long) indexAddressJ;
if(indexAddress == 0) {
// Create faiss index
jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION);
std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ));
faiss::Index *indexWriter = faiss::index_factory(dim, indexDescriptionCpp.c_str(), metric);
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env,
parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}

// Add extra parameters that cant be configured with the index factory
if (parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) {
jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS];
auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ);
SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter);
jniUtil->DeleteLocalRef(env, subParametersJ);
}
jniUtil->DeleteLocalRef(env, parametersJ);

// Check that the index does not need to be trained
if (!indexWriter->is_trained) {
throw std::runtime_error("Index is not trained");
}
idMap = new faiss::IndexIDMap(indexWriter);
idMap->add_with_ids(numVectors, inputVectors->data(), idVector.data());
} else {
idMap = reinterpret_cast<faiss::IndexIDMap *>(indexAddress);
idMap->add_with_ids(numVectors, inputVectors->data(), idVector.data());
}
delete inputVectors;
return (long long)idMap;
}


void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
jstring indexPathJ, jobject parametersJ) {

Expand Down
24 changes: 24 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "faiss_wrapper.h"
#include "jni_util.h"
#include <iostream>

static knn_jni::JNIUtil jniUtil;
static const jint KNN_FAISS_JNI_VERSION = JNI_VERSION_1_1;
Expand Down Expand Up @@ -50,6 +51,29 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIE
}
}


JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexIteratively(JNIEnv * env, jclass cls, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ,
jlong indexAddressJ, jobject parametersJ)
{
try {
return (jlong)knn_jni::faiss_wrapper::CreateIndexIteratively(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddressJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return 0;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls,
jlong indexAddressJ, jstring indexPathJ, jobject parametersJ)
{
try {
knn_jni::faiss_wrapper::writeIndex(&jniUtil, env, indexAddressJ, indexPathJ, parametersJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls,
jintArray idsJ,
jlong vectorsAddressJ,
Expand Down
58 changes: 57 additions & 1 deletion jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
#include "test_util.h"
#include "faiss/IndexHNSW.h"
#include "faiss/IndexIVFPQ.h"
#include "faiss/utils/utils.h"

using ::testing::NiceMock;
using ::testing::Return;

float randomDataMin = -500.0;
float randomDataMax = 500.0;

TEST(FaissCreateIndexTest, BasicAssertions) {
TEST(FaissCreateIndexIterativelyTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
Expand All @@ -52,6 +53,61 @@ TEST(FaissCreateIndexTest, BasicAssertions) {
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors->size()));

// Create the index
long long indexAddress = knn_jni::faiss_wrapper::CreateIndexIteratively(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong) vectors, dim, (jlong)0,
(jobject)&parametersMap);
knn_jni::faiss_wrapper::writeIndex(
&mockJNIUtil, jniEnv,(jlong)indexAddress, (jstring)&indexPath,
(jobject)&parametersMap);

// Make sure index can be loaded
std::unique_ptr<faiss::Index> index(test_util::FaissLoadIndex(indexPath));
// Clean up
ids.clear();
ids.shrink_to_fit();
vectors->clear();
vectors->shrink_to_fit();
size_t mem_usage = faiss::get_mem_usage_kb() / (1 << 10);

std::cout<<"======Memory Usage:[" << mem_usage << "mb]======" << std::endl;
// Clean up
std::remove(indexPath.c_str());
}


TEST(FaissCreateIndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

std::string indexPath = test_util::RandomString(10, "/tmp/", ".faiss");
std::string spaceType = knn_jni::L2;
std::string index_description = "HNSW32,Flat";

std::unordered_map<std::string, jobject> parametersMap;
parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType;
parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&index_description;

// Set up jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
Expand Down
Loading

0 comments on commit 7165079

Please sign in to comment.