Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor integ tests that access model index #1423

Merged
merged 8 commits into from
Jan 26, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)
* Refactor integ tests that access model index [#1423](https://github.com/opensearch-project/k-NN/pull/1423)
### Documentation
### Maintenance
* Update developer guide to include M1 Setup [#1222](https://github.com/opensearch-project/k-NN/pull/1222)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package org.opensearch.knn.plugin.action;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
Expand All @@ -30,7 +31,6 @@
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
Expand All @@ -43,9 +43,8 @@

public class RestDeleteModelHandlerIT extends KNNRestTestCase {

public void testDeleteModelExists() throws Exception {
createModelSystemIndex();

@SneakyThrows
public void testDelete_whenModelExists_thenDeletionSucceeds() {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
Expand Down Expand Up @@ -80,42 +79,8 @@ public void testDeleteModelExists() throws Exception {
assertTrue(ex.getMessage().contains(modelId));
}

public void testDeleteTrainingModel() throws Exception {
createModelSystemIndex();

String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
int dimension = 8;
String modelDescription = "dummy description";

createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
// we do not wait for training to be completed
ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription);

Response getModelResponse = getModel(modelId, List.of());
assertEquals(RestStatus.OK, RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode()));

String responseBody = EntityUtils.toString(getModelResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

String deleteModelRestURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId);
Request deleteModelRequest = new Request("DELETE", deleteModelRestURI);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(deleteModelRequest));
assertEquals(RestStatus.CONFLICT.getStatus(), ex.getResponse().getStatusLine().getStatusCode());

// need to wait for training operation as it's required for after test cleanup
assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC);
}

public void testDeleteModelFailsInvalid() throws Exception {
public void testDelete_whenModelIDIsInvalid_thenFail() {
String modelId = "invalid-model-id";
createModelSystemIndex();
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId);
Request request = new Request("DELETE", restURI);

Expand All @@ -124,7 +89,8 @@ public void testDeleteModelFailsInvalid() throws Exception {
}

// Test Train Model -> Delete Model -> Train Model with same modelId
public void testTrainingDeletedModel() throws Exception {
@SneakyThrows
public void testTraining_whenModelHasBeenDeleted_thenSucceedTrainingModelWithSameID() {
String modelId = "test-model-id1";
String trainingIndexName1 = "train-index-1";
String trainingIndexName2 = "train-index-2";
Expand All @@ -141,8 +107,6 @@ public void testTrainingDeletedModel() throws Exception {
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

assertEquals(0, getDocCount(MODEL_INDEX_NAME));

// Train Model again with same ModelId
trainModel(modelId, trainingIndexName2, trainingFieldName, dimension);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.plugin.action;

import joptsimple.internal.Strings;
import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
Expand All @@ -21,7 +22,6 @@
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.core.rest.RestStatus;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -44,9 +44,8 @@

public class RestGetModelHandlerIT extends KNNRestTestCase {

public void testGetModelExists() throws Exception {
createModelSystemIndex();

@SneakyThrows
public void testGetModel_whenModelIdExists_thenSucceed() {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
Expand Down Expand Up @@ -81,8 +80,8 @@ public void testGetModelExists() throws Exception {
assertEquals(L2.getValue(), responseMap.get(METHOD_PARAMETER_SPACE_TYPE));
}

public void testGetModelExistsWithFilter() throws Exception {
createModelSystemIndex();
@SneakyThrows
public void testGetModel_whenFilterApplied_thenReturnExpectedFields() {
String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
Expand Down Expand Up @@ -118,17 +117,15 @@ public void testGetModelExistsWithFilter() throws Exception {
assertFalse(responseMap.containsKey(MODEL_STATE));
}

public void testGetModelFailsInvalid() throws IOException {
createModelSystemIndex();
public void testGetModel_whenModelIDIsInValid_thenFail() {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "invalid-model-id");
Request request = new Request("GET", restURI);

ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertTrue(ex.getMessage().contains("\"invalid-model-id\""));
}

public void testGetModelFailsBlank() throws IOException {
createModelSystemIndex();
public void testGetModel_whenIDIsBlank_thenFail() {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, " ");
Request request = new Request("GET", restURI);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ public void testScriptStats_multipleShards() throws Exception {
}

public void testModelIndexHealthMetricsStats() throws Exception {
// Create request that filters only model index
String modelIndexStatusName = StatNames.MODEL_INDEX_STATUS.getName();
// index can be created in one of previous tests, and as we do not delete it each test the check below became optional
if (!systemIndexExists(MODEL_INDEX_NAME)) {
Expand All @@ -351,7 +350,11 @@ public void testModelIndexHealthMetricsStats() throws Exception {
// Check that model health status is null since model index is not created to system yet
assertNull(statsMap.get(StatNames.MODEL_INDEX_STATUS.getName()));

createModelSystemIndex();
// Train a model so that the system index will get created
createBasicKnnIndex(TRAINING_INDEX, TRAINING_FIELD, DIMENSION);
bulkIngestRandomVectors(TRAINING_INDEX, TRAINING_FIELD, NUM_DOCS, DIMENSION);
trainKnnModel(TEST_MODEL_ID, TRAINING_INDEX, TRAINING_FIELD, DIMENSION, MODEL_DESCRIPTION);
validateModelCreated(TEST_MODEL_ID);
}

Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

package org.opensearch.knn.plugin.action;

import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Request;
Expand All @@ -19,22 +20,18 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE;
import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE;
import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE;
Expand All @@ -47,12 +44,7 @@

public class RestSearchModelHandlerIT extends KNNRestTestCase {

private ModelMetadata getModelMetadata() {
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "", "");
}

public void testNotSupportedParams() throws IOException {
createModelSystemIndex();
public void testSearch_whenUnSupportedParamsPassed_thenFail() {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search");
Map<String, String> invalidParams = new HashMap<>();
invalidParams.put("index", "index-name");
Expand All @@ -61,27 +53,31 @@ public void testNotSupportedParams() throws IOException {
expectThrows(ResponseException.class, () -> client().performRequest(request));
}

public void testNoModelExists() throws Exception {
createModelSystemIndex();
@SneakyThrows
public void testSearch_whenNoModelExists_thenReturnEmptyResults() {
// Currently, if the model index exists, we will return empty hits. If it does not exist, we will
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create an issue to fix this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, #1425

// throw an exception. This is somewhat of a bug considering that the model index is supposed to be
// an implementation detail abstracted away from the user. However, in order to test, we need to handle
// the 2 different scenarios
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search");
Request request = new Request("GET", restURI);
request.setJsonEntity("{\n" + " \"query\": {\n" + " \"match_all\": {}\n" + " }\n" + "}");

Response response = client().performRequest(request);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);
assertEquals(searchResponse.getHits().getHits().length, 0);

if (!systemIndexExists(MODEL_INDEX_NAME)) {
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request));
assertEquals(RestStatus.NOT_FOUND.getStatus(), ex.getResponse().getStatusLine().getStatusCode());
} else {
Response response = client().performRequest(request);
assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);
XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);
assertEquals(searchResponse.getHits().getHits().length, 0);
}
}

public void testSizeValidationFailsInvalidSize() throws IOException {
createModelSystemIndex();
public void testSearch_whenInvalidSizePassed_thenFail() {
for (Integer invalidSize : Arrays.asList(SEARCH_MODEL_MIN_SIZE - 1, SEARCH_MODEL_MAX_SIZE + 1)) {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search?" + PARAM_SIZE + "=" + invalidSize);
Request request = new Request("GET", restURI);
Expand All @@ -101,8 +97,8 @@ public void testSizeValidationFailsInvalidSize() throws IOException {

}

public void testSearchModelExists() throws Exception {
createModelSystemIndex();
@SneakyThrows
public void testSearch_whenModelExists_thenSuccess() {
String trainingIndex = "irrelevant-index";
String trainingFieldName = "train-field";
int dimension = 8;
Expand Down Expand Up @@ -151,7 +147,6 @@ public void testSearchModelExists() throws Exception {
}

public void testSearchModelWithoutSource() throws Exception {
createModelSystemIndex();
String trainingIndex = "irrelevant-index";
String trainingFieldName = "train-field";
int dimension = 8;
Expand Down Expand Up @@ -192,7 +187,6 @@ public void testSearchModelWithoutSource() throws Exception {
}

public void testSearchModelWithSourceFilteringIncludes() throws Exception {
createModelSystemIndex();
String trainingIndex = "irrelevant-index";
String trainingFieldName = "train-field";
int dimension = 8;
Expand Down Expand Up @@ -244,7 +238,6 @@ public void testSearchModelWithSourceFilteringIncludes() throws Exception {
}

public void testSearchModelWithSourceFilteringExcludes() throws Exception {
createModelSystemIndex();
String trainingIndex = "irrelevant-index";
String trainingFieldName = "train-field";
int dimension = 8;
Expand Down
20 changes: 0 additions & 20 deletions src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.knn;

import com.google.common.base.Charsets;
import com.google.common.io.Resources;
import com.google.common.primitives.Floats;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.StringUtils;
Expand All @@ -22,7 +20,6 @@
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
Expand Down Expand Up @@ -51,7 +48,6 @@

import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
Expand All @@ -77,8 +73,6 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;
Expand Down Expand Up @@ -781,20 +775,6 @@ protected String getIndexSettingByName(String indexName, String settingName, boo
}
}

protected void createModelSystemIndex() throws IOException {
URL url = ModelDao.class.getClassLoader().getResource(MODEL_INDEX_MAPPING_PATH);
if (url == null) {
throw new IllegalStateException("Unable to retrieve mapping for \"" + MODEL_INDEX_NAME + "\"");
}

String mapping = Resources.toString(url, Charsets.UTF_8);
mapping = mapping.substring(1, mapping.length() - 1);

if (!systemIndexExists(MODEL_INDEX_NAME)) {
createIndex(MODEL_INDEX_NAME, Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).build(), mapping);
}
}

/**
* Clear cache
* <p>
Expand Down
Loading