Skip to content

Commit

Permalink
Allow method parameter override for training based indices (solves is…
Browse files Browse the repository at this point in the history
…sue #2246) (#2290)

* Allow method parameter override for training based indices

Signed-off-by: Sahil Buddharaju <[email protected]>

* Fixed code squashing imports

Signed-off-by: Sahil Buddharaju <[email protected]>

* Changed changelog

Signed-off-by: Sahil Buddharaju <[email protected]>

* spotlessApply styling

Signed-off-by: Sahil Buddharaju <[email protected]>

---------

Signed-off-by: Sahil Buddharaju <[email protected]>
Co-authored-by: Sahil Buddharaju <[email protected]>
  • Loading branch information
buddharajusahil and Sahil Buddharaju authored Dec 10, 2024
1 parent 9276c77 commit 19f045d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ && ensureSpaceTypeNotSet(topLevelSpaceType)) {
}

ensureAtleastOneSet(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode, COMPRESSION_LEVEL_PARAMETER, compressionLevel);
ensureMutualExclusion(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode);
ensureMutualExclusion(KNN_METHOD, knnMethodContext, COMPRESSION_LEVEL_PARAMETER, compressionLevel);

ensureSet(DIMENSION, dimension);
ensureSet(TRAIN_INDEX_PARAMETER, trainingIndex);
Expand Down
24 changes: 22 additions & 2 deletions src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,23 @@
import java.util.List;

import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST_DEFAULT;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
Expand Down Expand Up @@ -260,17 +266,31 @@ public void testCompressionIndexWithNonVectorFieldsSegment_whenValid_ThenSucceed
public void testTraining_whenInvalid_thenFail() {
setupTrainingIndex();
String modelId = "test";

XContentBuilder builder1 = XContentFactory.jsonBuilder()
.startObject()
.field(TRAIN_INDEX_PARAMETER, TRAINING_INDEX_NAME)
.field(TRAIN_FIELD_PARAMETER, TRAINING_FIELD_NAME)
.field(KNNConstants.DIMENSION, DIMENSION)
.field(VECTOR_DATA_TYPE_FIELD, "float")
.field(MODEL_DESCRIPTION, "")
.field(MODE_PARAMETER, Mode.ON_DISK)
.field(COMPRESSION_LEVEL_PARAMETER, "16x")
.startObject(KNN_METHOD)
.field(NAME, METHOD_IVF)
.field(KNN_ENGINE, FAISS_NAME)
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 1)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, "pq")
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2)
.field(ENCODER_PARAMETER_PQ_M, 8)
.endObject()
.endObject()
.endObject()
.endObject()
.field(MODEL_DESCRIPTION, "")
.field(MODE_PARAMETER, Mode.ON_DISK)
.endObject();
expectThrows(ResponseException.class, () -> trainModel(modelId, builder1));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.plugin.action;

import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.XContentBuilder;
Expand All @@ -22,15 +23,22 @@

import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;

public class RestTrainModelHandlerIT extends KNNRestTestCase {

Expand Down Expand Up @@ -472,4 +480,75 @@ public void testTrainModel_success_nestedField() throws Exception {

assertTrainingSucceeds(modelId, 30, 1000);
}

// Test to checks when user tries to train a model compression/mode and method
public void testTrainModel_success_methodOverrideWithCompressionMode() throws Exception {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String nestedFieldPath = "a.b.train-field";
int dimension = 8;

// Create a training index and randomly ingest data into it
String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath);
createKnnIndex(trainingIndexName, mapping);
int trainingDataCount = 200;
bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension);

// Call the train API with this definition:

/*
POST /_plugins/_knn/models/test-model/_train
{
"training_index": "train_index",
"training_field": "train_field",
"dimension": 8,
"description": "model",
"space_type": "innerproduct",
"mode": "on_disk",
"method": {
"name": "ivf",
"params": {
"nlist": 16
}
}
}
*/
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, "ivf")
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 16)
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(builder);

XContentBuilder outerParams = XContentFactory.jsonBuilder()
.startObject()
.field(TRAIN_INDEX_PARAMETER, trainingIndexName)
.field(TRAIN_FIELD_PARAMETER, nestedFieldPath)
.field(DIMENSION, dimension)
.field(COMPRESSION_LEVEL_PARAMETER, "16x")
.field(MODE_PARAMETER, "on_disk")
.field(KNN_METHOD, method)
.field(MODEL_DESCRIPTION, "dummy description")
.endObject();

Request request = new Request("POST", "/_plugins/_knn/models/" + modelId + "/_train");
request.setJsonEntity(outerParams.toString());

Response trainResponse = client().performRequest(request);

assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode()));

Response getResponse = getModel(modelId, null);
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

assertTrainingSucceeds(modelId, 30, 1000);
}
}

0 comments on commit 19f045d

Please sign in to comment.