From 36d65605c1dea8f5d41e2e21facdd98cadca1016 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Wed, 10 Apr 2024 14:04:05 -0700 Subject: [PATCH] Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices (#1604) Changes include: 1. Add the interface for streaming the vectors from java to jni layer with initial capacity (#1586) 2. Integrating storeVectors interfaces with createIndex and createIndexTemplate functions. (#1588) 3. Update KNN80BinaryDocValues reader count live docs and use live docs as initial capacity to initialize vector address(#1595) 4. Move free vectorAddress from Java to JNI layer to reduce the memory footprint for Nmslib (#1602) Signed-off-by: Navneet Verma --- CHANGELOG.md | 1 + jni/CMakeLists.txt | 3 +- jni/include/commons.h | 37 +++ jni/include/faiss_wrapper.h | 4 +- jni/include/jni_util.h | 4 + jni/include/nmslib_wrapper.h | 2 +- .../org_opensearch_knn_jni_FaissService.h | 24 +- .../org_opensearch_knn_jni_JNICommons.h | 40 ++++ .../org_opensearch_knn_jni_NmslibService.h | 4 +- jni/src/commons.cpp | 41 ++++ jni/src/faiss_wrapper.cpp | 57 +++-- jni/src/jni_util.cpp | 11 +- jni/src/nmslib_wrapper.cpp | 45 ++-- .../org_opensearch_knn_jni_FaissService.cpp | 39 +--- jni/src/org_opensearch_knn_jni_JNICommons.cpp | 60 +++++ .../org_opensearch_knn_jni_NmslibService.cpp | 4 +- jni/tests/commons_test.cpp | 73 ++++++ jni/tests/faiss_wrapper_test.cpp | 39 ++-- jni/tests/nmslib_wrapper_test.cpp | 15 +- jni/tests/test_util.cpp | 8 + jni/tests/test_util.h | 2 + .../knn/TransferVectorsBenchmarks.java | 24 +- .../opensearch/knn/common/KNNConstants.java | 2 + .../org/opensearch/knn/index/KNNSettings.java | 22 +- .../KNN80Codec/KNN80BinaryDocValues.java | 18 +- .../KNN80Codec/KNN80DocValuesConsumer.java | 48 ++-- .../KNN80Codec/KNN80DocValuesReader.java | 45 +++- .../knn/index/codec/util/KNNCodecUtil.java | 75 ++++-- .../index/memory/NativeMemoryAllocation.java | 3 +- .../org/opensearch/knn/jni/FaissService.java | 42 ++-- .../org/opensearch/knn/jni/JNICommons.java | 62 +++++ .../org/opensearch/knn/jni/JNIService.java | 65 +++--- .../org/opensearch/knn/jni/NmslibService.java | 10 +- .../plugin-metadata/plugin-security.policy | 1 + .../org/opensearch/knn/index/FaissIT.java | 104 +++++++++ .../KNN80DocValuesConsumerTests.java | 3 +- .../memory/NativeMemoryAllocationTests.java | 4 +- .../memory/NativeMemoryLoadStrategyTests.java | 4 +- .../opensearch/knn/jni/JNICommonsTest.java | 40 ++++ .../opensearch/knn/jni/JNIServiceTests.java | 217 ++++++++++-------- .../knn/training/TrainingJobTests.java | 10 +- .../java/org/opensearch/knn/TestUtils.java | 26 ++- 42 files changed, 970 insertions(+), 368 deletions(-) create mode 100644 jni/include/commons.h create mode 100644 jni/include/org_opensearch_knn_jni_JNICommons.h create mode 100644 jni/src/commons.cpp create mode 100644 jni/src/org_opensearch_knn_jni_JNICommons.cpp create mode 100644 jni/tests/commons_test.cpp create mode 100644 src/main/java/org/opensearch/knn/jni/JNICommons.java create mode 100644 src/test/java/org/opensearch/knn/jni/JNICommonsTest.java diff --git a/CHANGELOG.md b/CHANGELOG.md index e82376a6d..65e39b77b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) * Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) +* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604) ### Bug Fixes ### Infrastructure * Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583) diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 60321ed1b..4f32c87b9 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -61,7 +61,7 @@ endif() # ---------------------------------------------------------------------------- # ---------------------------------- COMMON ---------------------------------- -add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp) +add_library(${TARGET_LIB_COMMON} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/jni_util.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_JNICommons.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/commons.cpp) target_include_directories(${TARGET_LIB_COMMON} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE}) set_target_properties(${TARGET_LIB_COMMON} PROPERTIES SUFFIX ${LIB_EXT}) set_target_properties(${TARGET_LIB_COMMON} PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -236,6 +236,7 @@ if ("${WIN32}" STREQUAL "") tests/faiss_util_test.cpp tests/nmslib_wrapper_test.cpp tests/test_util.cpp + tests/commons_test.cpp ) target_link_libraries( diff --git a/jni/include/commons.h b/jni/include/commons.h new file mode 100644 index 000000000..05367a693 --- /dev/null +++ b/jni/include/commons.h @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +#include "jni_util.h" +#include +namespace knn_jni { + namespace commons { + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + + /** + * Free up the memory allocated for the data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)} + * + * @param memoryAddress address to be freed. + */ + void freeVectorData(jlong); + } +} diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 2629eea43..3e1adeac4 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -19,13 +19,13 @@ namespace knn_jni { namespace faiss_wrapper { // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. // The index is serialized to indexPathJ. - void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, + void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, jbyteArray templateIndexJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); // Load an index from indexPathJ into memory. diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index 52b08a202..b3d55f1c1 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -69,6 +69,9 @@ namespace knn_jni { virtual std::vector Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim) = 0; + virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect ) = 0; + virtual std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0; // -------------------------------------------------------------------------- @@ -164,6 +167,7 @@ namespace knn_jni { void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode); void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val); void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf); + void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); private: std::unordered_map cachedClasses; diff --git a/jni/include/nmslib_wrapper.h b/jni/include/nmslib_wrapper.h index 9b555580a..08494644f 100644 --- a/jni/include/nmslib_wrapper.h +++ b/jni/include/nmslib_wrapper.h @@ -25,7 +25,7 @@ namespace knn_jni { namespace nmslib_wrapper { // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. // The index is serialized to indexPathJ. - void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jobjectArray vectorsJ, + void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddress, jint dim, jstring indexPathJ, jobject parametersJ); // Load an index from indexPathJ into memory. Use parametersJ to set any query time parameters diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index ec1f46bc3..32b6f22f1 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -21,18 +21,18 @@ extern "C" { /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndex - * Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex - (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject); + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate - * Signature: ([I[[FLjava/lang/String;[BLjava/util/Map;)V + * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate - (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jbyteArray, jobject); + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); /* * Class: org_opensearch_knn_jni_FaissService @@ -122,22 +122,6 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors (JNIEnv *, jclass, jlong, jobjectArray); -/* - * Class: org_opensearch_knn_jni_FaissService - * Method: transferVectorsV2 - * Signature: (J[[F)J - */ -JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2 - (JNIEnv *, jclass, jlong, jobjectArray); - -/* - * Class: org_opensearch_knn_jni_FaissService - * Method: freeVectors - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors - (JNIEnv *, jclass, jlong); - #ifdef __cplusplus } #endif diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h new file mode 100644 index 000000000..d0758d7c8 --- /dev/null +++ b/jni/include/org_opensearch_knn_jni_JNICommons.h @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_opensearch_knn_jni_JNICommons */ + +#ifndef _Included_org_opensearch_knn_jni_JNICommons +#define _Included_org_opensearch_knn_jni_JNICommons +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_opensearch_knn_jni_JNICommons + * Method: storeVectorData + * Signature: (J[[FJJ) + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData + (JNIEnv *, jclass, jlong, jobjectArray, jlong); + +/* + * Class: org_opensearch_knn_jni_JNICommons + * Method: freeVectorData + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData + (JNIEnv *, jclass, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/jni/include/org_opensearch_knn_jni_NmslibService.h b/jni/include/org_opensearch_knn_jni_NmslibService.h index 02f58d20f..31422955f 100644 --- a/jni/include/org_opensearch_knn_jni_NmslibService.h +++ b/jni/include/org_opensearch_knn_jni_NmslibService.h @@ -21,10 +21,10 @@ extern "C" { /* * Class: org_opensearch_knn_jni_NmslibService * Method: createIndex - * Signature: ([I[[FLjava/lang/String;Ljava/util/Map;)V + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex - (JNIEnv *, jclass, jintArray, jobjectArray, jstring, jobject); + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); /* * Class: org_opensearch_knn_jni_NmslibService diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp new file mode 100644 index 000000000..3c03ac49d --- /dev/null +++ b/jni/src/commons.cpp @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +#ifndef OPENSEARCH_KNN_COMMONS_H +#define OPENSEARCH_KNN_COMMONS_H +#include + +#include + +#include "jni_util.h" +#include "commons.h" + +jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, + jobjectArray dataJ, jlong initialCapacityJ) { + std::vector *vect; + if ((long) memoryAddressJ == 0) { + vect = new std::vector(); + vect->reserve((long)initialCapacityJ); + } else { + vect = reinterpret_cast*>(memoryAddressJ); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, dataJ); + jniUtil->Convert2dJavaObjectArrayAndStoreToFloatVector(env, dataJ, dim, vect); + + return (jlong) vect; +} + +void knn_jni::commons::freeVectorData(jlong memoryAddressJ) { + if (memoryAddressJ != 0) { + auto *vect = reinterpret_cast*>(memoryAddressJ); + delete vect; + } +} +#endif //OPENSEARCH_KNN_COMMONS_H \ No newline at end of file diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index a7075740e..817bdb816 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -81,15 +81,19 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); -void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) { +void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } - if (vectorsJ == nullptr) { - throw std::runtime_error("Vectors cannot be null"); + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } if (indexPathJ == nullptr) { @@ -109,16 +113,20 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); - // Read data set - int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); } - int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); - auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); - // Create faiss index jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); @@ -148,22 +156,30 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, dataset.data(), idVector.data()); + idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); // Write the index to disk std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); faiss::write_index(&idMap, indexPathCpp.c_str()); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; } void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } - if (vectorsJ == nullptr) { - throw std::runtime_error("Vectors cannot be null"); + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } if (indexPathJ == nullptr) { @@ -183,15 +199,15 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil->DeleteLocalRef(env, parametersJ); // Read data set - int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); } - int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); - auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); - // Get vector of bytes from jbytearray int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); @@ -208,8 +224,11 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, dataset.data(), idVector.data()); - + idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; // Write the index to disk std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); faiss::write_index(&idMap, indexPathCpp.c_str()); diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index a0c1d5733..a1faa4894 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -223,6 +223,13 @@ int knn_jni::JNIUtil::ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ std::vector knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim) { + std::vector vect; + Convert2dJavaObjectArrayAndStoreToFloatVector(env, array2dJ, dim, &vect); + return vect; +} + +void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect) { if (array2dJ == nullptr) { throw std::runtime_error("Array cannot be null"); @@ -231,7 +238,6 @@ std::vector knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JN int numVectors = env->GetArrayLength(array2dJ); this->HasExceptionInStack(env); - std::vector floatVectorCpp; for (int i = 0; i < numVectors; ++i) { auto vectorArray = (jfloatArray)env->GetObjectArrayElement(array2dJ, i); this->HasExceptionInStack(env, "Unable to get object array element"); @@ -247,13 +253,12 @@ std::vector knn_jni::JNIUtil::Convert2dJavaObjectArrayToCppFloatVector(JN } for(int j = 0; j < dim; ++j) { - floatVectorCpp.push_back(vector[j]); + vect->push_back(vector[j]); } env->ReleaseFloatArrayElements(vectorArray, vector, JNI_ABORT); } this->HasExceptionInStack(env); env->DeleteLocalRef(array2dJ); - return floatVectorCpp; } std::vector knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) { diff --git a/jni/src/nmslib_wrapper.cpp b/jni/src/nmslib_wrapper.cpp index f63fd2b01..6ea80d727 100644 --- a/jni/src/nmslib_wrapper.cpp +++ b/jni/src/nmslib_wrapper.cpp @@ -32,14 +32,19 @@ std::string TranslateSpaceType(const std::string& spaceType); const similarity::LabelType DEFAULT_LABEL = -1; void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) { + jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } - if (vectorsJ == nullptr) { - throw std::runtime_error("Vectors cannot be null"); + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } if (indexPathJ == nullptr) { @@ -91,12 +96,18 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J space.reset(similarity::SpaceFactoryRegistry::Instance().CreateSpace(spaceTypeCpp,similarity::AnyParams())); // Get number of ids and vectors and dimension - int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ); + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) ( inputVectors->size() / (uint64_t) dim); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); if (numIds != numVectors) { throw std::runtime_error("Number of IDs does not match number of vectors"); } - int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); // Read dataset similarity::ObjectVector dataset; @@ -105,10 +116,12 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J try { // Read in data set idsCpp = jniUtil->GetIntArrayElements(env, idsJ, nullptr); - - float* floatArrayCpp; - jfloatArray floatArrayJ; size_t vectorSizeInBytes = dim*sizeof(float); + // vectorPointer needs to be unsigned long long, this will ensure that out of range doesn't happen for this pointer + // when the values of numVectors * dim becomes very large. + // Example: for 10M vectors of 1536 dim vectorPointer max value will be ~15.3B which is already > range of ints. + // keeping it unsigned long long we will never go above the range. + unsigned long long vectorPointer = 0; // Allocate a large buffer that will contain all the vectors. Allocating the objects in one large buffer as // opposed to individually will prevent heap fragmentation. We have observed that allocating individual @@ -134,18 +147,18 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, J memcpy(ptr, &vectorSizeInBytes, similarity::DATALENGTH_SIZE); ptr += similarity::DATALENGTH_SIZE; - floatArrayJ = (jfloatArray)jniUtil->GetObjectArrayElement(env, vectorsJ, i); - if (dim != jniUtil->GetJavaFloatArrayLength(env, floatArrayJ)) { - throw std::runtime_error("Dimension of vectors is inconsistent"); - } - - floatArrayCpp = jniUtil->GetFloatArrayElements(env, floatArrayJ, nullptr); - memcpy(ptr, floatArrayCpp, vectorSizeInBytes); - jniUtil->ReleaseFloatArrayElements(env, floatArrayJ, floatArrayCpp, JNI_ABORT); + memcpy(ptr, &(inputVectors->at(vectorPointer)), vectorSizeInBytes); ptr += vectorSizeInBytes; + vectorPointer += dim; } jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + //commons::freeVectorData(vectorsAddressJ); + delete inputVectors; + std::unique_ptr> index; index.reset(similarity::MethodFactoryRegistry::Instance().CreateMethod(false, "hnsw", spaceTypeCpp, *(space), dataset)); index->CreateIndex(similarity::AnyParams(indexParameters)); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 3d9624c25..3249ed872 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -13,7 +13,6 @@ #include -#include #include #include "faiss_wrapper.h" @@ -41,11 +40,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { } JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, - jobject parametersJ) + jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsJ, indexPathJ, parametersJ); + knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -53,13 +52,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIE JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls, jintArray idsJ, - jobjectArray vectorsJ, + jlong vectorsAddressJ, + jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsJ, indexPathJ, templateIndexJ, parametersJ); + knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -190,30 +190,3 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors return (jlong) vect; } - -JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectorsV2(JNIEnv * env, jclass cls, -jlong vectorsPointerJ, - jobjectArray vectorsJ) -{ - std::vector *vect; - if ((long) vectorsPointerJ == 0) { - vect = new std::vector; - } else { - vect = reinterpret_cast*>(vectorsPointerJ); - } - - int dim = jniUtil.GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ); - auto dataset = jniUtil.Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim); - vect->insert(vect->end(), dataset.begin(), dataset.end()); - - return (jlong) vect; -} - -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors(JNIEnv * env, jclass cls, - jlong vectorsPointerJ) -{ - if (vectorsPointerJ != 0) { - auto *vect = reinterpret_cast*>(vectorsPointerJ); - delete vect; - } -} diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp new file mode 100644 index 000000000..ccdd11882 --- /dev/null +++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#include "org_opensearch_knn_jni_JNICommons.h" + +#include +#include "commons.h" +#include "jni_util.h" + +static knn_jni::JNIUtil jniUtil; +static const jint KNN_JNICOMMONS_JNI_VERSION = JNI_VERSION_1_1; + +jint JNI_OnLoad(JavaVM* vm, void* reserved) { + // Obtain the JNIEnv from the VM and confirm JNI_VERSION + JNIEnv* env; + if (vm->GetEnv((void**)&env, KNN_JNICOMMONS_JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + + jniUtil.Initialize(env); + + return KNN_JNICOMMONS_JNI_VERSION; +} + +void JNI_OnUnload(JavaVM *vm, void *reserved) { + JNIEnv* env; + vm->GetEnv((void**)&env, KNN_JNICOMMONS_JNI_VERSION); + jniUtil.Uninitialize(env); +} + + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData(JNIEnv * env, jclass cls, +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) + +{ + try { + return knn_jni::commons::storeVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (long)memoryAddressJ; +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNIEnv * env, jclass cls, + jlong memoryAddressJ) +{ + try { + return knn_jni::commons::freeVectorData(memoryAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} diff --git a/jni/src/org_opensearch_knn_jni_NmslibService.cpp b/jni/src/org_opensearch_knn_jni_NmslibService.cpp index 11dd885b1..d037d3337 100644 --- a/jni/src/org_opensearch_knn_jni_NmslibService.cpp +++ b/jni/src/org_opensearch_knn_jni_NmslibService.cpp @@ -38,11 +38,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { } JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jobjectArray vectorsJ, jstring indexPathJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) { try { - knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsJ, indexPathJ, parametersJ); + knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/commons_test.cpp b/jni/tests/commons_test.cpp new file mode 100644 index 000000000..09323f0fb --- /dev/null +++ b/jni/tests/commons_test.cpp @@ -0,0 +1,73 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + + +#include "test_util.h" +#include +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "jni_util.h" +#include "commons.h" + +TEST(CommonsTests, BasicAssertions) { + long dim = 3; + long totalNumberOfVector = 5; + std::vector> data; + for(int i = 0 ; i < totalNumberOfVector - 1 ; i++) { + std::vector vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((float)j); + } + data.push_back(vector); + } + JNIEnv *jniEnv = nullptr; + + testing::NiceMock mockJNIUtil; + + jlong memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, (jlong)0, + reinterpret_cast(&data), (jlong)(totalNumberOfVector * dim)); + ASSERT_NE(memoryAddress, 0); + auto *vect = reinterpret_cast*>(memoryAddress); + ASSERT_EQ(vect->size(), data.size() * dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + // Check by inserting more vectors at same memory location + jlong oldMemoryAddress = memoryAddress; + std::vector> data2; + std::vector vector; + for(int j = 0 ; j < dim ; j ++) { + vector.push_back((float)j); + } + data2.push_back(vector); + memoryAddress = knn_jni::commons::storeVectorData(&mockJNIUtil, jniEnv, memoryAddress, + reinterpret_cast(&data2), (jlong)(totalNumberOfVector * dim)); + ASSERT_NE(memoryAddress, 0); + ASSERT_EQ(memoryAddress, oldMemoryAddress); + vect = reinterpret_cast*>(memoryAddress); + int currentIndex = 0; + ASSERT_EQ(vect->size(), totalNumberOfVector*dim); + ASSERT_EQ(vect->capacity(), totalNumberOfVector * dim); + + // Validate if all vectors data are at correct location + for(auto & i : data) { + for(float j : i) { + ASSERT_FLOAT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } + + for(auto & i : data2) { + for(float j : i) { + ASSERT_FLOAT_EQ(vect->at(currentIndex), j); + currentIndex++; + } + } +} diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 2b1684cfb..05854f7ed 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -30,17 +30,14 @@ TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; - std::vector> vectors; + auto *vectors = new std::vector(); int dim = 2; + vectors->reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); - - std::vector vect; - vect.reserve(dim); for (int j = 0; j < dim; ++j) { - vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); } - vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); @@ -58,12 +55,12 @@ TEST(FaissCreateIndexTest, BasicAssertions) { EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors.size())); + .WillRepeatedly(Return(vectors->size())); // Create the index knn_jni::faiss_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - reinterpret_cast(&vectors), (jstring)&indexPath, + (jlong) vectors, dim , (jstring)&indexPath, (jobject)¶metersMap); // Make sure index can be loaded @@ -77,17 +74,14 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 100; std::vector ids; - std::vector> vectors; + auto *vectors = new std::vector(); int dim = 2; + vectors->reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); - - std::vector vect; - vect.reserve(dim); for (int j = 0; j < dim; ++j) { - vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); } - vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); @@ -105,7 +99,7 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors.size())); + .WillRepeatedly(Return(vectors->size())); std::string spaceType = knn_jni::L2; std::unordered_map parametersMap; @@ -113,7 +107,7 @@ TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { knn_jni::faiss_wrapper::CreateIndexFromTemplate( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - reinterpret_cast(&vectors), (jstring)&indexPath, + (jlong)vectors, dim, (jstring)&indexPath, reinterpret_cast(&(vectorIoWriter.data)), (jobject) ¶metersMap ); @@ -486,17 +480,14 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; - std::vector> vectors; + auto *vectors = new std::vector(); int dim = 2; + vectors->reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); - - std::vector vect; - vect.reserve(dim); for (int j = 0; j < dim; ++j) { - vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); } - vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); @@ -514,12 +505,12 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors.size())); + .WillRepeatedly(Return(vectors->size())); // Create the index knn_jni::faiss_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - reinterpret_cast(&vectors), (jstring)&indexPath, + (jlong)vectors, dim, (jstring)&indexPath, (jobject)¶metersMap); // Make sure index can be loaded diff --git a/jni/tests/nmslib_wrapper_test.cpp b/jni/tests/nmslib_wrapper_test.cpp index 3a21e7401..1fd9471b0 100644 --- a/jni/tests/nmslib_wrapper_test.cpp +++ b/jni/tests/nmslib_wrapper_test.cpp @@ -39,17 +39,14 @@ TEST(NmslibCreateIndexTest, BasicAssertions) { // Define index data int numIds = 100; std::vector ids; - std::vector> vectors; + auto *vectors = new std::vector(); int dim = 2; - for (int i = 0; i < numIds; ++i) { + vectors->reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); - - std::vector vect; - vect.reserve(dim); for (int j = 0; j < dim; ++j) { - vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); } - vectors.push_back(vect); } std::string indexPath = test_util::RandomString(10, "tmp/", ".nmslib"); @@ -70,7 +67,7 @@ TEST(NmslibCreateIndexTest, BasicAssertions) { EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors.size())); + .WillRepeatedly(Return(vectors->size())); EXPECT_CALL(mockJNIUtil, GetJavaIntArrayLength(jniEnv, reinterpret_cast(&ids))) @@ -79,7 +76,7 @@ TEST(NmslibCreateIndexTest, BasicAssertions) { // Create the index knn_jni::nmslib_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - reinterpret_cast(&vectors), (jstring)&indexPath, + (jlong) vectors, dim, (jstring)&indexPath, (jobject)¶metersMap); // Make sure index can be loaded diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index 89b19f9aa..92532b9e2 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -45,6 +45,14 @@ test_util::MockJNIUtil::MockJNIUtil() { return data; }); + ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToFloatVector) + .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector* data) { + for (const auto &v : + (*reinterpret_cast> *>(array2dJ))) + for (auto item : v) data->push_back(item); + }); + + // arrayJ is re-interpreted as std::vector * ON_CALL(*this, ConvertJavaIntArrayToCppIntVector) .WillByDefault([this](JNIEnv *env, jintArray arrayJ) { diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index 1e32ad3c3..8e73a8ab0 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -44,6 +44,8 @@ namespace test_util { // TODO: Figure out why this cant use "new" MOCK_METHOD MOCK_METHOD(std::vector, Convert2dJavaObjectArrayToCppFloatVector, (JNIEnv * env, jobjectArray array2dJ, int dim)); + MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToFloatVector, + (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); MOCK_METHOD(std::vector, ConvertJavaIntArrayToCppIntVector, (JNIEnv * env, jintArray arrayJ)); MOCK_METHOD2(ConvertJavaMapToCppMap, diff --git a/micro-benchmarks/src/main/java/org/opensearch/knn/TransferVectorsBenchmarks.java b/micro-benchmarks/src/main/java/org/opensearch/knn/TransferVectorsBenchmarks.java index ad1076484..2bce54ee6 100644 --- a/micro-benchmarks/src/main/java/org/opensearch/knn/TransferVectorsBenchmarks.java +++ b/micro-benchmarks/src/main/java/org/opensearch/knn/TransferVectorsBenchmarks.java @@ -23,7 +23,7 @@ import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; -import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.jni.JNICommons; import java.util.ArrayList; import java.util.List; @@ -42,9 +42,9 @@ @State(Scope.Benchmark) public class TransferVectorsBenchmarks { private static final Random random = new Random(1212121212); - private static final int TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED = 1000000; + private static final long TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED = 1000000; - @Param({ "128", "256", "384", "512" }) + @Param({ "128", "256", "384", "512", "960", "1024", "1536" }) private int dimension; @Param({ "100000", "500000", "1000000" }) @@ -61,20 +61,30 @@ public void setup() { } @Benchmark - public void transferVectors() { + public void transferVectors_withCapacity() { long vectorsAddress = 0; List vectorToTransfer = new ArrayList<>(); + long startingIndex = 0; for (float[] floats : vectorList) { if (vectorToTransfer.size() == vectorsPerTransfer) { - vectorsAddress = JNIService.transferVectorsV2(vectorsAddress, vectorToTransfer.toArray(new float[][] {})); + vectorsAddress = JNICommons.storeVectorData( + vectorsAddress, + vectorToTransfer.toArray(new float[][] {}), + dimension * TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED + ); + startingIndex += vectorsPerTransfer; vectorToTransfer = new ArrayList<>(); } vectorToTransfer.add(floats); } if (!vectorToTransfer.isEmpty()) { - vectorsAddress = JNIService.transferVectorsV2(vectorsAddress, vectorToTransfer.toArray(new float[][] {})); + vectorsAddress = JNICommons.storeVectorData( + vectorsAddress, + vectorToTransfer.toArray(new float[][] {}), + dimension * TOTAL_NUMBER_OF_VECTOR_TO_BE_TRANSFERRED + ); } - JNIService.freeVectors(vectorsAddress); + JNICommons.freeVectorData(vectorsAddress); } private float[] generateRandomVector(int dimensions) { diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 46d82b2bf..6c92afabc 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -72,6 +72,7 @@ public class KNNConstants { // nmslib specific constants public static final String NMSLIB_NAME = "nmslib"; + public static final String COMMONS_NAME = "common"; public static final String SPACE_TYPE = "spaceType"; // used as field info key public static final String HNSW_ALGO_M = "M"; public static final String HNSW_ALGO_EF_CONSTRUCTION = "efConstruction"; @@ -120,6 +121,7 @@ public class KNNConstants { public static final String FAISS_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME; public static final String FAISS_AVX2_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME + "_avx2"; public static final String NMSLIB_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + NMSLIB_NAME; + public static final String COMMON_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + COMMONS_NAME; // Filtered Search Constants // Please refer this github issue for more details for choosing this value: diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 9ac0bf216..8d0ad91e0 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -68,6 +68,7 @@ public class KNNSettings { public static final String KNN_ALGO_PARAM_INDEX_THREAD_QTY = "knn.algo_param.index_thread_qty"; public static final String KNN_MEMORY_CIRCUIT_BREAKER_ENABLED = "knn.memory.circuit_breaker.enabled"; public static final String KNN_MEMORY_CIRCUIT_BREAKER_LIMIT = "knn.memory.circuit_breaker.limit"; + public static final String KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB = "knn.vector_streaming_memory.limit"; public static final String KNN_CIRCUIT_BREAKER_TRIGGERED = "knn.circuit_breaker.triggered"; public static final String KNN_CACHE_ITEM_EXPIRY_ENABLED = "knn.cache.item.expiry.enabled"; public static final String KNN_CACHE_ITEM_EXPIRY_TIME_MINUTES = "knn.cache.item.expiry.minutes"; @@ -93,6 +94,7 @@ public class KNNSettings { public static final Integer KNN_DEFAULT_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 10; // By default, set aside 10% of the JVM for the limit public static final Integer KNN_MAX_MODEL_CACHE_SIZE_LIMIT_PERCENTAGE = 25; // Model cache limit cannot exceed 25% of the JVM heap public static final String KNN_DEFAULT_MEMORY_CIRCUIT_BREAKER_LIMIT = "50%"; + public static final String KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT = "1%"; public static final Integer ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE = -1; @@ -100,6 +102,15 @@ public class KNNSettings { * Settings Definition */ + // This setting controls how much memory should be used to transfer vectors from Java to JNI Layer. The default + // 1% of the JVM heap + public static final Setting KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING = Setting.memorySizeSetting( + KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB, + KNN_DEFAULT_VECTOR_STREAMING_MEMORY_LIMIT_PCT, + Setting.Property.Dynamic, + Setting.Property.NodeScope + ); + public static final Setting INDEX_KNN_SPACE_TYPE = Setting.simpleString( KNN_SPACE_TYPE, INDEX_KNN_DEFAULT_SPACE_TYPE, @@ -354,6 +365,10 @@ private Setting getSetting(String key) { return KNN_FAISS_AVX2_DISABLED_SETTING; } + if (KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB.equals(key)) { + return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -371,7 +386,8 @@ public List> getSettings() { MODEL_INDEX_NUMBER_OF_REPLICAS_SETTING, MODEL_CACHE_SIZE_LIMIT_SETTING, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, - KNN_FAISS_AVX2_DISABLED_SETTING + KNN_FAISS_AVX2_DISABLED_SETTING, + KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING ); return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList()); } @@ -475,6 +491,10 @@ public void onFailure(Exception e) { }); } + public static ByteSizeValue getVectorStreamingMemoryLimit() { + return KNNSettings.state().getSettingValue(KNN_VECTOR_STREAMING_MEMORY_LIMIT_IN_MB); + } + /** * * @param index Name of the index diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java index 832737a6d..df26766b3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80BinaryDocValues.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.codec.KNN80Codec; +import lombok.Getter; import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocIDMerger; @@ -15,10 +16,13 @@ /** * A per-document kNN numeric value. */ -class KNN80BinaryDocValues extends BinaryDocValues { +public class KNN80BinaryDocValues extends BinaryDocValues { private DocIDMerger docIDMerger; + @Getter + private long totalLiveDocs; + KNN80BinaryDocValues(DocIDMerger docIdMerger) { this.docIDMerger = docIdMerger; } @@ -61,4 +65,14 @@ public long cost() { public BytesRef binaryValue() throws IOException { return current.getValues().binaryValue(); } -}; + + /** + * Builder pattern like setter for setting totalLiveDocs. We can use setter also. But this way the code is clean. + * @param totalLiveDocs int + * @return {@link KNN80BinaryDocValues} + */ + public KNN80BinaryDocValues setTotalLiveDocs(long totalLiveDocs) { + this.totalLiveDocs = totalLiveDocs; + return this; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 804c4f7bb..096df817a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -109,11 +109,11 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); - if (pair.vectors.length == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the documents"); + if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { + logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); return; } - long arraySize = calculateArraySize(pair.vectors, pair.serializationMode); + long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); if (isMerge) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); @@ -121,31 +121,27 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, } // Increment counter for number of graph index requests KNNCounter.GRAPH_INDEX_REQUESTS.increment(); - // Create library index either from model or from scratch - String engineFileName; - String indexPath; - NativeIndexCreator indexCreator; final KNNEngine knnEngine = getKNNEngine(field); + final String engineFileName = buildEngineFileName( + state.segmentInfo.name, + knnEngine.getVersion(), + field.name, + knnEngine.getExtension() + ); + final String indexPath = Paths.get( + ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), + engineFileName + ).toString(); + NativeIndexCreator indexCreator; + // Create library index either from model or from scratch if (field.attributes().containsKey(MODEL_ID)) { - String modelId = field.attributes().get(MODEL_ID); Model model = ModelCache.getInstance().get(modelId); - - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getVersion(), field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) - .toString(); - if (model.getModelBlob() == null) { - throw new RuntimeException("There is no trained model with id \"" + modelId + "\""); + throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); } else { - - engineFileName = buildEngineFileName(state.segmentInfo.name, knnEngine.getVersion(), field.name, knnEngine.getExtension()); - indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName) - .toString(); - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } @@ -184,7 +180,15 @@ private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KN KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndexFromTemplate(pair.docs, pair.vectors, indexPath, model, parameters, knnEngine); + JNIService.createIndexFromTemplate( + pair.docs, + pair.getVectorAddress(), + pair.getDimension(), + indexPath, + model, + parameters, + knnEngine + ); return null; }); } @@ -223,7 +227,7 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa // Pass the path for the nms library to save the file AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.vectors, indexPath, parameters, knnEngine); + JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); return null; }); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java index ccfaa68fc..16380c5d9 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesReader.java @@ -5,6 +5,10 @@ package org.opensearch.knn.index.codec.KNN80Codec; +import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.Bits; +import org.opensearch.common.StopWatch; import org.opensearch.knn.index.codec.util.BinaryDocValuesSub; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.index.BinaryDocValues; @@ -14,12 +18,14 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; +import java.io.IOException; import java.util.ArrayList; import java.util.List; /** * Reader for KNNDocValues from the segments */ +@Log4j2 class KNN80DocValuesReader extends EmptyDocValuesProducer { private final MergeState mergeState; @@ -30,6 +36,7 @@ class KNN80DocValuesReader extends EmptyDocValuesProducer { @Override public BinaryDocValues getBinary(FieldInfo field) { + long totalLiveDocs = 0; try { List subs = new ArrayList<>(this.mergeState.docValuesProducers.length); for (int i = 0; i < this.mergeState.docValuesProducers.length; i++) { @@ -41,13 +48,49 @@ public BinaryDocValues getBinary(FieldInfo field) { values = docValuesProducer.getBinary(readerFieldInfo); } if (values != null) { + totalLiveDocs = totalLiveDocs + getLiveDocsCount(values, this.mergeState.liveDocs[i]); + // docValues will be consumed when liveDocs are not null, hence resetting the docsValues + // pointer. + values = this.mergeState.liveDocs[i] != null ? docValuesProducer.getBinary(readerFieldInfo) : values; + subs.add(new BinaryDocValuesSub(mergeState.docMaps[i], values)); } } } - return new KNN80BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort)); + return new KNN80BinaryDocValues(DocIDMerger.of(subs, mergeState.needsIndexSort)).setTotalLiveDocs(totalLiveDocs); } catch (Exception e) { throw new RuntimeException(e); } } + + /** + * This function return the liveDocs count present in the BinaryDocValues. If the liveDocsBits is null, then we + * can use {@link BinaryDocValues#cost()} function to get max docIds. But if LiveDocsBits is not null, then we + * iterate over the BinaryDocValues and validate if the docId is present in the live docs bits or not. + * + * @param binaryDocValues {@link BinaryDocValues} + * @param liveDocsBits {@link Bits} + * @return total number of liveDocs. + * @throws IOException + */ + private long getLiveDocsCount(final BinaryDocValues binaryDocValues, final Bits liveDocsBits) throws IOException { + long liveDocs = 0; + if (liveDocsBits != null) { + int docId; + // This is not the right way to log the time. I create a github issue for adding an annotation to track + // the time. https://github.com/opensearch-project/k-NN/issues/1594 + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + for (docId = binaryDocValues.nextDoc(); docId != DocIdSetIterator.NO_MORE_DOCS; docId = binaryDocValues.nextDoc()) { + if (liveDocsBits.get(docId)) { + liveDocs++; + } + } + stopWatch.stop(); + log.debug("Time taken to iterate over binary doc values: {} ms", stopWatch.totalTime().millis()); + } else { + liveDocs = binaryDocValues.cost(); + } + return liveDocs; + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 02ab2d833..c5ae469e0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -5,18 +5,22 @@ package org.opensearch.knn.index.codec.util; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; +import org.opensearch.knn.jni.JNICommons; import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.ArrayList; +import java.util.List; public class KNNCodecUtil { - - public static final String HNSW_EXTENSION = ".hnsw"; - public static final String HNSW_COMPOUND_EXTENSION = ".hnswc"; // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; // References to objects are 4 bytes in size @@ -26,42 +30,63 @@ public class KNNCodecUtil { // Java rounds each array size up to multiples of 8 bytes public static final int JAVA_ROUNDING_NUMBER = 8; + @AllArgsConstructor public static final class Pair { - public Pair(int[] docs, float[][] vectors, SerializationMode serializationMode) { - this.docs = docs; - this.vectors = vectors; - this.serializationMode = serializationMode; - } - public int[] docs; - public float[][] vectors; + @Getter + @Setter + private long vectorAddress; + @Getter + @Setter + private int dimension; public SerializationMode serializationMode; + } public static KNNCodecUtil.Pair getFloats(BinaryDocValues values) throws IOException { - ArrayList vectorList = new ArrayList<>(); - ArrayList docIdList = new ArrayList<>(); + List vectorList = new ArrayList<>(); + List docIdList = new ArrayList<>(); + long vectorAddress = 0; + int dimension = 0; SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; + + long totalLiveDocs = getTotalLiveDocsCount(values); + long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes(); + long vectorsPerTransfer = Integer.MIN_VALUE; + for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { BytesRef bytesref = values.binaryValue(); try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { serializationMode = KNNVectorSerializerFactory.serializerModeFromStream(byteStream); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + dimension = vector.length; + + if (vectorsPerTransfer == Integer.MIN_VALUE) { + vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; + } + if (vectorList.size() == vectorsPerTransfer) { + vectorAddress = JNICommons.storeVectorData( + vectorAddress, + vectorList.toArray(new float[][] {}), + totalLiveDocs * dimension + ); + // We should probably come up with a better way to reuse the vectorList memory which we have + // created. Problem here is doing like this can lead to a lot of list memory which is of no use and + // will be garbage collected later on, but it creates pressure on JVM. We should revisit this. + vectorList = new ArrayList<>(); + } vectorList.add(vector); } docIdList.add(doc); } - return new KNNCodecUtil.Pair( - docIdList.stream().mapToInt(Integer::intValue).toArray(), - vectorList.toArray(new float[][] {}), - serializationMode - ); + if (vectorList.isEmpty() == false) { + vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); + } + return new KNNCodecUtil.Pair(docIdList.stream().mapToInt(Integer::intValue).toArray(), vectorAddress, dimension, serializationMode); } - public static long calculateArraySize(float[][] vectors, SerializationMode serializationMode) { - int vectorLength = vectors[0].length; - int numVectors = vectors.length; + public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { if (serializationMode == SerializationMode.ARRAY) { int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { @@ -96,4 +121,14 @@ public static String buildEngineFilePrefix(String segmentName) { public static String buildEngineFileSuffix(String fieldName, String extension) { return String.format("_%s%s", fieldName, extension); } + + private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { + long totalLiveDocs; + if (binaryDocValues instanceof KNN80BinaryDocValues) { + totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs(); + } else { + totalLiveDocs = binaryDocValues.cost(); + } + return totalLiveDocs; + } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 286e6265c..c8b56436b 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.watcher.FileWatcher; @@ -313,7 +314,7 @@ private void cleanup() { closed = true; if (this.memoryAddress != 0) { - JNIService.freeVectors(this.memoryAddress); + JNICommons.freeVectorData(this.memoryAddress); } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 4b5045359..32516ef9d 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -50,27 +50,33 @@ class FaissService { } /** - * Create an index for the native library + * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the + * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer + * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this + * issue * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed * @param indexPath path to save index file to * @param parameters parameters to build index */ - public static native void createIndex(int[] ids, float[][] data, String indexPath, Map parameters); + public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); /** * Create an index for the native library with a provided template index * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed * @param indexPath path to save index file to * @param templateIndex empty template index * @param parameters additional build time parameters */ public static native void createIndexFromTemplate( int[] ids, - float[][] data, + long vectorsAddress, + int dim, String indexPath, byte[] templateIndex, Map parameters @@ -173,33 +179,15 @@ public static native KNNQueryResult[] queryIndexWithFilter( public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer); /** + *

+ * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} + *

* Transfer vectors from Java to native * * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well * @param trainingData data to be transferred * @return pointer to native memory location of training data */ + @Deprecated(since = "2.14.0", forRemoval = true) public static native long transferVectors(long vectorsPointer, float[][] trainingData); - - /** - * Transfer vectors from Java to native layer. This is the version 2 of transfer vector functionality. The - * difference between this and the version 1 is, this version puts vectors at the end rather than in front. - * Keeping this name as V2 for now, will come up with better name going forward. - *

- * TODO: Rename the function - *
- * TODO: Make this function native function and use a common cpp file to host these functions. - *

- * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well - * @param data data to be transferred - * @return pointer to native memory location for data - */ - public static native long transferVectorsV2(long vectorsPointer, float[][] data); - - /** - * Free vectors from memory - * - * @param vectorsPointer to be freed - */ - public static native void freeVectors(long vectorsPointer); } diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java new file mode 100644 index 000000000..90ad70c3d --- /dev/null +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.jni; + +import org.opensearch.knn.common.KNNConstants; + +import java.security.AccessController; +import java.security.PrivilegedAction; + +/** + * Common class for providing the JNI related functionality to various JNIServices. + */ +public class JNICommons { + + static { + AccessController.doPrivileged((PrivilegedAction) () -> { + System.loadLibrary(KNNConstants.COMMON_JNI_LIBRARY_NAME); + return null; + }); + } + + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D float array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity); + + /** + * Free up the memory allocated for the data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} + * + *

+ * The function is not threadsafe. If multiple threads are trying to free up same memory location, then it can + * lead to errors. + *

+ * + * @param memoryAddress address to be freed. + */ + public static native void freeVectorData(long memoryAddress); +} diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 80b56b173..5a5b6794a 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -23,22 +23,34 @@ public class JNIService { /** - * Create an index for the native library + * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the + * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer + * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this + * issue * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed * @param indexPath path to save index file to * @param parameters parameters to build index * @param knnEngine engine to build index for */ - public static void createIndex(int[] ids, float[][] data, String indexPath, Map parameters, KNNEngine knnEngine) { + public static void createIndex( + int[] ids, + long vectorsAddress, + int dim, + String indexPath, + Map parameters, + KNNEngine knnEngine + ) { + if (KNNEngine.NMSLIB == knnEngine) { - NmslibService.createIndex(ids, data, indexPath, parameters); + NmslibService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); return; } if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndex(ids, data, indexPath, parameters); + FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); return; } @@ -49,7 +61,8 @@ public static void createIndex(int[] ids, float[][] data, String indexPath, Map< * Create an index for the native library with a provided template index * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of vectors to be indexed * @param indexPath path to save index file to * @param templateIndex empty template index * @param parameters parameters to build index @@ -57,14 +70,15 @@ public static void createIndex(int[] ids, float[][] data, String indexPath, Map< */ public static void createIndexFromTemplate( int[] ids, - float[][] data, + long vectorsAddress, + int dim, String indexPath, byte[] templateIndex, Map parameters, KNNEngine knnEngine ) { if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndexFromTemplate(ids, data, indexPath, templateIndex, parameters); + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); return; } @@ -235,44 +249,17 @@ public static byte[] trainIndex(Map indexParameters, int dimensi } /** + *

+ * The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} + *

* Transfer vectors from Java to native * * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well * @param trainingData data to be transferred * @return pointer to native memory location of training data */ + @Deprecated(since = "2.14.0", forRemoval = true) public static long transferVectors(long vectorsPointer, float[][] trainingData) { return FaissService.transferVectors(vectorsPointer, trainingData); } - - /** - * Free vectors from memory - * - * @param vectorsPointer to be freed - */ - public static void freeVectors(long vectorsPointer) { - FaissService.freeVectors(vectorsPointer); - } - - /** - * Experimental: Transfer vectors from Java to native layer. This is the version 2 of transfer vector - * functionality. The difference between this and the version 1 is, this version puts vectors at the end rather - * than in front. Keeping this name as V2 for now, will come up with better name going forward. - *

- * This is not a production ready function for now. Adding this to ensure that we are able to run atleast 1 - * micro-benchmarks. - *

- *

- * TODO: Rename the function - *
- * TODO: Make this function native function and use a common cpp file to host these functions. - *

- * @param vectorsPointer pointer to vectors in native memory. Should be 0 to create vector as well - * @param data data to be transferred - * @return pointer to native memory location for data - * - */ - public static long transferVectorsV2(long vectorsPointer, float[][] data) { - return FaissService.transferVectorsV2(vectorsPointer, data); - } } diff --git a/src/main/java/org/opensearch/knn/jni/NmslibService.java b/src/main/java/org/opensearch/knn/jni/NmslibService.java index 77896822a..7fdc278d2 100644 --- a/src/main/java/org/opensearch/knn/jni/NmslibService.java +++ b/src/main/java/org/opensearch/knn/jni/NmslibService.java @@ -39,14 +39,18 @@ class NmslibService { } /** - * Create an index for the native library + * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the + * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer + * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this + * issue * * @param ids array of ids mapping to the data passed in - * @param data array of float arrays to be indexed + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed * @param indexPath path to save index file to * @param parameters parameters to build index */ - public static native void createIndex(int[] ids, float[][] data, String indexPath, Map parameters); + public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); /** * Load an index into memory diff --git a/src/main/plugin-metadata/plugin-security.policy b/src/main/plugin-metadata/plugin-security.policy index 91624613c..d5ab0be21 100644 --- a/src/main/plugin-metadata/plugin-security.policy +++ b/src/main/plugin-metadata/plugin-security.policy @@ -1,6 +1,7 @@ grant { permission java.lang.RuntimePermission "loadLibrary.opensearchknn_nmslib"; permission java.lang.RuntimePermission "loadLibrary.opensearchknn_faiss"; + permission java.lang.RuntimePermission "loadLibrary.opensearchknn_common"; permission java.lang.RuntimePermission "loadLibrary.opensearchknn_faiss_avx2"; permission java.net.SocketPermission "*", "connect,resolve"; permission java.lang.RuntimePermission "accessDeclaredMembers"; diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 72ffd2b66..1f9b423bc 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -34,10 +34,12 @@ import java.io.IOException; import java.net.URL; import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.TreeMap; import java.util.stream.Collectors; @@ -172,6 +174,108 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { fail("Graphs are not getting evicted"); } + @SneakyThrows + public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { + String indexName = "test-index-1"; + String fieldName = "test-field-1"; + + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final Set docIdsToBeDeleted = new HashSet<>(); + while (docIdsToBeDeleted.size() < 10) { + docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length)); + } + + for (Integer id : docIdsToBeDeleted) { + deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id])); + } + refreshAllNonSystemIndices(); + forceMergeKnnIndex(indexName, 3); + + assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName)); + + int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + } + + // Delete index + deleteKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index cdf3109a4..2ce3a7c83 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -38,6 +38,7 @@ import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.knn.plugin.stats.KNNGraphValue; @@ -357,7 +358,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio modelBytes, modelId ); - JNIService.freeVectors(trainingPtr); + JNICommons.freeVectorData(trainingPtr); // Setup the model cache to return the correct model ModelDao modelDao = mock(ModelDao.class); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index fabffab46..7573a4394 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -52,7 +53,8 @@ public void testIndexAllocation_close() throws InterruptedException { Arrays.fill(vectors[i], 1f); } Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); - JNIService.createIndex(ids, vectors, path, parameters, knnEngine); + long vectorMemoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); + JNIService.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); // Load index into memory long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 207b6373b..a84974202 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -16,6 +16,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.SpaceType; @@ -53,7 +54,8 @@ public void testIndexLoadStrategy_load() throws IOException { Arrays.fill(vectors[i], 1f); } Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); - JNIService.createIndex(ids, vectors, path, parameters, knnEngine); + long memoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); + JNIService.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); // Setup mock resource manager ResourceWatcherService resourceWatcherService = mock(ResourceWatcherService.class); diff --git a/src/test/java/org/opensearch/knn/jni/JNICommonsTest.java b/src/test/java/org/opensearch/knn/jni/JNICommonsTest.java new file mode 100644 index 000000000..bf27458b0 --- /dev/null +++ b/src/test/java/org/opensearch/knn/jni/JNICommonsTest.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.jni; + +import org.opensearch.knn.KNNTestCase; + +public class JNICommonsTest extends KNNTestCase { + + public void testStoreVectorData_whenVaildInputThenSuccess() { + float[][] data = new float[2][2]; + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 2; j++) { + data[i][j] = i + j; + } + } + long memoryAddress = JNICommons.storeVectorData(0, data, 8); + assertTrue(memoryAddress > 0); + assertEquals(memoryAddress, JNICommons.storeVectorData(memoryAddress, data, 8)); + } + + public void testFreeVectorData_whenValidInput_ThenSuccess() { + float[][] data = new float[2][2]; + for (int i = 0; i < 2; i++) { + for (int j = 0; j < 2; j++) { + data[i][j] = i + j; + } + } + long memoryAddress = JNICommons.storeVectorData(0, data, 8); + JNICommons.freeVectorData(memoryAddress); + } +} diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 36a3d93be..d6ae13e92 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -83,7 +83,8 @@ public void testCreateIndex_invalid_engineNotSupported() { IllegalArgumentException.class, () -> JNIService.createIndex( new int[] {}, - new float[][] {}, + 0, + 0, "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.LUCENE @@ -96,7 +97,8 @@ public void testCreateIndex_invalid_engineNull() { Exception.class, () -> JNIService.createIndex( new int[] {}, - new float[][] {}, + 0, + 0, "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null @@ -105,12 +107,12 @@ public void testCreateIndex_invalid_engineNull() { } public void testCreateIndex_nmslib_invalid_noSpaceType() { - expectThrows( Exception.class, () -> JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), "something", Collections.emptyMap(), KNNEngine.NMSLIB @@ -122,13 +124,14 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept int[] docIds = new int[] { 1, 2, 3 }; float[][] vectors1 = new float[][] { { 1, 2 }, { 3, 4 } }; - + long memoryAddress = JNICommons.storeVectorData(0, vectors1, vectors1.length * vectors1[0].length); Path tmpFile1 = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors1, + memoryAddress, + vectors1[0].length, tmpFile1.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -136,13 +139,15 @@ public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOExcept ); float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; + long memoryAddress2 = JNICommons.storeVectorData(0, vectors2, vectors2.length * vectors2[0].length); Path tmpFile2 = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors2, + memoryAddress2, + vectors2[0].length, tmpFile2.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -154,13 +159,14 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { int[] docIds = new int[] {}; float[][] vectors = new float[][] {}; - + long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( null, - vectors, + memoryAddress, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -171,7 +177,8 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { Exception.class, () -> JNIService.createIndex( docIds, - null, + 0, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -182,7 +189,8 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + 0, null, ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -191,14 +199,15 @@ public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) + () -> JNIService.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) ); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null @@ -210,13 +219,14 @@ public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { int[] docIds = new int[] { 1 }; float[][] vectors = new float[][] { { 2, 3 } }; - + long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length * vectors[0].length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.NMSLIB @@ -224,28 +234,11 @@ public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { ); } - public void testCreateIndex_nmslib_invalid_inconsistentDimensions() throws IOException { - - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; - - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> JNIService.createIndex( - docIds, - vectors, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ) - ); - } - public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException { - int[] docIds = new int[] {}; - float[][] vectors = new float[][] {}; + int[] docIds = new int[] { 1 }; + float[][] vectors = new float[][] { { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length * vectors[0].length); Map parametersMap = ImmutableMap.of( KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, @@ -258,7 +251,8 @@ public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue(), KNNConstants.PARAMETERS, parametersMap), KNNEngine.NMSLIB @@ -273,7 +267,8 @@ public void testCreateIndex_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.NMSLIB @@ -284,7 +279,8 @@ public void testCreateIndex_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of( KNNConstants.SPACE_TYPE, @@ -301,15 +297,14 @@ public void testCreateIndex_nmslib_valid() throws IOException { } public void testCreateIndex_faiss_invalid_noSpaceType() { - int[] docIds = new int[] {}; - float[][] vectors = new float[][] {}; expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), "something", ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod), KNNEngine.FAISS @@ -321,13 +316,14 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti int[] docIds = new int[] { 1, 2, 3 }; float[][] vectors1 = new float[][] { { 1, 2 }, { 3, 4 } }; - + long memoryAddress = JNICommons.storeVectorData(0, vectors1, vectors1.length * vectors1[0].length); Path tmpFile1 = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors1, + memoryAddress, + vectors1[0].length, tmpFile1.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -335,13 +331,14 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti ); float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; - + long memoryAddress2 = JNICommons.storeVectorData(0, vectors2, vectors2.length * vectors2[0].length); Path tmpFile2 = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors2, + memoryAddress, + vectors2[0].length, tmpFile2.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -353,13 +350,15 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { int[] docIds = new int[] {}; float[][] vectors = new float[][] {}; + long memoryAddress = JNICommons.storeVectorData(0, vectors, 0); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( null, - vectors, + memoryAddress, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -370,7 +369,8 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { Exception.class, () -> JNIService.createIndex( docIds, - null, + 0, + 0, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -381,7 +381,8 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { Exception.class, () -> JNIService.createIndex( docIds, - vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), null, ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -390,14 +391,22 @@ public void testCreateIndex_faiss_invalid_null() throws IOException { expectThrows( Exception.class, - () -> JNIService.createIndex(docIds, vectors, tmpFile.toAbsolutePath().toString(), null, KNNEngine.FAISS) + () -> JNIService.createIndex( + docIds, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + null, + KNNEngine.FAISS + ) ); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null @@ -409,13 +418,15 @@ public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { int[] docIds = new int[] { 1 }; float[][] vectors = new float[][] { { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.FAISS @@ -423,35 +434,19 @@ public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { ); } - public void testCreateIndex_faiss_invalid_inconsistentDimensions() throws IOException { - - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; - - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> JNIService.createIndex( - docIds, - vectors, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); - } - public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOException { int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -460,16 +455,16 @@ public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOExceptio } public void testCreateIndex_faiss_invalid_invalidIndexDescription() throws IOException { - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3, 4 } }; - + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); expectThrows( Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "invalid", KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -482,6 +477,8 @@ public void testCreateIndex_faiss_sqfp16_invalidIndexDescription() { int[] docIds = new int[] { 1, 2 }; float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); + String sqfp16InvalidIndexDescription = "HNSW16,SQfp1655"; Path tmpFile = createTempFile(); @@ -489,7 +486,8 @@ public void testCreateIndex_faiss_sqfp16_invalidIndexDescription() { Exception.class, () -> JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of( INDEX_DESCRIPTION_PARAMETER, @@ -508,11 +506,12 @@ public void testLoadIndex_faiss_sqfp16_valid() { int[] docIds = new int[] { 1, 2 }; float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; String sqfp16IndexDescription = "HNSW16,SQfp16"; - + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path tmpFile = createTempFile(); JNIService.createIndex( docIds, - vectors, + memoryAddress, + vectors[0].length, tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -528,11 +527,13 @@ public void testQueryIndex_faiss_sqfp16_valid() { String sqfp16IndexDescription = "HNSW16,SQfp16"; int k = 10; - + float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); + long memoryAddress = JNICommons.storeVectorData(0, truncatedVectors, (long) truncatedVectors.length * truncatedVectors[0].length); Path tmpFile = createTempFile(); JNIService.createIndex( testData.indexData.docs, - truncateToFp16Range(testData.indexData.vectors), + memoryAddress, + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -596,7 +597,7 @@ public void testTrain_whenConfigurationIsIVFSQFP16_thenSucceed() { byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); } public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOException { @@ -609,7 +610,8 @@ public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOExcept Exception.class, () -> JNIService.createIndex( docIds, - vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of( INDEX_DESCRIPTION_PARAMETER, @@ -634,7 +636,8 @@ public void testCreateIndex_faiss_valid() throws IOException { Path tmpFile1 = createTempFile(); JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile1.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS @@ -684,7 +687,8 @@ public void testLoadIndex_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -719,7 +723,8 @@ public void testLoadIndex_faiss_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -745,7 +750,8 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -770,7 +776,8 @@ public void testQueryIndex_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.NMSLIB @@ -802,7 +809,8 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -826,7 +834,8 @@ public void testQueryIndex_faiss_valid() throws IOException { Path tmpFile = createTempFile(); JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS @@ -867,7 +876,8 @@ public void testQueryIndex_faiss_parentIds() throws IOException { Path tmpFile = createTempFile(); JNIService.createIndex( testDataNested.indexData.docs, - testDataNested.indexData.vectors, + testData.loadDataToMemoryAddress(), + testDataNested.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS @@ -941,7 +951,8 @@ public void testFree_nmslib_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB @@ -964,7 +975,8 @@ public void testFree_faiss_valid() throws IOException { JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS @@ -987,7 +999,7 @@ public void testTransferVectors() { assertEquals(trainPointer1, trainPointer2); } - JNIService.freeVectors(trainPointer1); + JNICommons.freeVectorData(trainPointer1); } public void testTrain_whenConfigurationIsIVFFlat_thenSucceed() throws IOException { @@ -1008,7 +1020,7 @@ public void testTrain_whenConfigurationIsIVFFlat_thenSucceed() throws IOExceptio byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); } public void testTrain_whenConfigurationIsIVFPQ_thenSucceed() throws IOException { @@ -1038,7 +1050,7 @@ public void testTrain_whenConfigurationIsIVFPQ_thenSucceed() throws IOException byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); } public void testTrain_whenConfigurationIsHNSWPQ_thenSucceed() throws IOException { @@ -1065,7 +1077,7 @@ public void testTrain_whenConfigurationIsHNSWPQ_thenSucceed() throws IOException byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); } private long transferVectors(int numDuplicates) { @@ -1120,12 +1132,13 @@ public void testCreateIndexFromTemplate() throws IOException { byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer1, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer1); + JNICommons.freeVectorData(trainPointer1); Path tmpFile1 = createTempFile(); JNIService.createIndexFromTemplate( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile1.toAbsolutePath().toString(), faissIndex, ImmutableMap.of(INDEX_THREAD_QTY, 1), @@ -1255,11 +1268,12 @@ private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, Spac byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer); + JNICommons.freeVectorData(trainPointer); Path tmpFile = createTempFile(); JNIService.createIndexFromTemplate( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), faissIndex, ImmutableMap.of(INDEX_THREAD_QTY, 1), @@ -1274,7 +1288,8 @@ private String createFaissHNSWIndex(SpaceType spaceType) throws IOException { Path tmpFile = createTempFile(); JNIService.createIndex( testData.indexData.docs, - testData.indexData.vectors, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), tmpFile.toAbsolutePath().toString(), ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index e4c01a8ec..06b96c57c 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -25,6 +25,7 @@ import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import java.io.File; @@ -170,7 +171,7 @@ public void testRun_success() throws IOException, ExecutionException { when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); doAnswer(invocationOnMock -> { - JNIService.freeVectors(memoryAddress); + JNICommons.freeVectorData(memoryAddress); return null; }).when(nativeMemoryCacheManager).invalidate(tdataKey); @@ -197,11 +198,12 @@ public void testRun_success() throws IOException, ExecutionException { int[] ids = { 1, 2, 3, 4 }; float[][] vectors = new float[ids.length][dimension]; fillFloatArrayRandomly(vectors); - + long vectorsMemoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); Path indexPath = createTempFile(); JNIService.createIndexFromTemplate( ids, - vectors, + vectorsMemoryAddress, + vectors[0].length, indexPath.toString(), model.getModelBlob(), ImmutableMap.of(INDEX_THREAD_QTY, 1), @@ -456,7 +458,7 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation); doAnswer(invocationOnMock -> { - JNIService.freeVectors(memoryAddress); + JNICommons.freeVectorData(memoryAddress); return null; }).when(nativeMemoryCacheManager).invalidate(tdataKey); diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 1b5accae9..37e35f062 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -9,12 +9,15 @@ import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.knn.index.codec.util.KNNCodecUtil; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.plugin.script.KNNScoringUtil; import java.util.Comparator; import java.util.Random; @@ -246,7 +249,7 @@ public static PriorityQueue insertWithOverflow(PriorityQueue idsList = new ArrayList<>(); List vectorsList = new ArrayList<>(); @@ -287,8 +290,7 @@ private KNNCodecUtil.Pair readIndexData(String path) throws IOException { vectorsArray[i][j] = vectorsList.get(i)[j]; } } - - return new KNNCodecUtil.Pair(idsArray, vectorsArray, SerializationMode.COLLECTION_OF_FLOATS); + return new Pair(idsArray, vectorsArray[0].length, SerializationMode.COLLECTION_OF_FLOATS, vectorsArray); } private float[][] readQueries(String path) throws IOException { @@ -319,5 +321,19 @@ private float[][] readQueries(String path) throws IOException { } return queryArray; } + + public long loadDataToMemoryAddress() { + return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length); + } + + @AllArgsConstructor + public static class Pair { + public int[] docs; + @Getter + @Setter + private int dimension; + public SerializationMode serializationMode; + public float[][] vectors; + } } }