Skip to content

Commit

Permalink
Allowed using knn field path when train model
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Nov 29, 2023
1 parent 5e2f899 commit 32b0e7d
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
### Bug Fixes
* Fix use-after-free case on nmslib search path [#1305](https://github.com/opensearch-project/k-NN/pull/1305)
* Fixed field value from nested mapping when train model [#1318](https://github.com/opensearch-project/k-NN/pull/1318)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
### Documentation
Expand Down
53 changes: 52 additions & 1 deletion src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.index.mapper.MapperService.INDEX_MAPPING_NESTED_FIELDS_LIMIT_SETTING;
import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
Expand Down Expand Up @@ -59,6 +60,39 @@ public static int getFileSizeInKB(String filePath) {
return Math.toIntExact((file.length() / BYTES_PER_KILOBYTES) + 1L); // Add one so that integer division rounds up
}

/**
* This method retrieves a specified field mapping from a set of mapping properties.
*
* @param properties A map representing properties, where each key is a property name and
* the value is either a sub-map of properties or the property value itself.
* @param fieldPaths The field path list that make up the path to the field mapping.
* @return The value of the field if found, or null if the field is not present in the map.
*/
public static Object getFieldMapping(final Map<String, Object> properties, final String[] fieldPaths) {
Object currentFieldMapping = properties;

for (String path : fieldPaths) {
if (!(currentFieldMapping instanceof Map<?, ?>)) {
return null;
}

currentFieldMapping = ((Map<String, Object>) currentFieldMapping).get(path);
if (currentFieldMapping == null) {
return null;
}

if (currentFieldMapping instanceof Map<?, ?>) {
Object possibleProperties = ((Map<String, Object>) currentFieldMapping).get("properties");
if (possibleProperties instanceof Map<?, ?>) {
currentFieldMapping = possibleProperties;
}
}
}

return currentFieldMapping;
}


/**
* Validate that a field is a k-NN vector field and has the expected dimension
*
Expand Down Expand Up @@ -100,7 +134,24 @@ public static ValidationException validateKnnField(
return exception;
}

Object fieldMapping = properties.get(field);
// Check field path is valid
if (field.isEmpty()) {
exception.addValidationError("Field path is empty");
return exception;
}

String[] fieldPaths = field.split("\\.");

Long nestedFieldMaxLimit = INDEX_MAPPING_NESTED_FIELDS_LIMIT_SETTING.get(indexMetadata.getSettings());

// Check filed path length is valid
if (fieldPaths.length == 0 || fieldPaths.length > nestedFieldMaxLimit) {
exception.addValidationError(String.format("Field path length \"%s\" is invalid, it should > 0 and <= %d",
field, nestedFieldMaxLimit));
return exception;
}

Object fieldMapping = getFieldMapping(properties, fieldPaths);

// Check field existence
if (fieldMapping == null) {
Expand Down
17 changes: 15 additions & 2 deletions src/main/java/org/opensearch/knn/training/VectorReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

public class VectorReader {
Expand Down Expand Up @@ -182,9 +183,21 @@ public void onResponse(SearchResponse searchResponse) {
int vectorsToAdd = Integer.min(maxVectorCount - collectedVectorCount, hits.length);
List<Float[]> trainingData = new ArrayList<>();

for (int i = 0; i < vectorsToAdd; i++) {
for (int vector = 0; vector < vectorsToAdd; vector++) {
Map<String, Object> sourceMap = hits[vector].getSourceAsMap();
// The field name may be a nested field, so we need to split it and traverse the map.
// Example fieldName: "my_field" or "my_field.nested_field.nested_nested_field"
String[] fieldPath = fieldName.split("\\.");
Map<String, Object> currentMap = sourceMap;

for (int pathPart = 0; pathPart < fieldPath.length - 1; pathPart++) {
currentMap = (Map<String, Object>) currentMap.get(fieldPath[pathPart]);
}

List<Number> fieldList = (List<Number>) currentMap.get(fieldPath[fieldPath.length - 1]);

trainingData.add(
((List<Number>) hits[i].getSourceAsMap().get(fieldName)).stream().map(Number::floatValue).toArray(Float[]::new)
fieldList.stream().map(Number::floatValue).toArray(Float[]::new)
);
}

Expand Down
55 changes: 55 additions & 0 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ protected void createKnnIndexMapping(String indexName, String fieldName, Integer
request.source(fieldName, "type=knn_vector,dimension=" + dimensions);
OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet());
}
/**
* Create simple k-NN mapping which can be nested.
* e.g. fieldPath = "a.b.c" will create mapping for "c" as knn_vector
*/
protected void createKnnNestedIndexMapping(String indexName, String fieldPath, Integer dimensions)
throws IOException {
PutMappingRequest request = new PutMappingRequest(indexName);
String[] path = fieldPath.split("\\.");
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("properties");
for (int i = 0; i < path.length; i++) {
xContentBuilder.startObject(path[i]);
if (i == path.length - 1) {
xContentBuilder.field("type", "knn_vector")
.field("dimension", dimensions.toString());
} else {
xContentBuilder.startObject("properties");
}
}
for (int i = path.length - 1; i >= 0; i--) {
if (i != path.length - 1) {
xContentBuilder.endObject();
}
xContentBuilder.endObject();
}
xContentBuilder.endObject().endObject();

request.source(xContentBuilder);

OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet());
}

/**
* Get default k-NN settings for test cases
Expand All @@ -112,6 +142,31 @@ protected void addKnnDoc(String index, String docId, String fieldName, Object[]
assertEquals(response.status(), RestStatus.CREATED);
}

/**
* Add a k-NN doc to an index with nested knn_vector field
*/
protected void addKnnNestedDoc(String index, String docId, String fieldPath, Object[] vector) throws IOException, InterruptedException,
ExecutionException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
String[] fieldParts = fieldPath.split("\\.");

for (int i = 0; i < fieldParts.length - 1; i++) {
builder.startObject(fieldParts[i]);
}
builder.field(fieldParts[fieldParts.length - 1], vector);
for (int i = fieldParts.length - 2; i >= 0; i--) {
builder.endObject();
}
builder.endObject();
IndexRequest indexRequest = new IndexRequest().index(index)
.id(docId)
.source(builder)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

IndexResponse response = client().index(indexRequest).get();
assertEquals(response.status(), RestStatus.CREATED);
}

/**
* Add any document to index
*/
Expand Down
32 changes: 32 additions & 0 deletions src/test/java/org/opensearch/knn/index/IndexUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.mockito.ArgumentMatchers.anyString;
Expand Down Expand Up @@ -67,4 +69,34 @@ public void testGetLoadParameters() {
assertEquals(spaceType2.getValue(), loadParameters.get(SPACE_TYPE));
assertEquals(efSearchValue, loadParameters.get(HNSW_ALGO_EF_SEARCH));
}

public void testGetFieldMappingNonNestedField() {
Map<String, Object> fieldValues = Map.of("type", "knn_vector", "dimension", 8);
Map<String, Object> properties = Map.of("top_level_field", fieldValues);
String[] field = {"top_level_field"};

Object result = IndexUtil.getFieldMapping(properties, field);
assertEquals(fieldValues, result);
}

public void testGetFieldMappingNestedField() {
Map<String, Object> deepFieldValues = Map.of("type", "knn_vector", "dimension", 8);
Map<String, Object> deepField = Map.of("train-field", deepFieldValues);
Map<String, Object> deepFieldProperties = Map.of("properties", deepField);
Map<String, Object> nest_b = Map.of("b", deepFieldProperties);
Map<String, Object> nest_b_properties = Map.of("properties", nest_b);
Map<String, Object> nest_a = Map.of("a", nest_b_properties);
String[] field = {"a", "b", "train-field"};

Object deepResult = IndexUtil.getFieldMapping(nest_a, field);
assertEquals(deepFieldValues, deepResult);
}

public void testGetFieldMappingEmptyProperties() {
Map<String, Object> properties = Collections.emptyMap();
String[] field = {"top_level_field"};

Object result = IndexUtil.getFieldMapping(properties, field);
assertNull(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -341,4 +341,74 @@ public void testTrainModel_success_noId() throws Exception {

assertTrainingSucceeds(modelId, 30, 1000);
}

// Test to checks when user tries to train a model with nested fields
public void testTrainModel_success_nestedField() throws Exception {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String nestedFieldPath = "a.b.train-field";
int dimension = 8;

// Create a training index and randomly ingest data into it
String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath);
createKnnIndex(trainingIndexName, mapping);
int trainingDataCount = 200;
bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension);

// Call the train API with this definition:
/*
{
"training_index": "train_index",
"training_field": "a.b.train_field",
"dimension": 8,
"description": "this should be allowed to be null",
"method": {
"name":"ivf",
"engine":"faiss",
"space_type": "l2",
"parameters":{
"nlist":1,
"encoder":{
"name":"pq",
"parameters":{
"code_size":2,
"m": 2
}
}
}
}
}
*/
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, "ivf")
.field(KNN_ENGINE, "faiss")
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 1)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, "pq")
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2)
.field(ENCODER_PARAMETER_PQ_M, 2)
.endObject()
.endObject()
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(builder);

Response trainResponse = trainModel(modelId, trainingIndexName, nestedFieldPath, dimension, method, "dummy description");

assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode()));

Response getResponse = getModel(modelId, null);
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

assertTrainingSucceeds(modelId, 30, 1000);
}
}
41 changes: 41 additions & 0 deletions src/test/java/org/opensearch/knn/training/VectorReaderTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class VectorReaderTests extends KNNSingleNodeTestCase {
private final static int DEFAULT_LATCH_TIMEOUT = 100;
private final static String DEFAULT_INDEX_NAME = "test-index";
private final static String DEFAULT_FIELD_NAME = "test-field";
private final static String DEFAULT_NESTED_FIELD_PATH = "a.b.test-field";
private final static int DEFAULT_DIMENSION = 16;
private final static int DEFAULT_NUM_VECTORS = 100;
private final static int DEFAULT_MAX_VECTOR_COUNT = 10000;
Expand Down Expand Up @@ -345,6 +346,46 @@ public void testRead_invalid_fieldIsNotKnn() throws InterruptedException, Execut
);
}

public void testRead_valid_NestedField() throws InterruptedException, ExecutionException, IOException {
createIndex(DEFAULT_INDEX_NAME);
createKnnNestedIndexMapping(DEFAULT_INDEX_NAME, DEFAULT_NESTED_FIELD_PATH, DEFAULT_DIMENSION);

// Create list of random vectors and ingest
Random random = new Random();
List<Float[]> vectors = new ArrayList<>();
for (int i = 0; i < DEFAULT_NUM_VECTORS; i++) {
Float[] vector = random.doubles(DEFAULT_DIMENSION).boxed().map(Double::floatValue).toArray(Float[]::new);
vectors.add(vector);
addKnnNestedDoc(DEFAULT_INDEX_NAME, Integer.toString(i), DEFAULT_NESTED_FIELD_PATH, vector);
}

// Configure VectorReader
ClusterService clusterService = node().injector().getInstance(ClusterService.class);
VectorReader vectorReader = new VectorReader(client());

// Read all vectors and confirm they match vectors
TestVectorConsumer testVectorConsumer = new TestVectorConsumer();
final CountDownLatch inProgressLatch = new CountDownLatch(1);
vectorReader.read(
clusterService,
DEFAULT_INDEX_NAME,
DEFAULT_NESTED_FIELD_PATH,
DEFAULT_MAX_VECTOR_COUNT,
DEFAULT_SEARCH_SIZE,
testVectorConsumer,
createOnSearchResponseCountDownListener(inProgressLatch)
);

assertLatchDecremented(inProgressLatch);

List<Float[]> consumedVectors = testVectorConsumer.getVectorsConsumed();
assertEquals(DEFAULT_NUM_VECTORS, consumedVectors.size());

List<Float> flatVectors = vectors.stream().flatMap(Arrays::stream).collect(Collectors.toList());
List<Float> flatConsumedVectors = consumedVectors.stream().flatMap(Arrays::stream).collect(Collectors.toList());
assertEquals(new HashSet<>(flatVectors), new HashSet<>(flatConsumedVectors));
}

private static class TestVectorConsumer implements Consumer<List<Float[]>> {

List<Float[]> vectorsConsumed;
Expand Down
Loading

0 comments on commit 32b0e7d

Please sign in to comment.