Skip to content

Commit

Permalink
Refactor integ tests that access model index
Browse files Browse the repository at this point in the history
Refactors integration tests that directly access the model system index.
End users should not be directly accessing the model system index. It is
supposed to be an implementation detail. We have written restful
integration tests that directly access the model system index in order
to initialize the cluster state. However, we should not do this because
users should not be able to interact with it through restful APIs

That being said, some of this
implementation detail leaks out into the interface. For instance, in
k-NN stats we have a stat that is the model system index status. So, in
order to test this, we do need direct access to the system index.
Similarly, for search, we execute the search against the system index
and directly return the results. This is probably a bug - but we still
need to test it.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Jan 26, 2024
1 parent 47728ce commit f3302d2
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 106 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)
* Refactor integ tests that access model index [#1423](https://github.com/opensearch-project/k-NN/pull/1423)
### Documentation
### Maintenance
* Update developer guide to include M1 Setup [#1222](https://github.com/opensearch-project/k-NN/pull/1222)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
Expand All @@ -43,9 +42,7 @@

public class RestDeleteModelHandlerIT extends KNNRestTestCase {

public void testDeleteModelExists() throws Exception {
createModelSystemIndex();

public void testDelete_whenModelExists_thenDeletionSucceeds() throws Exception {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
Expand Down Expand Up @@ -80,42 +77,8 @@ public void testDeleteModelExists() throws Exception {
assertTrue(ex.getMessage().contains(modelId));
}

public void testDeleteTrainingModel() throws Exception {
createModelSystemIndex();

String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
int dimension = 8;
String modelDescription = "dummy description";

createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
// we do not wait for training to be completed
ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription);

Response getModelResponse = getModel(modelId, List.of());
assertEquals(RestStatus.OK, RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode()));

String responseBody = EntityUtils.toString(getModelResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

String deleteModelRestURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId);
Request deleteModelRequest = new Request("DELETE", deleteModelRestURI);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(deleteModelRequest));
assertEquals(RestStatus.CONFLICT.getStatus(), ex.getResponse().getStatusLine().getStatusCode());

// need to wait for training operation as it's required for after test cleanup
assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC);
}

public void testDeleteModelFailsInvalid() throws Exception {
public void testDelete_whenModelIDIsInvalid_thenFail() {
String modelId = "invalid-model-id";
createModelSystemIndex();
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId);
Request request = new Request("DELETE", restURI);

Expand All @@ -124,7 +87,7 @@ public void testDeleteModelFailsInvalid() throws Exception {
}

// Test Train Model -> Delete Model -> Train Model with same modelId
public void testTrainingDeletedModel() throws Exception {
public void testTraining_whenModelHasBeenDeleted_thenSucceedTrainingModelWithSameID() throws Exception {
String modelId = "test-model-id1";
String trainingIndexName1 = "train-index-1";
String trainingIndexName2 = "train-index-2";
Expand All @@ -141,8 +104,6 @@ public void testTrainingDeletedModel() throws Exception {
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

assertEquals(0, getDocCount(MODEL_INDEX_NAME));

// Train Model again with same ModelId
trainModel(modelId, trainingIndexName2, trainingFieldName, dimension);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.core.rest.RestStatus;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -44,9 +43,7 @@

public class RestGetModelHandlerIT extends KNNRestTestCase {

public void testGetModelExists() throws Exception {
createModelSystemIndex();

public void testGetModel_whenModelIdExists_thenSucceed() throws Exception {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
Expand Down Expand Up @@ -81,8 +78,7 @@ public void testGetModelExists() throws Exception {
assertEquals(L2.getValue(), responseMap.get(METHOD_PARAMETER_SPACE_TYPE));
}

public void testGetModelExistsWithFilter() throws Exception {
createModelSystemIndex();
public void testGetModel_whenFilterApplied_thenReturnExpectedFields() throws Exception {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
Expand Down Expand Up @@ -118,17 +114,15 @@ public void testGetModelExistsWithFilter() throws Exception {
assertFalse(responseMap.containsKey(MODEL_STATE));
}

public void testGetModelFailsInvalid() throws IOException {
createModelSystemIndex();
public void testGetModel_whenModelIDIsInValid_thenFail() {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "invalid-model-id");
Request request = new Request("GET", restURI);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertTrue(ex.getMessage().contains("\"invalid-model-id\""));
}

public void testGetModelFailsBlank() throws IOException {
createModelSystemIndex();
public void testGetModel_whenIDIsBlank_thenFail() {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, " ");
Request request = new Request("GET", restURI);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ public void testScriptStats_multipleShards() throws Exception {
}

public void testModelIndexHealthMetricsStats() throws Exception {
// Create request that filters only model index
String modelIndexStatusName = StatNames.MODEL_INDEX_STATUS.getName();
// index can be created in one of previous tests, and as we do not delete it each test the check below became optional
if (!systemIndexExists(MODEL_INDEX_NAME)) {
Expand All @@ -351,7 +350,11 @@ public void testModelIndexHealthMetricsStats() throws Exception {
// Check that model health status is null since model index is not created to system yet
assertNull(statsMap.get(StatNames.MODEL_INDEX_STATUS.getName()));

createModelSystemIndex();
// Train a model so that the system index will get created
createBasicKnnIndex(TRAINING_INDEX, TRAINING_FIELD, DIMENSION);
bulkIngestRandomVectors(TRAINING_INDEX, TRAINING_FIELD, NUM_DOCS, DIMENSION);
trainKnnModel(TEST_MODEL_ID, TRAINING_INDEX, TRAINING_FIELD, DIMENSION, MODEL_DESCRIPTION);
validateModelCreated(TEST_MODEL_ID);
}

Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,18 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE;
import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE;
import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE;
Expand All @@ -47,12 +43,7 @@

public class RestSearchModelHandlerIT extends KNNRestTestCase {

private ModelMetadata getModelMetadata() {
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "", "");
}

public void testNotSupportedParams() throws IOException {
createModelSystemIndex();
public void testSearch_whenUnSupportedParamsPassed_thenFail() {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search");
Map<String, String> invalidParams = new HashMap<>();
invalidParams.put("index", "index-name");
Expand All @@ -61,27 +52,30 @@ public void testNotSupportedParams() throws IOException {
expectThrows(ResponseException.class, () -> client().performRequest(request));
}

public void testNoModelExists() throws Exception {
createModelSystemIndex();
public void testSearch_whenNoModelExists_thenReturnEmptyResults() throws Exception {
// Currently, if the model index exists, we will return empty hits. If it does not exist, we will
// throw an exception. This is somewhat of a bug considering that the model index is supposed to be
// an implementation detail abstracted away from the user. However, in order to test, we need to handle
// the 2 different scenarios
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search");
Request request = new Request("GET", restURI);
request.setJsonEntity("{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}");

Response response = client().performRequest(request);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);
assertEquals(searchResponse.getHits().getHits().length, 0);

if (!systemIndexExists(MODEL_INDEX_NAME)) {
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertEquals(RestStatus.NOT_FOUND.getStatus(), ex.getResponse().getStatusLine().getStatusCode());
} else {
Response response = client().performRequest(request);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);
XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);
assertEquals(searchResponse.getHits().getHits().length, 0);
}
}

public void testSizeValidationFailsInvalidSize() throws IOException {
createModelSystemIndex();
public void testSearch_whenInvalidSizePassed_thenFail() {
for (Integer invalidSize : Arrays.asList(SEARCH_MODEL_MIN_SIZE - 1, SEARCH_MODEL_MAX_SIZE + 1)) {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search?" + PARAM_SIZE + "=" + invalidSize);
Request request = new Request("GET", restURI);
Expand All @@ -101,8 +95,7 @@ public void testSizeValidationFailsInvalidSize() throws IOException {

}

public void testSearchModelExists() throws Exception {
createModelSystemIndex();
public void testSearch_whenModelExists_thenSuccess() throws Exception {
String trainingIndex = "irrelevant-index";
String trainingFieldName = "train-field";
int dimension = 8;
Expand Down Expand Up @@ -151,7 +144,6 @@ public void testSearchModelExists() throws Exception {
}

public void testSearchModelWithoutSource() throws Exception {
createModelSystemIndex();
String trainingIndex = "irrelevant-index";
String trainingFieldName = "train-field";
int dimension = 8;
Expand Down Expand Up @@ -192,7 +184,6 @@ public void testSearchModelWithoutSource() throws Exception {
}

public void testSearchModelWithSourceFilteringIncludes() throws Exception {
createModelSystemIndex();
String trainingIndex = "irrelevant-index";
String trainingFieldName = "train-field";
int dimension = 8;
Expand Down Expand Up @@ -244,7 +235,6 @@ public void testSearchModelWithSourceFilteringIncludes() throws Exception {
}

public void testSearchModelWithSourceFilteringExcludes() throws Exception {
createModelSystemIndex();
String trainingIndex = "irrelevant-index";
String trainingFieldName = "train-field";
int dimension = 8;
Expand Down
20 changes: 0 additions & 20 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.knn;

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import com.google.common.primitives.Floats;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
Expand All @@ -22,7 +20,6 @@
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
Expand Down Expand Up @@ -51,7 +48,6 @@

import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
Expand All @@ -77,8 +73,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;
Expand Down Expand Up @@ -781,20 +775,6 @@ protected String getIndexSettingByName(String indexName, String settingName, boo
}
}

protected void createModelSystemIndex() throws IOException {
URL url = ModelDao.class.getClassLoader().getResource(MODEL_INDEX_MAPPING_PATH);
if (url == null) {
throw new IllegalStateException("Unable to retrieve mapping for \"" + MODEL_INDEX_NAME + "\"");
}

String mapping = Resources.toString(url, Charsets.UTF_8);
mapping = mapping.substring(1, mapping.length() - 1);

if (!systemIndexExists(MODEL_INDEX_NAME)) {
createIndex(MODEL_INDEX_NAME, Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).build(), mapping);
}
}

/**
* Clear cache
* <p>
Expand Down

0 comments on commit f3302d2

Please sign in to comment.