diff --git a/CHANGELOG.md b/CHANGELOG.md index 310d8951c1..cec57d3f81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements ### Bug Fixes * Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305) +* Fixed field value from nested mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318) ### Infrastructure * Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289) ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 574e4a9773..a38291fcbc 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -28,6 +28,7 @@ import java.util.HashMap; import java.util.Map; +import static org.opensearch.index.mapper.MapperService.INDEX_MAPPING_NESTED_FIELDS_LIMIT_SETTING; import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; @@ -59,6 +60,39 @@ public static int getFileSizeInKB(String filePath) { return Math.toIntExact((file.length() / BYTES_PER_KILOBYTES) + 1L); // Add one so that integer division rounds up } + /** + * This method retrieves a specified field mapping from a set of mapping properties. + * + * @param properties A map representing properties, where each key is a property name and + * the value is either a sub-map of properties or the property value itself. + * @param fieldPaths The field path list that make up the path to the field mapping. + * @return The value of the field if found, or null if the field is not present in the map. + */ + public static Object getFieldMapping(final Map properties, final String[] fieldPaths) { + Object currentFieldMapping = properties; + + for (String path : fieldPaths) { + if (!(currentFieldMapping instanceof Map)) { + return null; + } + + currentFieldMapping = ((Map) currentFieldMapping).get(path); + if (currentFieldMapping == null) { + return null; + } + + if (currentFieldMapping instanceof Map) { + Object possibleProperties = ((Map) currentFieldMapping).get("properties"); + if (possibleProperties instanceof Map) { + currentFieldMapping = possibleProperties; + } + } + } + + return currentFieldMapping; + } + + /** * Validate that a field is a k-NN vector field and has the expected dimension * @@ -100,7 +134,24 @@ public static ValidationException validateKnnField( return exception; } - Object fieldMapping = properties.get(field); + // Check field path is valid + if (field.isEmpty()) { + exception.addValidationError("Field path is empty"); + return exception; + } + + String[] fieldPaths = field.split("\\."); + + Long nestedFieldMaxLimit = INDEX_MAPPING_NESTED_FIELDS_LIMIT_SETTING.get(indexMetadata.getSettings()); + + // Check filed path length is valid + if (fieldPaths.length == 0 || fieldPaths.length > nestedFieldMaxLimit) { + exception.addValidationError(String.format("Field path length \"%s\" is invalid, it should > 0 and <= %d", + field, nestedFieldMaxLimit)); + return exception; + } + + Object fieldMapping = getFieldMapping(properties, fieldPaths); // Check field existence if (fieldMapping == null) { diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index c104e2df36..ef935d870e 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -29,6 +29,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.function.Consumer; public class VectorReader { @@ -182,9 +183,21 @@ public void onResponse(SearchResponse searchResponse) { int vectorsToAdd = Integer.min(maxVectorCount - collectedVectorCount, hits.length); List trainingData = new ArrayList<>(); - for (int i = 0; i < vectorsToAdd; i++) { + for (int vector = 0; vector < vectorsToAdd; vector++) { + Map sourceMap = hits[vector].getSourceAsMap(); + // The field name may be a nested field, so we need to split it and traverse the map. + // Example fieldName: "my_field" or "my_field.nested_field.nested_nested_field" + String[] fieldPath = fieldName.split("\\."); + Map currentMap = sourceMap; + + for (int pathPart = 0; pathPart < fieldPath.length - 1; pathPart++) { + currentMap = (Map) currentMap.get(fieldPath[pathPart]); + } + + List fieldList = (List) currentMap.get(fieldPath[fieldPath.length - 1]); + trainingData.add( - ((List) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new) + fieldList.stream().map(Number::floatValue).toArray(Float[]::new) ); } diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index 38037a7fe3..bf4341d3a8 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -89,6 +89,36 @@ protected void createKnnIndexMapping(String indexName, String fieldName, Integer request.source(fieldName, "type=knn_vector,dimension=" + dimensions); OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); } + /** + * Create simple k-NN mapping which can be nested. + * e.g. fieldPath = "a.b.c" will create mapping for "c" as knn_vector + */ + protected void createKnnNestedIndexMapping(String indexName, String fieldPath, Integer dimensions) + throws IOException { + PutMappingRequest request = new PutMappingRequest(indexName); + String[] path = fieldPath.split("\\."); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); + for (int i = 0; i < path.length; i++) { + xContentBuilder.startObject(path[i]); + if (i == path.length - 1) { + xContentBuilder.field("type", "knn_vector") + .field("dimension", dimensions.toString()); + } else { + xContentBuilder.startObject("properties"); + } + } + for (int i = path.length - 1; i >= 0; i--) { + if (i != path.length - 1) { + xContentBuilder.endObject(); + } + xContentBuilder.endObject(); + } + xContentBuilder.endObject().endObject(); + + request.source(xContentBuilder); + + OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet()); + } /** * Get default k-NN settings for test cases @@ -112,6 +142,31 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[] assertEquals(response.status(), RestStatus.CREATED); } + /** + * Add a k-NN doc to an index with nested knn_vector field + */ + protected void addKnnNestedDoc(String index, String docId, String fieldPath, Object[] vector) throws IOException, InterruptedException, + ExecutionException { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + String[] fieldParts = fieldPath.split("\\."); + + for (int i = 0; i < fieldParts.length - 1; i++) { + builder.startObject(fieldParts[i]); + } + builder.field(fieldParts[fieldParts.length - 1], vector); + for (int i = fieldParts.length - 2; i >= 0; i--) { + builder.endObject(); + } + builder.endObject(); + IndexRequest indexRequest = new IndexRequest().index(index) + .id(docId) + .source(builder) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + IndexResponse response = client().index(indexRequest).get(); + assertEquals(response.status(), RestStatus.CREATED); + } + /** * Add any document to index */ diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index 7013ef261a..4efe281130 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -20,6 +20,8 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.util.KNNEngine; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import static org.mockito.ArgumentMatchers.anyString; @@ -67,4 +69,34 @@ public void testGetLoadParameters() { assertEquals(spaceType2.getValue(), loadParameters.get(SPACE_TYPE)); assertEquals(efSearchValue, loadParameters.get(HNSW_ALGO_EF_SEARCH)); } + + public void testGetFieldMappingNonNestedField() { + Map fieldValues = Map.of("type", "knn_vector", "dimension", 8); + Map properties = Map.of("top_level_field", fieldValues); + String[] field = {"top_level_field"}; + + Object result = IndexUtil.getFieldMapping(properties, field); + assertEquals(fieldValues, result); + } + + public void testGetFieldMappingNestedField() { + Map deepFieldValues = Map.of("type", "knn_vector", "dimension", 8); + Map deepField = Map.of("train-field", deepFieldValues); + Map deepFieldProperties = Map.of("properties", deepField); + Map nest_b = Map.of("b", deepFieldProperties); + Map nest_b_properties = Map.of("properties", nest_b); + Map nest_a = Map.of("a", nest_b_properties); + String[] field = {"a", "b", "train-field"}; + + Object deepResult = IndexUtil.getFieldMapping(nest_a, field); + assertEquals(deepFieldValues, deepResult); + } + + public void testGetFieldMappingEmptyProperties() { + Map properties = Collections.emptyMap(); + String[] field = {"top_level_field"}; + + Object result = IndexUtil.getFieldMapping(properties, field); + assertNull(result); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java index 545fa7cc7f..1180dcf0a3 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java @@ -341,4 +341,74 @@ public void testTrainModel_success_noId() throws Exception { assertTrainingSucceeds(modelId, 30, 1000); } + + // Test to checks when user tries to train a model with nested fields + public void testTrainModel_success_nestedField() throws Exception { + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String nestedFieldPath = "a.b.train-field"; + int dimension = 8; + + // Create a training index and randomly ingest data into it + String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath); + createKnnIndex(trainingIndexName, mapping); + int trainingDataCount = 200; + bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension); + + // Call the train API with this definition: + /* + { + "training_index": "train_index", + "training_field": "a.b.train_field", + "dimension": 8, + "description": "this should be allowed to be null", + "method": { + "name":"ivf", + "engine":"faiss", + "space_type": "l2", + "parameters":{ + "nlist":1, + "encoder":{ + "name":"pq", + "parameters":{ + "code_size":2, + "m": 2 + } + } + } + } + } + */ + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); + Map method = xContentBuilderToMap(builder); + + Response trainResponse = trainModel(modelId, trainingIndexName, nestedFieldPath, dimension, method, "dummy description"); + + assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); + + Response getResponse = getModel(modelId, null); + String responseBody = EntityUtils.toString(getResponse.getEntity()); + assertNotNull(responseBody); + + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); + + assertEquals(modelId, responseMap.get(MODEL_ID)); + + assertTrainingSucceeds(modelId, 30, 1000); + } } diff --git a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java index 74008a0435..b7372e6086 100644 --- a/src/test/java/org/opensearch/knn/training/VectorReaderTests.java +++ b/src/test/java/org/opensearch/knn/training/VectorReaderTests.java @@ -34,6 +34,7 @@ public class VectorReaderTests extends KNNSingleNodeTestCase { private final static int DEFAULT_LATCH_TIMEOUT = 100; private final static String DEFAULT_INDEX_NAME = "test-index"; private final static String DEFAULT_FIELD_NAME = "test-field"; + private final static String DEFAULT_NESTED_FIELD_PATH = "a.b.test-field"; private final static int DEFAULT_DIMENSION = 16; private final static int DEFAULT_NUM_VECTORS = 100; private final static int DEFAULT_MAX_VECTOR_COUNT = 10000; @@ -345,6 +346,46 @@ public void testRead_invalid_fieldIsNotKnn() throws InterruptedException, Execut ); } + public void testRead_valid_NestedField() throws InterruptedException, ExecutionException, IOException { + createIndex(DEFAULT_INDEX_NAME); + createKnnNestedIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_NESTED_FIELD_PATH, DEFAULT_DIMENSION); + + // Create list of random vectors and ingest + Random random = new Random(); + List vectors = new ArrayList<>(); + for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) { + Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new); + vectors.add(vector); + addKnnNestedDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_NESTED_FIELD_PATH, vector); + } + + // Configure VectorReader + ClusterService clusterService = node().injector().getInstance(ClusterService.class); + VectorReader vectorReader = new VectorReader(client()); + + // Read all vectors and confirm they match vectors + TestVectorConsumer testVectorConsumer = new TestVectorConsumer(); + final CountDownLatch inProgressLatch = new CountDownLatch(1); + vectorReader.read( + clusterService, + DEFAULT_INDEX_NAME, + DEFAULT_NESTED_FIELD_PATH, + DEFAULT_MAX_VECTOR_COUNT, + DEFAULT_SEARCH_SIZE, + testVectorConsumer, + createOnSearchResponseCountDownListener(inProgressLatch) + ); + + assertLatchDecremented(inProgressLatch); + + List consumedVectors = testVectorConsumer.getVectorsConsumed(); + assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size()); + + List flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + List flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList()); + assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors)); + } + private static class TestVectorConsumer implements Consumer> { List vectorsConsumed; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 642b8bfb92..15f3a4c160 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -356,6 +356,41 @@ protected String createKnnIndexMapping(List fieldNames, List di return xContentBuilder.toString(); } + /** + * Utility to create a Knn Index Mapping with nested field + * + * @param dimensions dimension of the vector + * @param fieldPath path of the nested field, e.g. "my_nested_field.my_vector" + * @return mapping string for the nested field + */ + protected String createKnnIndexNestedMapping(Integer dimensions, String fieldPath) throws IOException { + String[] fieldPathArray = fieldPath.split("\\."); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties"); + + for (int i = 0; i < fieldPathArray.length; i++) { + xContentBuilder.startObject(fieldPathArray[i]); + if (i == fieldPathArray.length - 1) { + xContentBuilder.field("type", "knn_vector") + .field("dimension", dimensions.toString()); + } else { + xContentBuilder.startObject("properties"); + } + } + + for (int i = fieldPathArray.length - 1; i >= 0; i--) { + if (i != fieldPathArray.length - 1) { + xContentBuilder.endObject(); + } + xContentBuilder.endObject(); + } + + xContentBuilder.endObject().endObject(); + + return xContentBuilder.toString(); + } + + + /** * Get index mapping as map * @@ -423,6 +458,37 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[] assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } + /** + * Add a single KNN Doc to an index with a nested vector field + * + * @param index name of the index + * @param docId id of the document + * @param nestedFieldPath path of the nested field, e.g. "my_nested_field.my_vector" + * @param vector vector to add + * + */ + protected void addKnnDocWithNestedField(String index, String docId, String nestedFieldPath, Object[] vector) throws IOException { + String[] fieldParts = nestedFieldPath.split("\\."); + + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldParts.length - 1; i++) { + builder.startObject(fieldParts[i]); + } + builder.field(fieldParts[fieldParts.length - 1], vector); + for (int i = fieldParts.length - 2; i >= 0; i--) { + builder.endObject(); + } + builder.endObject(); + + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + request.setJsonEntity(builder.toString()); + client().performRequest(request); + + request = new Request("POST", "/" + index + "/_refresh"); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + /** * Add a single KNN Doc to an index with multiple fields */ @@ -853,6 +919,25 @@ public void bulkIngestRandomVectors(String indexName, String fieldName, int numV } + /** + * Bulk ingest random vectors with nested field + * + * @param indexName index name + * @param nestedFieldPath nested field path, e.g. "my_nested_field.my_vector_field" + * @param numVectors number of vectors + * @param dimension vector dimension + */ + public void bulkIngestRandomVectorsWithNestedField(String indexName, String nestedFieldPath, int numVectors, int dimension) throws IOException { + for (int i = 0; i < numVectors; i++) { + float[] vector = new float[dimension]; + for (int j = 0; j < dimension; j++) { + vector[j] = randomFloat(); + } + + addKnnDocWithNestedField(indexName, String.valueOf(i + 1), nestedFieldPath, Floats.asList(vector).toArray()); + } + } + // Method that adds multiple documents into the index using Bulk API public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVectors, int docCount) throws IOException { Request request = new Request("POST", "/_bulk");