diff --git a/CHANGELOG.md b/CHANGELOG.md index 310d8951c..3328dda52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements ### Bug Fixes * Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305) +* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318) ### Infrastructure * Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289) ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 4e2dc38c6..b3a24d2f0 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import org.apache.commons.lang.StringUtils; import org.opensearch.Version; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MappingMetadata; @@ -26,6 +27,7 @@ import java.io.File; import java.util.Collections; import java.util.HashMap; +import java.util.Locale; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; @@ -61,6 +63,37 @@ public static int getFileSizeInKB(String filePath) { return Math.toIntExact((file.length() / BYTES_PER_KILOBYTES) + 1L); // Add one so that integer division rounds up } + /** + * This method retrieves the field mapping by a given field path from the index metadata. + * + * @param properties Index metadata mapping properties. + * @param fieldPath The field path string that make up the path to the field mapping. e.g. "a.b.field" or "field". + * The field path is applied and checked in OpenSearch, so it is guaranteed to be valid. + * + * @return The field mapping object if found, or null if the field is not found in the index metadata. + */ + private static Object getFieldMapping(final Map properties, final String fieldPath) { + String[] fieldPaths = fieldPath.split("\\."); + Object currentFieldMapping = properties; + + // Iterate through the field path list to retrieve the field mapping. + for (String path : fieldPaths) { + currentFieldMapping = ((Map) currentFieldMapping).get(path); + if (currentFieldMapping == null) { + return null; + } + + if (currentFieldMapping instanceof Map) { + Object possibleProperties = ((Map) currentFieldMapping).get("properties"); + if (possibleProperties instanceof Map) { + currentFieldMapping = possibleProperties; + } + } + } + + return currentFieldMapping; + } + /** * Validate that a field is a k-NN vector field and has the expected dimension * @@ -102,7 +135,13 @@ public static ValidationException validateKnnField( return exception; } - Object fieldMapping = properties.get(field); + // Check field path is valid + if (StringUtils.isEmpty(field)) { + exception.addValidationError(String.format(Locale.ROOT, "Field path is empty.")); + return exception; + } + + Object fieldMapping = getFieldMapping(properties, field); // Check field existence if (fieldMapping == null) { diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index c104e2df3..aeebae129 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.function.Consumer; public class VectorReader { @@ -180,13 +181,7 @@ public void onResponse(SearchResponse searchResponse) { // Either add the entire set of returned hits, or maxVectorCount - collectedVectorCount hits SearchHit[] hits = searchResponse.getHits().getHits(); int vectorsToAdd = Integer.min(maxVectorCount - collectedVectorCount, hits.length); - List trainingData = new ArrayList<>(); - - for (int i = 0; i < vectorsToAdd; i++) { - trainingData.add( - ((List) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new) - ); - } + List trainingData = extractVectorsFromHits(searchResponse, vectorsToAdd); this.collectedVectorCount += trainingData.size(); @@ -225,5 +220,42 @@ public void onFailure(Exception e) { listener.onFailure(e); } } + + /** + * Extracts vectors from the hits in a search response + * + * @param searchResponse Search response to extract vectors from + * @param vectorsToAdd number of vectors to extract + * @return list of vectors + */ + private List extractVectorsFromHits(SearchResponse searchResponse, int vectorsToAdd) { + SearchHit[] hits = searchResponse.getHits().getHits(); + List trainingData = new ArrayList<>(); + String[] fieldPath = fieldName.split("\\."); + int nullVectorCount = 0; + + for (int vector = 0; vector < vectorsToAdd; vector++) { + Map currentMap = hits[vector].getSourceAsMap(); + // The field name may be a nested field, so we need to split it and traverse the map. + // Example fieldName: "my_field" or "my_field.nested_field.nested_nested_field" + + for (int pathPart = 0; pathPart < fieldPath.length - 1; pathPart++) { + currentMap = (Map) currentMap.get(fieldPath[pathPart]); + } + + if (currentMap.get(fieldPath[fieldPath.length - 1]) instanceof List == false) { + nullVectorCount++; + continue; + } + + List fieldList = (List) currentMap.get(fieldPath[fieldPath.length - 1]); + + trainingData.add(fieldList.stream().map(Number::floatValue).toArray(Float[]::new)); + } + if (nullVectorCount > 0) { + logger.warn("Found {} documents with null vectors in field {}", nullVectorCount, fieldName); + } + return trainingData; + } } } diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 792ebde69..323442fff 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -81,6 +81,35 @@ protected void createKnnIndexMapping(String indexName, String fieldName, Integer OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); } + /** + * Create simple k-NN mapping which can be nested. + * e.g. fieldPath = "a.b.c" will create mapping for "c" as knn_vector + */ + protected void createKnnNestedIndexMapping(String indexName, String fieldPath, Integer dimensions) throws IOException { + PutMappingRequest request = new PutMappingRequest(indexName); + String[] path = fieldPath.split("\\."); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); + for (int i = 0; i < path.length; i++) { + xContentBuilder.startObject(path[i]); + if (i == path.length - 1) { + xContentBuilder.field("type", "knn_vector").field("dimension", dimensions.toString()); + } else { + xContentBuilder.startObject("properties"); + } + } + for (int i = path.length - 1; i >= 0; i--) { + if (i != path.length - 1) { + xContentBuilder.endObject(); + } + xContentBuilder.endObject(); + } + xContentBuilder.endObject().endObject(); + + request.source(xContentBuilder); + + OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); + } + /** * Get default k-NN settings for test cases */ @@ -103,6 +132,31 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[] assertEquals(response.status(), RestStatus.CREATED); } + /** + * Add a k-NN doc to an index with nested knn_vector field + */ + protected void addKnnNestedDoc(String index, String docId, String fieldPath, Object[] vector) throws IOException, InterruptedException, + ExecutionException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + String[] fieldParts = fieldPath.split("\\."); + + for (int i = 0; i < fieldParts.length - 1; i++) { + builder.startObject(fieldParts[i]); + } + builder.field(fieldParts[fieldParts.length - 1], vector); + for (int i = fieldParts.length - 2; i >= 0; i--) { + builder.endObject(); + } + builder.endObject(); + IndexRequest indexRequest = new IndexRequest().index(index) + .id(docId) + .source(builder) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + IndexResponse response = client().index(indexRequest).get(); + assertEquals(response.status(), RestStatus.CREATED); + } + /** * Add any document to index */ diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index 7013ef261..dc9c980e0 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -14,13 +14,18 @@ import com.google.common.collect.ImmutableMap; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.ValidationException; import org.opensearch.common.settings.Settings; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; import java.util.Map; +import java.util.Objects; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; @@ -67,4 +72,136 @@ public void testGetLoadParameters() { assertEquals(spaceType2.getValue(), loadParameters.get(SPACE_TYPE)); assertEquals(efSearchValue, loadParameters.get(HNSW_ALGO_EF_SEARCH)); } + + public void testValidateKnnField_NestedField() { + Map deepFieldValues = Map.of("type", "knn_vector", "dimension", 8); + Map deepField = Map.of("train-field", deepFieldValues); + Map deepFieldProperties = Map.of("properties", deepField); + Map nest_b = Map.of("b", deepFieldProperties); + Map nest_b_properties = Map.of("properties", nest_b); + Map nest_a = Map.of("a", nest_b_properties); + Map properties = Map.of("properties", nest_a); + + String field = "a.b.train-field"; + int dimension = 8; + + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(mappingMetadata.getSourceAsMap()).thenReturn(properties); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); + + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + + assertNull(e); + } + + public void testValidateKnnField_NonNestedField() { + Map fieldValues = Map.of("type", "knn_vector", "dimension", 8); + Map top_level_field = Map.of("top_level_field", fieldValues); + Map properties = Map.of("properties", top_level_field); + String field = "top_level_field"; + int dimension = 8; + + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(mappingMetadata.getSourceAsMap()).thenReturn(properties); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); + + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + + assertNull(e); + } + + public void testValidateKnnField_NonKnnField() { + Map fieldValues = Map.of("type", "text"); + Map top_level_field = Map.of("top_level_field", fieldValues); + Map properties = Map.of("properties", top_level_field); + String field = "top_level_field"; + int dimension = 8; + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(mappingMetadata.getSourceAsMap()).thenReturn(properties); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); + + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + + assert Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" is not of type knn_vector.;"); + } + + public void testValidateKnnField_WrongFieldPath() { + Map deepFieldValues = Map.of("type", "knn_vector", "dimension", 8); + Map deepField = Map.of("train-field", deepFieldValues); + Map deepFieldProperties = Map.of("properties", deepField); + Map nest_b = Map.of("b", deepFieldProperties); + Map nest_b_properties = Map.of("properties", nest_b); + Map nest_a = Map.of("a", nest_b_properties); + Map properties = Map.of("properties", nest_a); + String field = "a.train-field"; + int dimension = 8; + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(mappingMetadata.getSourceAsMap()).thenReturn(properties); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); + + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + + assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" does not exist.;")); + } + + public void testValidateKnnField_EmptyField() { + Map deepFieldValues = Map.of("type", "knn_vector", "dimension", 8); + Map deepField = Map.of("train-field", deepFieldValues); + Map deepFieldProperties = Map.of("properties", deepField); + Map nest_b = Map.of("b", deepFieldProperties); + Map nest_b_properties = Map.of("properties", nest_b); + Map nest_a = Map.of("a", nest_b_properties); + Map properties = Map.of("properties", nest_a); + String field = ""; + int dimension = 8; + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(mappingMetadata.getSourceAsMap()).thenReturn(properties); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); + + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + + System.out.println(Objects.requireNonNull(e).getMessage()); + + assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field path is empty.;")); + } + + public void testValidateKnnField_EmptyIndexMetadata() { + String field = "a.b.train-field"; + int dimension = 8; + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(null); + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class); + when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); + when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); + + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + + assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;")); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java index 4b996be28..b2f429e2a 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java @@ -342,4 +342,74 @@ public void testTrainModel_success_noId() throws IOException, InterruptedExcepti assertTrainingSucceeds(modelId, 30, 1000); } + + // Test to checks when user tries to train a model with nested fields + public void testTrainModel_success_nestedField() throws Exception { + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String nestedFieldPath = "a.b.train-field"; + int dimension = 8; + + // Create a training index and randomly ingest data into it + String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath); + createKnnIndex(trainingIndexName, mapping); + int trainingDataCount = 200; + bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension); + + // Call the train API with this definition: + /* + { + "training_index": "train_index", + "training_field": "a.b.train_field", + "dimension": 8, + "description": "this should be allowed to be null", + "method": { + "name":"ivf", + "engine":"faiss", + "space_type": "l2", + "parameters":{ + "nlist":1, + "encoder":{ + "name":"pq", + "parameters":{ + "code_size":2, + "m": 2 + } + } + } + } + } + */ + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); + Map method = xContentBuilderToMap(builder); + + Response trainResponse = trainModel(modelId, trainingIndexName, nestedFieldPath, dimension, method, "dummy description"); + + assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); + + Response getResponse = getModel(modelId, null); + String responseBody = EntityUtils.toString(getResponse.getEntity()); + assertNotNull(responseBody); + + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + + assertEquals(modelId, responseMap.get(MODEL_ID)); + + assertTrainingSucceeds(modelId, 30, 1000); + } } diff --git a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java index 74008a043..209c9cc73 100644 --- a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java +++ b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java @@ -34,6 +34,7 @@ public class VectorReaderTests extends KNNSingleNodeTestCase { private final static int DEFAULT_LATCH_TIMEOUT = 100; private final static String DEFAULT_INDEX_NAME = "test-index"; private final static String DEFAULT_FIELD_NAME = "test-field"; + private final static String DEFAULT_NESTED_FIELD_PATH = "a.b.test-field"; private final static int DEFAULT_DIMENSION = 16; private final static int DEFAULT_NUM_VECTORS = 100; private final static int DEFAULT_MAX_VECTOR_COUNT = 10000; @@ -345,6 +346,46 @@ public void testRead_invalid_fieldIsNotKnn() throws InterruptedException, Execut ); } + public void testRead_valid_NestedField() throws InterruptedException, ExecutionException, IOException { + createIndex(DEFAULT_INDEX_NAME); + createKnnNestedIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_NESTED_FIELD_PATH, DEFAULT_DIMENSION); + + // Create list of random vectors and ingest + Random random = new Random(); + List vectors = new ArrayList<>(); + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); + vectors.add(vector); + addKnnNestedDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_NESTED_FIELD_PATH, vector); + } + + // Configure VectorReader + ClusterService clusterService = node().injector().getInstance(ClusterService.class); + VectorReader vectorReader = new VectorReader(client()); + + // Read all vectors and confirm they match vectors + TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_NESTED_FIELD_PATH, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); + + List consumedVectors = testVectorConsumer.getVectorsConsumed(); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); + + List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); + } + private static class TestVectorConsumer implements Consumer> { List vectorsConsumed; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 678f490fa..6f08627ae 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -349,6 +349,38 @@ protected String createKnnIndexMapping(List fieldNames, List di return xContentBuilder.toString(); } + /** + * Utility to create a Knn Index Mapping with nested field + * + * @param dimensions dimension of the vector + * @param fieldPath path of the nested field, e.g. "my_nested_field.my_vector" + * @return mapping string for the nested field + */ + protected String createKnnIndexNestedMapping(Integer dimensions, String fieldPath) throws IOException { + String[] fieldPathArray = fieldPath.split("\\."); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); + + for (int i = 0; i < fieldPathArray.length; i++) { + xContentBuilder.startObject(fieldPathArray[i]); + if (i == fieldPathArray.length - 1) { + xContentBuilder.field("type", "knn_vector").field("dimension", dimensions.toString()); + } else { + xContentBuilder.startObject("properties"); + } + } + + for (int i = fieldPathArray.length - 1; i >= 0; i--) { + if (i != fieldPathArray.length - 1) { + xContentBuilder.endObject(); + } + xContentBuilder.endObject(); + } + + xContentBuilder.endObject().endObject(); + + return xContentBuilder.toString(); + } + /** * Get index mapping as map * @@ -416,6 +448,37 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[] assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + /** + * Add a single KNN Doc to an index with a nested vector field + * + * @param index name of the index + * @param docId id of the document + * @param nestedFieldPath path of the nested field, e.g. "my_nested_field.my_vector" + * @param vector vector to add + * + */ + protected void addKnnDocWithNestedField(String index, String docId, String nestedFieldPath, Object[] vector) throws IOException { + String[] fieldParts = nestedFieldPath.split("\\."); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldParts.length - 1; i++) { + builder.startObject(fieldParts[i]); + } + builder.field(fieldParts[fieldParts.length - 1], vector); + for (int i = fieldParts.length - 2; i >= 0; i--) { + builder.endObject(); + } + builder.endObject(); + + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + request.setJsonEntity(builder.toString()); + client().performRequest(request); + + request = new Request("POST", "/" + index + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Add a single KNN Doc to an index with multiple fields */ @@ -825,6 +888,26 @@ public void bulkIngestRandomVectors(String indexName, String fieldName, int numV } + /** + * Bulk ingest random vectors with nested field + * + * @param indexName index name + * @param nestedFieldPath nested field path, e.g. "my_nested_field.my_vector_field" + * @param numVectors number of vectors + * @param dimension vector dimension + */ + public void bulkIngestRandomVectorsWithNestedField(String indexName, String nestedFieldPath, int numVectors, int dimension) + throws IOException { + for (int i = 0; i < numVectors; i++) { + float[] vector = new float[dimension]; + for (int j = 0; j < dimension; j++) { + vector[j] = randomFloat(); + } + + addKnnDocWithNestedField(indexName, String.valueOf(i + 1), nestedFieldPath, Floats.asList(vector).toArray()); + } + } + // Method that adds multiple documents into the index using Bulk API public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVectors, int docCount) throws IOException { Request request = new Request("POST", "/_bulk");