Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move free vectorAddress from Java to JNI layer to reduce the memory footprint for Nmslib. #1602

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
}

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
Expand Down Expand Up @@ -221,7 +225,10 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface *
auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data());

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
delete inputVectors;
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
Expand Down
6 changes: 6 additions & 0 deletions jni/src/nmslib_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J
}
jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT);

// Releasing the vectorsAddressJ memory as that is not required once we have created the index.
// This is not the ideal approach, please refer this gh issue for long term solution:
// https://github.com/opensearch-project/k-NN/issues/1600
//commons::freeVectorData(vectorsAddressJ);
delete inputVectors;

std::unique_ptr<similarity::Index<float>> index;
index.reset(similarity::MethodFactoryRegistry<float>::Instance().CreateMethod(false, "hnsw", spaceTypeCpp, *(space), dataset));
index->CreateIndex(similarity::AnyParams(indexParameters));
Expand Down
30 changes: 15 additions & 15 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ TEST(FaissCreateIndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors.reserve(dim * numIds);
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -55,12 +55,12 @@ TEST(FaissCreateIndexTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));
.WillRepeatedly(Return(vectors->size()));

// Create the index
knn_jni::faiss_wrapper::CreateIndex(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong) &vectors, dim , (jstring)&indexPath,
(jlong) vectors, dim , (jstring)&indexPath,
(jobject)&parametersMap);

// Make sure index can be loaded
Expand All @@ -74,13 +74,13 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 100;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors.reserve(dim * numIds);
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -99,15 +99,15 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));
.WillRepeatedly(Return(vectors->size()));

std::string spaceType = knn_jni::L2;
std::unordered_map<std::string, jobject> parametersMap;
parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType;

knn_jni::faiss_wrapper::CreateIndexFromTemplate(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong)&vectors, dim, (jstring)&indexPath,
(jlong)vectors, dim, (jstring)&indexPath,
reinterpret_cast<jbyteArray>(&(vectorIoWriter.data)),
(jobject) &parametersMap
);
Expand Down Expand Up @@ -480,13 +480,13 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors.reserve(dim * numIds);
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -505,12 +505,12 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));
.WillRepeatedly(Return(vectors->size()));

// Create the index
knn_jni::faiss_wrapper::CreateIndex(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong)&vectors, dim, (jstring)&indexPath,
(jlong)vectors, dim, (jstring)&indexPath,
(jobject)&parametersMap);

// Make sure index can be loaded
Expand Down
10 changes: 5 additions & 5 deletions jni/tests/nmslib_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ TEST(NmslibCreateIndexTest, BasicAssertions) {
// Define index data
int numIds = 100;
std::vector<int> ids;
std::vector<float> vectors;
auto *vectors = new std::vector<float>();
int dim = 2;
vectors.reserve(dim * numIds);
vectors->reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -67,7 +67,7 @@ TEST(NmslibCreateIndexTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors.size()));
.WillRepeatedly(Return(vectors->size()));

EXPECT_CALL(mockJNIUtil,
GetJavaIntArrayLength(jniEnv, reinterpret_cast<jintArray>(&ids)))
Expand All @@ -76,7 +76,7 @@ TEST(NmslibCreateIndexTest, BasicAssertions) {
// Create the index
knn_jni::nmslib_wrapper::CreateIndex(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong) &vectors, dim, (jstring)&indexPath,
(jlong) vectors, dim, (jstring)&indexPath,
(jobject)&parametersMap);

// Make sure index can be loaded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
Expand Down Expand Up @@ -111,67 +110,57 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
throws IOException {
// Get values to be indexed
BinaryDocValues values = valuesProducer.getBinary(field);
KNNCodecUtil.Pair pair = null;
try {
pair = KNNCodecUtil.getFloats(values);
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the segment");
return;
}
long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
}
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
final KNNEngine knnEngine = getKNNEngine(field);
final String engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getVersion(),
field.name,
knnEngine.getExtension()
);
final String indexPath = Paths.get(
((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName
).toString();
KNNCodecUtil.Pair finalPair = pair;
NativeIndexCreator indexCreator;
// Create library index either from model or from scratch
if (field.attributes().containsKey(MODEL_ID)) {
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), finalPair, knnEngine, indexPath);
} else {
indexCreator = () -> createKNNIndexFromScratch(field, finalPair, knnEngine, indexPath);
}

if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values);
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
logger.info("Skipping engine index creation as there are no vectors or docs in the segment");
return;
}
long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
if (isMerge) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length);
KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
}
// Increment counter for number of graph index requests
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
final KNNEngine knnEngine = getKNNEngine(field);
final String engineFileName = buildEngineFileName(
state.segmentInfo.name,
knnEngine.getVersion(),
field.name,
knnEngine.getExtension()
);
final String indexPath = Paths.get(
((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(),
engineFileName
).toString();
NativeIndexCreator indexCreator;
// Create library index either from model or from scratch
if (field.attributes().containsKey(MODEL_ID)) {
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath);
} else {
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
}

if (isRefresh) {
recordRefreshStats();
}
if (isMerge) {
recordMergeStats(pair.docs.length, arraySize);
}

// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
state.directory.createOutput(engineFileName, state.context).close();
indexCreator.createIndex();
writeFooter(indexPath, engineFileName);
} finally {
// Freeing up the Native memory where vectors was stored. We added a try block here to ensure that even
// in case of exceptions we are freeing up the space to avoid memory leaks.
if (pair != null) {
JNICommons.freeVectorData(pair.getVectorAddress());
}
if (isRefresh) {
recordRefreshStats();
}

// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
// not be marked as added to the directory.
state.directory.createOutput(engineFileName, state.context).close();
indexCreator.createIndex();
writeFooter(indexPath, engineFileName);
}

private void recordMergeStats(int length, long arraySize) {
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/jni/FaissService.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ class FaissService {
}

/**
* Create an index for the native library
* Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the
* function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer
* created the memory address and that should only free up the memory. We are tracking the proper fix for this on this
* <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a>
*
* @param ids array of ids mapping to the data passed in
* @param vectorsAddress address of native memory where vectors are stored
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
public class JNIService {

/**
* Create an index for the native library
* Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the
* function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer
* created the memory address and that should only free up the memory. We are tracking the proper fix for this on this
* <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a>
*
* @param ids array of ids mapping to the data passed in
* @param vectorsAddress address of native memory where vectors are stored
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/jni/NmslibService.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class NmslibService {
}

/**
* Create an index for the native library
* Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the
* function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer
* created the memory address and that should only free up the memory. We are tracking the proper fix for this on this
* <a href="https://github.com/opensearch-project/k-NN/issues/1600">issue</a>
*
* @param ids array of ids mapping to the data passed in
* @param vectorsAddress address of native memory where vectors are stored
Expand Down
Loading
Loading