Skip to content

Commit

Permalink
Allow nested knn field mapping when train model (#1318) (#1339)
Browse files Browse the repository at this point in the history
(cherry picked from commit 2e3ab95)

Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei authored Dec 7, 2023
1 parent d5db2ae commit 06d52d5
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Allow nested knn field mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
### Documentation
Expand Down
41 changes: 40 additions & 1 deletion src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import org.apache.commons.lang.StringUtils;
import org.opensearch.Version;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
Expand All @@ -26,6 +27,7 @@
import java.io.File;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
Expand Down Expand Up @@ -61,6 +63,37 @@ public static int getFileSizeInKB(String filePath) {
return Math.toIntExact((file.length() / BYTES_PER_KILOBYTES) + 1L); // Add one so that integer division rounds up
}

/**
* This method retrieves the field mapping by a given field path from the index metadata.
*
* @param properties Index metadata mapping properties.
* @param fieldPath The field path string that make up the path to the field mapping. e.g. "a.b.field" or "field".
* The field path is applied and checked in OpenSearch, so it is guaranteed to be valid.
*
* @return The field mapping object if found, or null if the field is not found in the index metadata.
*/
private static Object getFieldMapping(final Map<String, Object> properties, final String fieldPath) {
String[] fieldPaths = fieldPath.split("\\.");
Object currentFieldMapping = properties;

// Iterate through the field path list to retrieve the field mapping.
for (String path : fieldPaths) {
currentFieldMapping = ((Map<String, Object>) currentFieldMapping).get(path);
if (currentFieldMapping == null) {
return null;
}

if (currentFieldMapping instanceof Map<?, ?>) {
Object possibleProperties = ((Map<String, Object>) currentFieldMapping).get("properties");
if (possibleProperties instanceof Map<?, ?>) {
currentFieldMapping = possibleProperties;
}
}
}

return currentFieldMapping;
}

/**
* Validate that a field is a k-NN vector field and has the expected dimension
*
Expand Down Expand Up @@ -102,7 +135,13 @@ public static ValidationException validateKnnField(
return exception;
}

Object fieldMapping = properties.get(field);
// Check field path is valid
if (StringUtils.isEmpty(field)) {
exception.addValidationError(String.format(Locale.ROOT, "Field path is empty."));
return exception;
}

Object fieldMapping = getFieldMapping(properties, field);

// Check field existence
if (fieldMapping == null) {
Expand Down
46 changes: 39 additions & 7 deletions src/main/java/org/opensearch/knn/training/VectorReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

public class VectorReader {
Expand Down Expand Up @@ -180,13 +181,7 @@ public void onResponse(SearchResponse searchResponse) {
// Either add the entire set of returned hits, or maxVectorCount - collectedVectorCount hits
SearchHit[] hits = searchResponse.getHits().getHits();
int vectorsToAdd = Integer.min(maxVectorCount - collectedVectorCount, hits.length);
List<Float[]> trainingData = new ArrayList<>();

for (int i = 0; i < vectorsToAdd; i++) {
trainingData.add(
((List<Number>) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new)
);
}
List<Float[]> trainingData = extractVectorsFromHits(searchResponse, vectorsToAdd);

this.collectedVectorCount += trainingData.size();

Expand Down Expand Up @@ -225,5 +220,42 @@ public void onFailure(Exception e) {
listener.onFailure(e);
}
}

/**
* Extracts vectors from the hits in a search response
*
* @param searchResponse Search response to extract vectors from
* @param vectorsToAdd number of vectors to extract
* @return list of vectors
*/
private List<Float[]> extractVectorsFromHits(SearchResponse searchResponse, int vectorsToAdd) {
SearchHit[] hits = searchResponse.getHits().getHits();
List<Float[]> trainingData = new ArrayList<>();
String[] fieldPath = fieldName.split("\\.");
int nullVectorCount = 0;

for (int vector = 0; vector < vectorsToAdd; vector++) {
Map<String, Object> currentMap = hits[vector].getSourceAsMap();
// The field name may be a nested field, so we need to split it and traverse the map.
// Example fieldName: "my_field" or "my_field.nested_field.nested_nested_field"

for (int pathPart = 0; pathPart < fieldPath.length - 1; pathPart++) {
currentMap = (Map<String, Object>) currentMap.get(fieldPath[pathPart]);
}

if (currentMap.get(fieldPath[fieldPath.length - 1]) instanceof List<?> == false) {
nullVectorCount++;
continue;
}

List<Number> fieldList = (List<Number>) currentMap.get(fieldPath[fieldPath.length - 1]);

trainingData.add(fieldList.stream().map(Number::floatValue).toArray(Float[]::new));
}
if (nullVectorCount > 0) {
logger.warn("Found {} documents with null vectors in field {}", nullVectorCount, fieldName);
}
return trainingData;
}
}
}
54 changes: 54 additions & 0 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,35 @@ protected void createKnnIndexMapping(String indexName, String fieldName, Integer
OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet());
}

/**
* Create simple k-NN mapping which can be nested.
* e.g. fieldPath = "a.b.c" will create mapping for "c" as knn_vector
*/
protected void createKnnNestedIndexMapping(String indexName, String fieldPath, Integer dimensions) throws IOException {
PutMappingRequest request = new PutMappingRequest(indexName);
String[] path = fieldPath.split("\\.");
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties");
for (int i = 0; i < path.length; i++) {
xContentBuilder.startObject(path[i]);
if (i == path.length - 1) {
xContentBuilder.field("type", "knn_vector").field("dimension", dimensions.toString());
} else {
xContentBuilder.startObject("properties");
}
}
for (int i = path.length - 1; i >= 0; i--) {
if (i != path.length - 1) {
xContentBuilder.endObject();
}
xContentBuilder.endObject();
}
xContentBuilder.endObject().endObject();

request.source(xContentBuilder);

OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet());
}

/**
* Get default k-NN settings for test cases
*/
Expand All @@ -103,6 +132,31 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[]
assertEquals(response.status(), RestStatus.CREATED);
}

/**
* Add a k-NN doc to an index with nested knn_vector field
*/
protected void addKnnNestedDoc(String index, String docId, String fieldPath, Object[] vector) throws IOException, InterruptedException,
ExecutionException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
String[] fieldParts = fieldPath.split("\\.");

for (int i = 0; i < fieldParts.length - 1; i++) {
builder.startObject(fieldParts[i]);
}
builder.field(fieldParts[fieldParts.length - 1], vector);
for (int i = fieldParts.length - 2; i >= 0; i--) {
builder.endObject();
}
builder.endObject();
IndexRequest indexRequest = new IndexRequest().index(index)
.id(docId)
.source(builder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

IndexResponse response = client().index(indexRequest).get();
assertEquals(response.status(), RestStatus.CREATED);
}

/**
* Add any document to index
*/
Expand Down
137 changes: 137 additions & 0 deletions src/test/java/org/opensearch/knn/index/IndexUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
import com.google.common.collect.ImmutableMap;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.ValidationException;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;

import java.util.Map;
import java.util.Objects;

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -67,4 +72,136 @@ public void testGetLoadParameters() {
assertEquals(spaceType2.getValue(), loadParameters.get(SPACE_TYPE));
assertEquals(efSearchValue, loadParameters.get(HNSW_ALGO_EF_SEARCH));
}

public void testValidateKnnField_NestedField() {
Map<String, Object> deepFieldValues = Map.of("type", "knn_vector", "dimension", 8);
Map<String, Object> deepField = Map.of("train-field", deepFieldValues);
Map<String, Object> deepFieldProperties = Map.of("properties", deepField);
Map<String, Object> nest_b = Map.of("b", deepFieldProperties);
Map<String, Object> nest_b_properties = Map.of("properties", nest_b);
Map<String, Object> nest_a = Map.of("a", nest_b_properties);
Map<String, Object> properties = Map.of("properties", nest_a);

String field = "a.b.train-field";
int dimension = 8;

MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);

assertNull(e);
}

public void testValidateKnnField_NonNestedField() {
Map<String, Object> fieldValues = Map.of("type", "knn_vector", "dimension", 8);
Map<String, Object> top_level_field = Map.of("top_level_field", fieldValues);
Map<String, Object> properties = Map.of("properties", top_level_field);
String field = "top_level_field";
int dimension = 8;

MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);

assertNull(e);
}

public void testValidateKnnField_NonKnnField() {
Map<String, Object> fieldValues = Map.of("type", "text");
Map<String, Object> top_level_field = Map.of("top_level_field", fieldValues);
Map<String, Object> properties = Map.of("properties", top_level_field);
String field = "top_level_field";
int dimension = 8;
MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);

assert Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" is not of type knn_vector.;");
}

public void testValidateKnnField_WrongFieldPath() {
Map<String, Object> deepFieldValues = Map.of("type", "knn_vector", "dimension", 8);
Map<String, Object> deepField = Map.of("train-field", deepFieldValues);
Map<String, Object> deepFieldProperties = Map.of("properties", deepField);
Map<String, Object> nest_b = Map.of("b", deepFieldProperties);
Map<String, Object> nest_b_properties = Map.of("properties", nest_b);
Map<String, Object> nest_a = Map.of("a", nest_b_properties);
Map<String, Object> properties = Map.of("properties", nest_a);
String field = "a.train-field";
int dimension = 8;
MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);

assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" does not exist.;"));
}

public void testValidateKnnField_EmptyField() {
Map<String, Object> deepFieldValues = Map.of("type", "knn_vector", "dimension", 8);
Map<String, Object> deepField = Map.of("train-field", deepFieldValues);
Map<String, Object> deepFieldProperties = Map.of("properties", deepField);
Map<String, Object> nest_b = Map.of("b", deepFieldProperties);
Map<String, Object> nest_b_properties = Map.of("properties", nest_b);
Map<String, Object> nest_a = Map.of("a", nest_b_properties);
Map<String, Object> properties = Map.of("properties", nest_a);
String field = "";
int dimension = 8;
MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);

System.out.println(Objects.requireNonNull(e).getMessage());

assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field path is empty.;"));
}

public void testValidateKnnField_EmptyIndexMetadata() {
String field = "a.b.train-field";
int dimension = 8;
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(null);
ModelDao modelDao = mock(ModelDao.class);
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);

assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;"));
}
}
Loading

0 comments on commit 06d52d5

Please sign in to comment.