From 716507947224480574bd11920f9b2408036d20b8 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Fri, 12 Apr 2024 14:00:02 -0700 Subject: [PATCH] Added separate interface for creating and writing in a faiss index to reduce memory footprint for faiss Signed-off-by: Navneet Verma --- jni/include/faiss_wrapper.h | 7 + .../org_opensearch_knn_jni_FaissService.h | 17 +++ jni/src/faiss_wrapper.cpp | 98 ++++++++++++ .../org_opensearch_knn_jni_FaissService.cpp | 24 +++ jni/tests/faiss_wrapper_test.cpp | 58 ++++++- .../KNN80Codec/KNN80DocValuesConsumer.java | 141 ++++++++++++------ .../knn/index/codec/util/KNNCodecUtil.java | 2 +- .../util/KNNVectorSerializerFactory.java | 2 +- .../org/opensearch/knn/jni/FaissService.java | 10 ++ .../org/opensearch/knn/jni/JNIService.java | 23 +++ 10 files changed, 337 insertions(+), 45 deletions(-) diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 3e1adeac4..5d6d9ee60 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -22,6 +22,13 @@ namespace knn_jni { void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ); + // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. + // The index is serialized to indexPathJ. + long long CreateIndexIteratively(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong indexAddressJ, jobject parametersJ); + + void writeIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env,jlong indexAddressJ, jstring indexPathJ, jobject parametersJ); + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 32b6f22f1..208fc9bba 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -26,6 +26,23 @@ extern "C" { JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createIndexIteratively + * Signature: ([IJIJLjava/util/Map;)J + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexIteratively + (JNIEnv *, jclass, jintArray, jlong, jint, jlong, jobject); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: writeIndex + * Signature: (JLjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex + (JNIEnv *, jclass, jlong, jstring, jobject); + + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 817bdb816..17e5703bb 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -81,6 +81,104 @@ bool isIndexIVFPQL2(faiss::Index * index); // IndexIDMap which has member that will point to underlying index that stores the data faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); +void knn_jni::faiss_wrapper::writeIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env,jlong indexAddressJ, jstring indexPathJ, jobject parametersJ) { + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, + parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + auto *idMap = reinterpret_cast((long long)indexAddressJ); + faiss::write_index(idMap, indexPathCpp.c_str()); + // Deleting the internal index of the idMap index + delete idMap->index; + delete idMap; +} + +long long knn_jni::faiss_wrapper::CreateIndexIteratively(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, + jlong indexAddressJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Get space type for this index + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int) dimJ; + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if (numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexIDMap *idMap = nullptr; + long indexAddress = (long) indexAddressJ; + if(indexAddress == 0) { + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + faiss::Index *indexWriter = faiss::index_factory(dim, indexDescriptionCpp.c_str(), metric); + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, + parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + if (parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); + SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter); + jniUtil->DeleteLocalRef(env, subParametersJ); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Check that the index does not need to be trained + if (!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + idMap = new faiss::IndexIDMap(indexWriter); + idMap->add_with_ids(numVectors, inputVectors->data(), idVector.data()); + } else { + idMap = reinterpret_cast(indexAddress); + idMap->add_with_ids(numVectors, inputVectors->data(), idVector.data()); + } + delete inputVectors; + return (long long)idMap; +} + + void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jobject parametersJ) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 3249ed872..fb8af44f4 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -17,6 +17,7 @@ #include "faiss_wrapper.h" #include "jni_util.h" +#include static knn_jni::JNIUtil jniUtil; static const jint KNN_FAISS_JNI_VERSION = JNI_VERSION_1_1; @@ -50,6 +51,29 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIE } } + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexIteratively(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jlong indexAddressJ, jobject parametersJ) +{ + try { + return (jlong)knn_jni::faiss_wrapper::CreateIndexIteratively(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexAddressJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return 0; +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, + jlong indexAddressJ, jstring indexPathJ, jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::writeIndex(&jniUtil, env, indexAddressJ, indexPathJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls, jintArray idsJ, jlong vectorsAddressJ, diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 05854f7ed..69998e50c 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -19,6 +19,7 @@ #include "test_util.h" #include "faiss/IndexHNSW.h" #include "faiss/IndexIVFPQ.h" +#include "faiss/utils/utils.h" using ::testing::NiceMock; using ::testing::Return; @@ -26,7 +27,7 @@ using ::testing::Return; float randomDataMin = -500.0; float randomDataMax = 500.0; -TEST(FaissCreateIndexTest, BasicAssertions) { +TEST(FaissCreateIndexIterativelyTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; @@ -52,6 +53,61 @@ TEST(FaissCreateIndexTest, BasicAssertions) { JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaObjectArrayLength( + jniEnv, reinterpret_cast(&vectors))) + .WillRepeatedly(Return(vectors->size())); + + // Create the index + long long indexAddress = knn_jni::faiss_wrapper::CreateIndexIteratively( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong) vectors, dim, (jlong)0, + (jobject)¶metersMap); + knn_jni::faiss_wrapper::writeIndex( + &mockJNIUtil, jniEnv,(jlong)indexAddress, (jstring)&indexPath, + (jobject)¶metersMap); + + // Make sure index can be loaded + std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); + // Clean up + ids.clear(); + ids.shrink_to_fit(); + vectors->clear(); + vectors->shrink_to_fit(); + size_t mem_usage = faiss::get_mem_usage_kb() / (1 << 10); + + std::cout<<"======Memory Usage:[" << mem_usage << "mb]======" << std::endl; + // Clean up + std::remove(indexPath.c_str()); +} + + +TEST(FaissCreateIndexTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + auto *vectors = new std::vector(); + int dim = 2; + vectors->reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { + vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + std::string indexPath = test_util::RandomString(10, "/tmp/", ".faiss"); + std::string spaceType = knn_jni::L2; + std::string index_description = "HNSW32,Flat"; + + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; + parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&index_description; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 096df817a..0c78e2a01 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -8,19 +8,23 @@ import com.google.common.collect.ImmutableMap; import lombok.NonNull; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.util.BytesRef; import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.StopWatch; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.codec.util.KNNVectorSerializer; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; +import org.opensearch.knn.index.codec.util.SerializationMode; +import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; -import org.opensearch.knn.plugin.stats.KNNCounter; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.codecs.DocValuesConsumer; @@ -36,6 +40,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import java.io.ByteArrayInputStream; import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; @@ -46,14 +51,15 @@ import java.nio.file.StandardOpenOption; import java.security.AccessController; import java.security.PrivilegedAction; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; -import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; /** * This class writes the KNN docvalues to the segments @@ -108,19 +114,11 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, throws IOException { // Get values to be indexed BinaryDocValues values = valuesProducer.getBinary(field); - KNNCodecUtil.Pair pair = KNNCodecUtil.getFloats(values); - if (pair.getVectorAddress() == 0 || pair.docs.length == 0) { - logger.info("Skipping engine index creation as there are no vectors or docs in the segment"); + if (values == null) { + log.info("BinaryDocValues is null. Returning.."); return; } - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); - if (isMerge) { - KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); - KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length); - KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize); - } - // Increment counter for number of graph index requests - KNNCounter.GRAPH_INDEX_REQUESTS.increment(); + // Get the KNN engine final KNNEngine knnEngine = getKNNEngine(field); final String engineFileName = buildEngineFileName( state.segmentInfo.name, @@ -132,35 +130,99 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), engineFileName ).toString(); - NativeIndexCreator indexCreator; - // Create library index either from model or from scratch - if (field.attributes().containsKey(MODEL_ID)) { - String modelId = field.attributes().get(MODEL_ID); - Model model = ModelCache.getInstance().get(modelId); - if (model.getModelBlob() == null) { - throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); - } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); - } else { - indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); - } + Map parametersMap = getKNNIndexFromScratchParameters(field, knnEngine); - if (isMerge) { - recordMergeStats(pair.docs.length, arraySize); - } + long indexAddress = createIndex(values, knnEngine, parametersMap); - if (isRefresh) { - recordRefreshStats(); + if (indexAddress == 0) { + log.info("Index is not created. Returning.."); } - // This is a bit of a hack. We have to create an output here and then immediately close it to ensure that - // engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will - // not be marked as added to the directory. state.directory.createOutput(engineFileName, state.context).close(); - indexCreator.createIndex(); + // Now we can write the index + AccessController.doPrivileged((PrivilegedAction) () -> { + JNIService.writeIndex(indexAddress, parametersMap, indexPath, knnEngine); + return null; + }); writeFooter(indexPath, engineFileName); } + private long createIndex(BinaryDocValues values, final KNNEngine knnEngine, final Map parametersMap) + throws IOException { + List vectorList = new ArrayList<>(); + List docIdList = new ArrayList<>(); + int dimension = 0; + SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS; + + long totalLiveDocs = KNNCodecUtil.getTotalLiveDocsCount(values); + long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes(); + long vectorsPerTransfer = Integer.MIN_VALUE; + Long indexAddress = 0L; + + for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) { + BytesRef bytesref = values.binaryValue(); + try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) { + final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); + final float[] vector = vectorSerializer.byteToFloatArray(byteStream); + dimension = vector.length; + + if (vectorsPerTransfer == Integer.MIN_VALUE) { + vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit; + if (vectorsPerTransfer == 0) { + vectorsPerTransfer = totalLiveDocs; + } + } + if (vectorList.size() == vectorsPerTransfer) { + final long vectorAddress = JNICommons.storeVectorData( + 0, + vectorList.toArray(new float[][] {}), + (long) vectorList.size() * dimension + ); + List docIdList2 = docIdList; + int finalDimension = dimension; + long indexAddress2 = indexAddress; + indexAddress = AccessController.doPrivileged( + (PrivilegedAction) () -> JNIService.buildIndex( + docIdList2.stream().mapToInt(Integer::intValue).toArray(), + vectorAddress, + indexAddress2, + finalDimension, + parametersMap, + knnEngine + ) + ); + + // We should probably come up with a better way to reuse the vectorList memory which we have + // created. Problem here is doing like this can lead to a lot of list memory which is of no use and + // will be garbage collected later on, but it creates pressure on JVM. We should revisit this. + vectorList = new ArrayList<>(); + docIdList = new ArrayList<>(); + // JNICommons.freeVectorData(vectorAddress); + } + vectorList.add(vector); + } + docIdList.add(doc); + } + if (vectorList.isEmpty() == false) { + long vectorAddress = JNICommons.storeVectorData(0, vectorList.toArray(new float[][] {}), (long) vectorList.size() * dimension); + List docIdList2 = docIdList; + int finalDimension = dimension; + long indexAddress2 = indexAddress; + indexAddress = AccessController.doPrivileged( + (PrivilegedAction) () -> JNIService.buildIndex( + docIdList2.stream().mapToInt(Integer::intValue).toArray(), + vectorAddress, + indexAddress2, + finalDimension, + parametersMap, + knnEngine + ) + ); + // JNICommons.freeVectorData(vectorAddress); + } + return indexAddress; + } + private void recordMergeStats(int length, long arraySize) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement(); KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length); @@ -193,8 +255,7 @@ private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KN }); } - private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) - throws IOException { + private Map getKNNIndexFromScratchParameters(FieldInfo fieldInfo, KNNEngine knnEngine) throws IOException { Map parameters = new HashMap<>(); Map fieldAttributes = fieldInfo.attributes(); String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); @@ -225,11 +286,7 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa // Used to determine how many threads to use when indexing parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); - // Pass the path for the nms library to save the file - AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine); - return null; - }); + return parameters; } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index e05962608..8eb82dc28 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -127,7 +127,7 @@ public static String buildEngineFileSuffix(String fieldName, String extension) { return String.format("_%s%s", fieldName, extension); } - private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { + public static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) { long totalLiveDocs; if (binaryDocValues instanceof KNN80BinaryDocValues) { totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs(); diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java index 5c1e4ca9b..bf2f5c6be 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java @@ -56,7 +56,7 @@ public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayIn return getSerializerBySerializationMode(serializationMode); } - static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) { + public static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) { int numberOfAvailableBytesInStream = byteStream.available(); if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) { return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS); diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 32516ef9d..29feb7932 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -49,6 +49,16 @@ class FaissService { }); } + public static native long createIndexIteratively( + int[] ids, + long vectorsAddress, + int dim, + long indexAddress, + Map parameters + ); + + public static native void writeIndex(long indexAddress, String indexPath, Map parameters); + /** * Create an index for the native library The memory occupied by the vectorsAddress will be freed up during the * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 5a5b6794a..40678e65c 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -57,6 +57,29 @@ public static void createIndex( throw new IllegalArgumentException(String.format("CreateIndex not supported for provided engine : %s", knnEngine.getName())); } + public static long buildIndex( + int[] ids, + long vectorsAddress, + long indexAddress, + int dim, + Map parameters, + KNNEngine knnEngine + ) { + if (KNNEngine.FAISS == knnEngine) { + return FaissService.createIndexIteratively(ids, vectorsAddress, dim, indexAddress, parameters); + } + throw new IllegalArgumentException(String.format("buildIndex not supported for provided engine : %s", knnEngine.getName())); + + } + + public static void writeIndex(long indexAddress, Map parameters, String indexPath, KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + FaissService.writeIndex(indexAddress, indexPath, parameters); + return; + } + throw new IllegalArgumentException(String.format("writeIndex not supported for provided engine : %s", knnEngine.getName())); + } + /** * Create an index for the native library with a provided template index *