Skip to content

Commit

Permalink
Move free vectorAddress from Java to JNI layer to reduce the memory f…
Browse files Browse the repository at this point in the history
…ootprint for Nmslib

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Apr 9, 2024
1 parent badbb1d commit 6719c99
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 82 deletions.
9 changes: 8 additions & 1 deletion jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
}

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
Expand Down Expand Up @@ -221,7 +225,10 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
Expand Down
6 changes: 6 additions & 0 deletions jni/src/nmslib_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J
}
jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT);

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
//commons::freeVectorData(vectorsAddressJ);
delete inputVectors;

std::unique_ptr<similarity::Index<float>> index;
index.reset(similarity::MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceTypeCpp, *(space), dataset));
index->CreateIndex(similarity::AnyParams(indexParameters));
Expand Down
30 changes: 15 additions & 15 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ TEST(FaissCreateIndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors.reserve(dim * numIds);
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -55,12 +55,12 @@ TEST(FaissCreateIndexTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));
.WillRepeatedly(Return(vectors->size()));

// Create the index
knn_jni::faiss_wrapper::CreateIndex(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong) &vectors, dim , (jstring)&indexPath,
(jlong) vectors, dim , (jstring)&indexPath,
(jobject)&parametersMap);

// Make sure index can be loaded
Expand All @@ -74,13 +74,13 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 100;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors.reserve(dim * numIds);
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -99,15 +99,15 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));
.WillRepeatedly(Return(vectors->size()));

std::string spaceType = knn_jni::L2;
std::unordered_map<std::string, jobject> parametersMap;
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;

knn_jni::faiss_wrapper::CreateIndexFromTemplate(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong)&vectors, dim, (jstring)&indexPath,
(jlong)vectors, dim, (jstring)&indexPath,
reinterpret_cast<jbyteArray>(&(vectorIoWriter.data)),
(jobject) &parametersMap
);
Expand Down Expand Up @@ -480,13 +480,13 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors.reserve(dim * numIds);
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -505,12 +505,12 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));
.WillRepeatedly(Return(vectors->size()));

// Create the index
knn_jni::faiss_wrapper::CreateIndex(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong)&vectors, dim, (jstring)&indexPath,
(jlong)vectors, dim, (jstring)&indexPath,
(jobject)&parametersMap);

// Make sure index can be loaded
Expand Down
10 changes: 5 additions & 5 deletions jni/tests/nmslib_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ TEST(NmslibCreateIndexTest, BasicAssertions) {
// Define index data
int numIds = 100;
std::vector<int> ids;
std::vector<float> vectors;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors.reserve(dim * numIds);
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -67,7 +67,7 @@ TEST(NmslibCreateIndexTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));
.WillRepeatedly(Return(vectors->size()));

EXPECT_CALL(mockJNIUtil,
GetJavaIntArrayLength(jniEnv, reinterpret_cast<jintArray>(&ids)))
Expand All @@ -76,7 +76,7 @@ TEST(NmslibCreateIndexTest, BasicAssertions) {
// Create the index
knn_jni::nmslib_wrapper::CreateIndex(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong) &vectors, dim, (jstring)&indexPath,
(jlong) vectors, dim, (jstring)&indexPath,
(jobject)&parametersMap);

// Make sure index can be loaded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
Expand Down Expand Up @@ -111,67 +110,57 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
throws IOException {
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = null;
try {
pair = KNNCodecUtil.getFloats(values);
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the segment");
return;
}
long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
}
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
final KNNEngine knnEngine = getKNNEngine(field);
final String engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getVersion(),
field.name,
knnEngine.getExtension()
);
final String indexPath = Paths.get(
((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName
).toString();
KNNCodecUtil.Pair finalPair = pair;
NativeIndexCreator indexCreator;
// Create library index either from model or from scratch
if (field.attributes().containsKey(MODEL_ID)) {
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), finalPair, knnEngine, indexPath);
} else {
indexCreator = () -> createKNNIndexFromScratch(field, finalPair, knnEngine, indexPath);
}

if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the segment");
return;
}
long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
}
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
final KNNEngine knnEngine = getKNNEngine(field);
final String engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getVersion(),
field.name,
knnEngine.getExtension()
);
final String indexPath = Paths.get(
((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName
).toString();
NativeIndexCreator indexCreator;
// Create library index either from model or from scratch
if (field.attributes().containsKey(MODEL_ID)) {
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath);
} else {
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}

if (isRefresh) {
recordRefreshStats();
}
if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
}

// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
state.directory.createOutput(engineFileName, state.context).close();
indexCreator.createIndex();
writeFooter(indexPath, engineFileName);
} finally {
// Freeing up the Native memory where vectors was stored. We added a try block here to ensure that even
// in case of exceptions we are freeing up the space to avoid memory leaks.
if (pair != null) {
JNICommons.freeVectorData(pair.getVectorAddress());
}
if (isRefresh) {
recordRefreshStats();
}

// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
state.directory.createOutput(engineFileName, state.context).close();
indexCreator.createIndex();
writeFooter(indexPath, engineFileName);
}

private void recordMergeStats(int length, long arraySize) {
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/jni/FaissService.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ class FaissService {
}

/**
* Create an index for the native library
* Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the
* function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer
* created the memory address and that should only free up the memory. We are tracking the proper fix for this on this
* <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a>
*
* @param ids array of ids mapping to the data passed in
* @param vectorsAddress address of native memory where vectors are stored
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
public class JNIService {

/**
* Create an index for the native library
* Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the
* function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer
* created the memory address and that should only free up the memory. We are tracking the proper fix for this on this
* <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a>
*
* @param ids array of ids mapping to the data passed in
* @param vectorsAddress address of native memory where vectors are stored
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/jni/NmslibService.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class NmslibService {
}

/**
* Create an index for the native library
* Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the
* function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer
* created the memory address and that should only free up the memory. We are tracking the proper fix for this on this
* <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a>
*
* @param ids array of ids mapping to the data passed in
* @param vectorsAddress address of native memory where vectors are stored
Expand Down

0 comments on commit 6719c99

Please sign in to comment.