diff --git a/CHANGELOG.md b/CHANGELOG.md index b6f37f6d1..22e493f32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Improved the logic to switch to exact search for restrictive filters search for better recall. [#1059](https://github.com/opensearch-project/k-NN/pull/1059) * Added max distance computation logic to enhance the switch to exact search in filtered Nearest Neighbor Search. [#1066](https://github.com/opensearch-project/k-NN/pull/1066) ### Bug Fixes +* Update Faiss parameter construction to allow HNSW+PQ to work [#1074](https://github.com/opensearch-project/k-NN/pull/1074) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 6c091da03..71eed404a 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -18,6 +18,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.function.Function; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; @@ -64,9 +65,7 @@ class Faiss extends NativeLibrary { Collections.emptyMap() ); - // TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now, - // we do not have a way to base validation off of dimension. Failure will happen during training in JNI. - private final static Map encoderComponents = ImmutableMap.of( + private final static Map COMMON_ENCODERS = ImmutableMap.of( KNNConstants.ENCODER_FLAT, MethodComponent.Builder.builder(KNNConstants.ENCODER_FLAT) .setMapGenerator( @@ -76,62 +75,111 @@ class Faiss extends NativeLibrary { methodComponentContext ).build()) ) - .build(), - KNNConstants.ENCODER_PQ, - MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ) - .addParameter( - ENCODER_PARAMETER_PQ_M, - new Parameter.IntegerParameter( - ENCODER_PARAMETER_PQ_M, - ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT - ) - ) - .addParameter( - ENCODER_PARAMETER_PQ_CODE_SIZE, - new Parameter.IntegerParameter( - ENCODER_PARAMETER_PQ_CODE_SIZE, - ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT, - v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT - ) - ) - .setRequiresTraining(true) - .setMapGenerator( - ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( - FAISS_PQ_DESCRIPTION, - methodComponent, - methodComponentContext - ).addParameter(ENCODER_PARAMETER_PQ_M, "", "").addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "").build()) + .build() + ); + + // TODO: To think about in future: for PQ, if dimension is not divisible by code count, PQ will fail. Right now, + // we do not have a way to base validation off of dimension. Failure will happen during training in JNI. + // Define methods supported by faiss. See issue here: https://github.com/opensearch-project/k-NN/issues/1075 + private final static Map HNSW_ENCODERS = ImmutableMap.builder() + .putAll( + ImmutableMap.of( + KNNConstants.ENCODER_PQ, + MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ) + .addParameter( + ENCODER_PARAMETER_PQ_M, + new Parameter.IntegerParameter( + ENCODER_PARAMETER_PQ_M, + ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT + ) + ) + .addParameter( + ENCODER_PARAMETER_PQ_CODE_SIZE, + new Parameter.IntegerParameter( + ENCODER_PARAMETER_PQ_CODE_SIZE, + ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT, + v -> Objects.equals(v, ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT) + ) + ) + .setRequiresTraining(true) + .setMapGenerator( + ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + FAISS_PQ_DESCRIPTION, + methodComponent, + methodComponentContext + ).addParameter(ENCODER_PARAMETER_PQ_M, "", "").build()) + ) + .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { + int codeSize = ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT; + return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1; + }) + .build() ) - .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { - // Size estimate formula: (4 * d * 2^code_size) / 1024 + 1 - - // Get value of code size passed in by user - Object codeSizeObject = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); - - // If not specified, get default value of code size - if (codeSizeObject == null) { - Parameter codeSizeParameter = methodComponent.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); - if (codeSizeParameter == null) { - throw new IllegalStateException( - String.format("%s is not a valid parameter. This is a bug.", ENCODER_PARAMETER_PQ_CODE_SIZE) - ); - } + ) + .putAll(COMMON_ENCODERS) + .build(); - codeSizeObject = codeSizeParameter.getDefaultValue(); - } + private final static Map IVF_ENCODERS = ImmutableMap.builder() + .putAll( + ImmutableMap.of( + KNNConstants.ENCODER_PQ, + MethodComponent.Builder.builder(KNNConstants.ENCODER_PQ) + .addParameter( + ENCODER_PARAMETER_PQ_M, + new Parameter.IntegerParameter( + ENCODER_PARAMETER_PQ_M, + ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT, + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_COUNT_LIMIT + ) + ) + .addParameter( + ENCODER_PARAMETER_PQ_CODE_SIZE, + new Parameter.IntegerParameter( + ENCODER_PARAMETER_PQ_CODE_SIZE, + ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT, + v -> v > 0 && v < ENCODER_PARAMETER_PQ_CODE_SIZE_LIMIT + ) + ) + .setRequiresTraining(true) + .setMapGenerator( + ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( + FAISS_PQ_DESCRIPTION, + methodComponent, + methodComponentContext + ).addParameter(ENCODER_PARAMETER_PQ_M, "", "").addParameter(ENCODER_PARAMETER_PQ_CODE_SIZE, "x", "").build()) + ) + .setOverheadInKBEstimator((methodComponent, methodComponentContext, dimension) -> { + // Size estimate formula: (4 * d * 2^code_size) / 1024 + 1 - if (!(codeSizeObject instanceof Integer)) { - throw new IllegalStateException(String.format("%s must be an integer.", ENCODER_PARAMETER_PQ_CODE_SIZE)); - } + // Get value of code size passed in by user + Object codeSizeObject = methodComponentContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); - int codeSize = (Integer) codeSizeObject; - return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1; - }) - .build() - ); + // If not specified, get default value of code size + if (codeSizeObject == null) { + Parameter codeSizeParameter = methodComponent.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE); + if (codeSizeParameter == null) { + throw new IllegalStateException( + String.format("%s is not a valid parameter. This is a bug.", ENCODER_PARAMETER_PQ_CODE_SIZE) + ); + } + + codeSizeObject = codeSizeParameter.getDefaultValue(); + } + + if (!(codeSizeObject instanceof Integer)) { + throw new IllegalStateException(String.format("%s must be an integer.", ENCODER_PARAMETER_PQ_CODE_SIZE)); + } + + int codeSize = (Integer) codeSizeObject; + return ((4L * (1L << codeSize) * dimension) / BYTES_PER_KILOBYTES) + 1; + }) + .build() + ) + ) + .putAll(COMMON_ENCODERS) + .build(); - // Define methods supported by faiss private final static Map METHODS = ImmutableMap.of( METHOD_HNSW, KNNMethod.Builder.builder( @@ -158,7 +206,7 @@ class Faiss extends NativeLibrary { ) .addParameter( METHOD_ENCODER_PARAMETER, - new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, encoderComponents) + new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, HNSW_ENCODERS) ) .setMapGenerator( ((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder( @@ -190,7 +238,7 @@ class Faiss extends NativeLibrary { ) .addParameter( METHOD_ENCODER_PARAMETER, - new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, encoderComponents) + new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, ENCODER_DEFAULT, IVF_ENCODERS) ) .setRequiresTraining(true) .setMapGenerator( diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 2d4dc8053..5db24ea30 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -37,7 +37,12 @@ import java.util.TreeMap; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; @@ -62,7 +67,8 @@ public static void setUpClass() throws IOException { testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); } - public void testEndToEnd_fromMethod() throws Exception { + @SneakyThrows + public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { String indexName = "test-index-1"; String fieldName = "test-field-1"; @@ -150,6 +156,117 @@ public void testEndToEnd_fromMethod() throws Exception { fail("Graphs are not getting evicted"); } + @SneakyThrows + public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { + String indexName = "test-index"; + String fieldName = "test-field"; + String trainingIndexName = "training-index"; + String trainingFieldName = "training-field"; + + String modelId = "test-model"; + String modelDescription = "test model"; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + List pqMValues = ImmutableList.of(2, 4, 8); + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ + int trainingDataCount = 256; + + SpaceType spaceType = SpaceType.L2; + + Integer dimension = testData.indexData.vectors[0].length; + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqMValues.get(random().nextInt(pqMValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount); + assertTrainingSucceeds(modelId, 180, 1000); + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("model_id", modelId) + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + assertEquals( + KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + } + + // Delete index + deleteKNNIndex(indexName); + deleteModel(modelId); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } + public void testDocUpdate() throws IOException { String indexName = "test-index-1"; String fieldName = "test-field-1"; diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java index a3011cef5..0b12980ab 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java @@ -26,6 +26,7 @@ import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT; import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; @@ -140,22 +141,54 @@ public void testRequiresTraining() { assertTrue(knnMethodContext.isTrainingRequired()); } - public void testEstimateOverheadInKB() { + public void testEstimateOverheadInKB_whenMethodIsHNSWFlatNmslib_thenSizeIsExpectedValue() { // For HNSW no encoding we expect 0 MethodComponentContext hnswMethod = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); - KNNMethodContext knnMethodContextNmslib = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, hnswMethod); - KNNMethodContext knnMethodContextFaiss = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, hnswMethod); - assertEquals(0, knnMethodContextNmslib.estimateOverheadInKB(1000)); - assertEquals(0, knnMethodContextFaiss.estimateOverheadInKB(168)); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, hnswMethod); + assertEquals(0, knnMethodContext.estimateOverheadInKB(1000)); + + } + + public void testEstimateOverheadInKB_whenMethodIsHNSWFlatFaiss_thenSizeIsExpectedValue() { + // For HNSW no encoding we expect 0 + MethodComponentContext hnswMethod = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, hnswMethod); + assertEquals(0, knnMethodContext.estimateOverheadInKB(168)); + + } + + public void testEstimateOverheadInKB_whenMethodIsHNSWPQFaiss_thenSizeIsExpectedValue() { + int dimension = 768; + int codeSize = ENCODER_PARAMETER_PQ_CODE_SIZE_DEFAULT; + + // For HNSWPQ, we expect 4 * d * 2^code_size / 1024 + 1 + int expectedHnswPq = 4 * dimension * (1 << codeSize) / BYTES_PER_KILOBYTES + 1; + + MethodComponentContext pqMethodContext = new MethodComponentContext(ENCODER_PQ, ImmutableMap.of()); + + MethodComponentContext hnswMethodPq = new MethodComponentContext( + METHOD_HNSW, + ImmutableMap.of(METHOD_ENCODER_PARAMETER, pqMethodContext) + ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, hnswMethodPq); + assertEquals(expectedHnswPq, knnMethodContext.estimateOverheadInKB(dimension)); + } + public void testEstimateOverheadInKB_whenMethodIsIVFFlatFaiss_thenSizeIsExpectedValue() { // For IVF, we expect 4 * nlist * d / 1024 + 1 int dimension = 768; int nlists = 1024; int expectedIvf = 4 * nlists * dimension / BYTES_PER_KILOBYTES + 1; MethodComponentContext ivfMethod = new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists)); - knnMethodContextFaiss = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, ivfMethod); - assertEquals(expectedIvf, knnMethodContextFaiss.estimateOverheadInKB(dimension)); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, ivfMethod); + assertEquals(expectedIvf, knnMethodContext.estimateOverheadInKB(dimension)); + } + + public void testEstimateOverheadInKB_whenMethodIsIVFPQFaiss_thenSizeIsExpectedValue() { + int dimension = 768; + int nlists = 1024; + int expectedIvf = 4 * nlists * dimension / BYTES_PER_KILOBYTES + 1; // For IVFPQ twe expect 4 * nlist * d / 1024 + 1 + 4 * d * 2^code_size / 1024 + 1 int codeSize = 16; @@ -171,17 +204,8 @@ public void testEstimateOverheadInKB() { METHOD_IVF, ImmutableMap.of(METHOD_PARAMETER_NLIST, nlists, METHOD_ENCODER_PARAMETER, pqMethodContext) ); - knnMethodContextFaiss = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, ivfMethodPq); - assertEquals(expectedIvfPq, knnMethodContextFaiss.estimateOverheadInKB(dimension)); - - // For HNSWPQ, we expect 4 * d * 2^code_size / 1024 + 1 - int expectedHnswPq = expectedFromPq; - MethodComponentContext hnswMethodPq = new MethodComponentContext( - METHOD_HNSW, - ImmutableMap.of(METHOD_ENCODER_PARAMETER, pqMethodContext) - ); - knnMethodContextFaiss = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, hnswMethodPq); - assertEquals(expectedHnswPq, knnMethodContextFaiss.estimateOverheadInKB(dimension)); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, ivfMethodPq); + assertEquals(expectedIvfPq, knnMethodContext.estimateOverheadInKB(dimension)); } /** diff --git a/src/test/java/org/opensearch/knn/index/util/FaissTests.java b/src/test/java/org/opensearch/knn/index/util/FaissTests.java index 01841363d..35f06027c 100644 --- a/src/test/java/org/opensearch/knn/index/util/FaissTests.java +++ b/src/test/java/org/opensearch/knn/index/util/FaissTests.java @@ -8,6 +8,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponent; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.Parameter; @@ -16,12 +17,122 @@ import java.util.HashMap; import java.util.Map; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; public class FaissTests extends KNNTestCase { + public void testGetMethodAsMap_whenMethodIsHNSWFlat_thenCreateCorrectIndexDescription() throws IOException { + int mParam = 65; + String expectedIndexDescription = String.format("HNSW%d,Flat", mParam); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, mParam) + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + + Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + + assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); + assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); + } + + public void testGetMethodAsMap_whenMethodIsHNSWPQ_thenCreateCorrectIndexDescription() throws IOException { + int hnswMParam = 65; + int pqMParam = 17; + String expectedIndexDescription = String.format("HNSW%d,PQ%d", hnswMParam, pqMParam); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, hnswMParam) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqMParam) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + + Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + + assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); + assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); + } + + public void testGetMethodAsMap_whenMethodIsIVFFlat_thenCreateCorrectIndexDescription() throws IOException { + int nlists = 88; + String expectedIndexDescription = String.format("IVF%d,Flat", nlists); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, nlists) + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + + Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + + assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); + assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); + } + + public void testGetMethodAsMap_whenMethodIsIVFPQ_thenCreateCorrectIndexDescription() throws IOException { + int ivfNlistsParam = 88; + int pqMParam = 17; + int pqCodeSizeParam = 53; + String expectedIndexDescription = String.format("IVF%d,PQ%dx%d", ivfNlistsParam, pqMParam, pqCodeSizeParam); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, ivfNlistsParam) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqMParam) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqCodeSizeParam) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + + Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + + assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); + assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); + } + public void testMethodAsMapBuilder() throws IOException { String methodName = "test-method"; String methodDescription = "test-description"; diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 6d52e5544..d1a5be741 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -14,6 +14,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.junit.BeforeClass; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; @@ -36,9 +38,13 @@ import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.INDEX_THREAD_QTY; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; public class JNIServiceTests extends KNNTestCase { @@ -765,28 +771,94 @@ public void testTransferVectors() { JNIService.freeVectors(trainPointer1); } - public void testTrain() { + public void testTrain_whenConfigurationIsIVFFlat_thenSucceed() throws IOException { + long trainPointer = transferVectors(10); + int ivfNlistParam = 16; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, ivfNlistParam) + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + Map parameters = KNNEngine.FAISS.getMethodAsMap(knnMethodContext); + byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, FAISS_NAME); + + assertNotEquals(0, faissIndex.length); + JNIService.freeVectors(trainPointer); + } + + public void testTrain_whenConfigurationIsIVFPQ_thenSucceed() throws IOException { + long trainPointer = transferVectors(10); + int ivfNlistParam = 16; + int pqMParam = 4; + int pqCodeSizeParam = 4; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, ivfNlistParam) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqMParam) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqCodeSizeParam) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + Map parameters = KNNEngine.FAISS.getMethodAsMap(knnMethodContext); + + byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, FAISS_NAME); + + assertNotEquals(0, faissIndex.length); + JNIService.freeVectors(trainPointer); + } + + public void testTrain_whenConfigurationIsHNSWPQ_thenSucceed() throws IOException { + long trainPointer = transferVectors(10); + int pqMParam = 4; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqMParam) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + Map parameters = KNNEngine.FAISS.getMethodAsMap(knnMethodContext); + + byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, FAISS_NAME); + + assertNotEquals(0, faissIndex.length); + JNIService.freeVectors(trainPointer); + } + + private long transferVectors(int numDuplicates) { long trainPointer1 = JNIService.transferVectors(0, testData.indexData.vectors); assertNotEquals(0, trainPointer1); long trainPointer2; - for (int i = 0; i < 10; i++) { + for (int i = 0; i < numDuplicates; i++) { trainPointer2 = JNIService.transferVectors(trainPointer1, testData.indexData.vectors); assertEquals(trainPointer1, trainPointer2); } - Map parameters = ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, - "IVF16,PQ4", - KNNConstants.SPACE_TYPE, - SpaceType.L2.getValue() - ); - - byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer1, FAISS_NAME); - - assertNotEquals(0, faissIndex.length); - JNIService.freeVectors(trainPointer1); + return trainPointer1; } public void testCreateIndexFromTemplate() throws IOException { diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 626e3fe0b..011ec712c 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -1154,6 +1154,24 @@ public Response getModel(String modelId, List filters) throws IOExceptio return client().performRequest(request); } + /** + * Delete the model + * + * @param modelId Id of model to be retrieved + * @throws IOException if request cannot be performed + */ + public void deleteModel(String modelId) throws IOException { + if (modelId == null) { + modelId = ""; + } else { + modelId = "/" + modelId; + } + + Request request = new Request("DELETE", "/_plugins/_knn/models" + modelId); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + public void assertTrainingSucceeds(String modelId, int attempts, int delayInMillis) throws InterruptedException, Exception { int attemptNum = 0; Response response; @@ -1254,6 +1272,18 @@ protected void ingestDataAndTrainModel( Map method ) throws Exception { int trainingDataCount = 40; + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, method, trainingDataCount); + } + + protected void ingestDataAndTrainModel( + String modelId, + String trainingIndexName, + String trainingFieldName, + int dimension, + String modelDescription, + Map method, + int trainingDataCount + ) throws Exception { bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); Response trainResponse = trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, modelDescription);