From fab6b5d61476fd5933eabeb02547f1c2a96ac6ef Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 4 Apr 2023 18:30:21 -0700 Subject: [PATCH] Support .opensearch-knn-model index as system index with security enabled (#827) * Add support for integ tests on secured cluster Signed-off-by: Martin Gaievski (cherry picked from commit b94b030afc74efc1fae5c5c7528edc25974a3fec) Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + build.gradle | 1 + .../java/org/opensearch/knn/bwc/ModelIT.java | 5 - .../org/opensearch/knn/indices/ModelDao.java | 147 ++++++++++++------ .../action/RestDeleteModelHandlerIT.java | 89 ++++++----- .../plugin/action/RestGetModelHandlerIT.java | 84 +++++----- .../plugin/action/RestKNNStatsHandlerIT.java | 22 +-- .../action/RestLegacyKNNStatsHandlerIT.java | 9 +- .../action/RestSearchModelHandlerIT.java | 93 ++++++----- src/test/resources/security/sample.pem | 28 ++++ src/test/resources/security/test-kirk.jks | Bin 0 -> 3874 bytes .../org/opensearch/knn/KNNRestTestCase.java | 89 ++++++++++- .../org/opensearch/knn/ODFERestTestCase.java | 136 ++++++++++++++-- .../java/org/opensearch/knn/TestUtils.java | 2 + 14 files changed, 514 insertions(+), 192 deletions(-) create mode 100644 src/test/resources/security/sample.pem create mode 100644 src/test/resources/security/test-kirk.jks diff --git a/CHANGELOG.md b/CHANGELOG.md index 91744f50b8..6d954ace3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Add 2.6.0 to BWC Version Matrix ([#810](https://github.com/opensearch-project/k-NN/pull/810)) * Update BWC Version with OpenSearch Version Bump ([#813](https://github.com/opensearch-project/k-NN/pull/813)) * Bump numpy version from 1.22.x to 1.24.2 ([#811](https://github.com/opensearch-project/k-NN/pull/811)) +* Support .opensearch-knn-model index as system index with security enabled ([#827](https://github.com/opensearch-project/k-NN/pull/827)) ### Documentation ### Maintenance ### Refactoring diff --git a/build.gradle b/build.gradle index f8a87db4e2..7b7c548092 100644 --- a/build.gradle +++ b/build.gradle @@ -178,6 +178,7 @@ dependencies { testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.14.3' testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.2' testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.14.3' + api "org.opensearch:common-utils:${version}" } diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java index f8e832c50c..e052f6dcfd 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java @@ -182,11 +182,6 @@ public static void wipeAllModels() throws IOException { deleteKNNModel(TEST_MODEL_ID); deleteKNNModel(TEST_MODEL_ID_DEFAULT); deleteKNNModel(TEST_MODEL_ID_TRAINING); - - Request request = new Request("DELETE", "/" + MODEL_INDEX_NAME); - - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } } diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0d5d75d30f..cf0dd18908 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -13,6 +13,7 @@ import com.google.common.base.Charsets; import com.google.common.io.Resources; +import lombok.SneakyThrows; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -42,6 +43,7 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.plugin.transport.DeleteModelResponse; @@ -49,10 +51,10 @@ import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; -import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; import java.io.IOException; import java.net.URL; @@ -62,6 +64,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; import static java.util.Objects.isNull; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH; @@ -216,14 +219,21 @@ public void create(ActionListener actionListener) throws IO if (isCreated()) { return; } - CreateIndexRequest request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping()) - .settings( - Settings.builder() - .put("index.hidden", true) - .put("index.number_of_shards", this.numberOfShards) - .put("index.number_of_replicas", this.numberOfReplicas) - ); - client.admin().indices().create(request, actionListener); + runWithStashedThreadContext(() -> { + CreateIndexRequest request; + try { + request = new CreateIndexRequest(MODEL_INDEX_NAME).mapping(getMapping()) + .settings( + Settings.builder() + .put("index.hidden", true) + .put("index.number_of_shards", this.numberOfShards) + .put("index.number_of_replicas", this.numberOfReplicas) + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + client.admin().indices().create(request, actionListener); + }); } @Override @@ -293,8 +303,9 @@ private void putInternal(Model model, ActionListener listener, Do parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model); } - IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME); - + final IndexRequestBuilder indexRequestBuilder = ModelDao.runWithStashedThreadContext( + () -> client.prepareIndex(MODEL_INDEX_NAME) + ); indexRequestBuilder.setId(model.getModelID()); indexRequestBuilder.setSource(parameters); @@ -304,8 +315,8 @@ private void putInternal(Model model, ActionListener listener, Do // After metadata update finishes, remove item from every node's cache if necessary. If no model id is // passed then nothing needs to be removed from the cache ActionListener onMetaListener; - onMetaListener = ActionListener.wrap( - indexResponse -> client.execute( + onMetaListener = ActionListener.wrap(indexResponse -> { + client.execute( RemoveModelFromCacheAction.INSTANCE, new RemoveModelFromCacheRequest(model.getModelID()), ActionListener.wrap(removeModelFromCacheResponse -> { @@ -318,9 +329,8 @@ private void putInternal(Model model, ActionListener listener, Do listener.onFailure(new RuntimeException(failureMessage)); }, listener::onFailure) - ), - listener::onFailure - ); + ); + }, listener::onFailure); // After the model is indexed, update metadata only if the model is in CREATED state ActionListener onIndexListener; @@ -357,16 +367,30 @@ private ActionListener getUpdateModelMetadataListener( ); } + @SneakyThrows @Override - public Model get(String modelId) throws ExecutionException, InterruptedException { + public Model get(String modelId) { /* GET //?_local */ - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) - .setPreference("_local"); - GetResponse getResponse = getRequestBuilder.execute().get(); - Map responseMap = getResponse.getSourceAsMap(); - return Model.getModelFromSourceMap(responseMap); + try { + return ModelDao.runWithStashedThreadContext(() -> { + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) + .setPreference("_local"); + GetResponse getResponse; + try { + getResponse = getRequestBuilder.execute().get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + Map responseMap = getResponse.getSourceAsMap(); + return Model.getModelFromSourceMap(responseMap); + }); + } catch (RuntimeException runtimeException) { + // we need to use RuntimeException as container for real exception to keep signature + // of runWithStashedThreadContext generic + throw runtimeException.getCause(); + } } /** @@ -380,20 +404,22 @@ public void get(String modelId, ActionListener actionListener) /* GET //?_local */ - GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) - .setPreference("_local"); - - getRequestBuilder.execute(ActionListener.wrap(response -> { - if (response.isSourceEmpty()) { - String errorMessage = String.format("Model \" %s \" does not exist", modelId); - actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage)); - return; - } - final Map responseMap = response.getSourceAsMap(); - Model model = Model.getModelFromSourceMap(responseMap); - actionListener.onResponse(new GetModelResponse(model)); + ModelDao.runWithStashedThreadContext(() -> { + GetRequestBuilder getRequestBuilder = new GetRequestBuilder(client, GetAction.INSTANCE, MODEL_INDEX_NAME).setId(modelId) + .setPreference("_local"); + + getRequestBuilder.execute(ActionListener.wrap(response -> { + if (response.isSourceEmpty()) { + String errorMessage = String.format("Model \" %s \" does not exist", modelId); + actionListener.onFailure(new ResourceNotFoundException(modelId, errorMessage)); + return; + } + final Map responseMap = response.getSourceAsMap(); + Model model = Model.getModelFromSourceMap(responseMap); + actionListener.onResponse(new GetModelResponse(model)); - }, actionListener::onFailure)); + }, actionListener::onFailure)); + }); } /** @@ -404,8 +430,10 @@ public void get(String modelId, ActionListener actionListener) */ @Override public void search(SearchRequest request, ActionListener actionListener) { - request.indices(MODEL_INDEX_NAME); - client.search(request, actionListener); + ModelDao.runWithStashedThreadContext(() -> { + request.indices(MODEL_INDEX_NAME); + client.search(request, actionListener); + }); } @Override @@ -505,16 +533,17 @@ public void delete(String modelId, ActionListener listener) ); // Setup delete model request - DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); - deleteRequestBuilder.setId(modelId); - deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - // On model metadata removal, delete the model from the index - clearModelMetadataStep.whenComplete( - acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder), - listener::onFailure - ); - + ModelDao.runWithStashedThreadContext(() -> { + DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); + deleteRequestBuilder.setId(modelId); + deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + // On model metadata removal, delete the model from the index + clearModelMetadataStep.whenComplete( + acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder), + listener::onFailure + ); + }); deleteModelFromIndexStep.whenComplete(deleteResponse -> { // If model is not deleted, remove modelId from model graveyard and return with error message if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { @@ -653,4 +682,26 @@ private String buildRemoveModelErrorMessage(String modelId, RemoveModelFromCache return stringBuilder.toString(); } } + + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function runnable that needs to be executed after thread context has been stashed, accepts and returns nothing + */ + private static void runWithStashedThreadContext(Runnable function) { + try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) { + function.run(); + } + } + + /** + * Set the thread context to default, this is needed to allow actions on model system index + * when security plugin is enabled + * @param function supplier function that needs to be executed after thread context has been stashed, return object + */ + private static T runWithStashedThreadContext(Supplier function) { + try (ThreadContext.StoredContext context = OpenSearchKNNModelDao.client.threadPool().getThreadContext().stashContext()) { + return function.get(); + } + } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java index aaa64625e3..86f9ad0a88 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java @@ -15,19 +15,15 @@ import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; -import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.rest.RestStatus; -import java.io.IOException; +import java.util.List; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; @@ -48,59 +44,74 @@ public class RestDeleteModelHandlerIT extends KNNRestTestCase { - private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); - } - - public void testDeleteModelExists() throws IOException { + public void testDeleteModelExists() throws Exception { createModelSystemIndex(); - String testModelID = "test-model-id"; - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); - assertEquals(getDocCount(MODEL_INDEX_NAME), 1); + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); - Request request = new Request("DELETE", restURI); + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + Response getModelResponse = getModel(modelId, List.of()); + assertEquals(RestStatus.OK, RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode())); assertEquals(0, getDocCount(MODEL_INDEX_NAME)); } - public void testDeleteTrainingModel() throws IOException { + public void testDeleteTrainingModel() throws Exception { createModelSystemIndex(); - String testModelID = "test-model-id"; - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - testModelMetadata.setState(ModelState.TRAINING); - - addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); - assertEquals(1, getDocCount(MODEL_INDEX_NAME)); - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); - Request request = new Request("DELETE", restURI); + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + // we do not wait for training to be completed + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription); - assertEquals(1, getDocCount(MODEL_INDEX_NAME)); + Response getModelResponse = getModel(modelId, List.of()); + assertEquals(RestStatus.OK, RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode())); - String responseBody = EntityUtils.toString(response.getEntity()); + String responseBody = EntityUtils.toString(getModelResponse.getEntity()); assertNotNull(responseBody); Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - assertEquals(testModelID, responseMap.get(MODEL_ID)); + assertEquals(modelId, responseMap.get(MODEL_ID)); + + String deleteModelRestURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); + Request deleteModelRequest = new Request("DELETE", deleteModelRestURI); + + Response deleteModelResponse = client().performRequest(deleteModelRequest); + assertEquals( + deleteModelRequest.getEndpoint() + ": failed", + RestStatus.OK, + RestStatus.fromCode(deleteModelResponse.getStatusLine().getStatusCode()) + ); + + responseBody = EntityUtils.toString(deleteModelResponse.getEntity()); + assertNotNull(responseBody); + + responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + + assertEquals(modelId, responseMap.get(MODEL_ID)); assertEquals("failed", responseMap.get(DeleteModelResponse.RESULT)); - String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", testModelID); + String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId); assertEquals(errorMessage, responseMap.get(DeleteModelResponse.ERROR_MSG)); + + // need to wait for training operation as it's required for after test cleanup + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } - public void testDeleteModelFailsInvalid() throws IOException { + public void testDeleteModelFailsInvalid() throws Exception { String modelId = "invalid-model-id"; createModelSystemIndex(); String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); @@ -111,7 +122,7 @@ public void testDeleteModelFailsInvalid() throws IOException { } // Test Train Model -> Delete Model -> Train Model with same modelId - public void testTrainingDeletedModel() throws IOException, InterruptedException { + public void testTrainingDeletedModel() throws Exception, InterruptedException { String modelId = "test-model-id1"; String trainingIndexName1 = "train-index-1"; String trainingIndexName2 = "train-index-2"; @@ -134,7 +145,7 @@ public void testTrainingDeletedModel() throws IOException, InterruptedException trainModel(modelId, trainingIndexName2, trainingFieldName, dimension); } - private void trainModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension) throws IOException, + private void trainModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension) throws Exception, InterruptedException { // Create a training index and randomly ingest data into it diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java index b6853e8bb9..092ca31e30 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestGetModelHandlerIT.java @@ -18,10 +18,6 @@ import org.opensearch.client.ResponseException; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.util.KNNEngine; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.rest.RestStatus; @@ -39,6 +35,8 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; +import static org.opensearch.knn.index.SpaceType.L2; +import static org.opensearch.knn.index.util.KNNEngine.FAISS; /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestGetModelHandler} @@ -46,19 +44,28 @@ public class RestGetModelHandlerIT extends KNNRestTestCase { - private ModelMetadata getModelMetadata() { - return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); - } - - public void testGetModelExists() throws IOException { + public void testGetModelExists() throws Exception { createModelSystemIndex(); - String testModelID = "test-model-id"; - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - - addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + + ingestDataAndTrainModel( + modelId, + trainingIndexName, + trainingFieldName, + dimension, + modelDescription, + xContentBuilderToMap(getModelMethodBuilder()) + ); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); Request request = new Request("GET", restURI); Response response = client().performRequest(request); @@ -68,30 +75,30 @@ public void testGetModelExists() throws IOException { assertNotNull(responseBody); Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - - assertEquals(testModelID, responseMap.get(MODEL_ID)); - assertEquals(testModelMetadata.getDescription(), responseMap.get(MODEL_DESCRIPTION)); - assertEquals(testModelMetadata.getDimension(), responseMap.get(DIMENSION)); - assertEquals(testModelMetadata.getError(), responseMap.get(MODEL_ERROR)); - assertEquals(testModelMetadata.getKnnEngine().getName(), responseMap.get(KNN_ENGINE)); - assertEquals(testModelMetadata.getSpaceType().getValue(), responseMap.get(METHOD_PARAMETER_SPACE_TYPE)); - assertEquals(testModelMetadata.getState().getName(), responseMap.get(MODEL_STATE)); - assertEquals(testModelMetadata.getTimestamp(), responseMap.get(MODEL_TIMESTAMP)); + assertEquals(modelId, responseMap.get(MODEL_ID)); + assertEquals(modelDescription, responseMap.get(MODEL_DESCRIPTION)); + assertEquals(FAISS.getName(), responseMap.get(KNN_ENGINE)); + assertEquals(L2.getValue(), responseMap.get(METHOD_PARAMETER_SPACE_TYPE)); } - public void testGetModelExistsWithFilter() throws IOException { + public void testGetModelExistsWithFilter() throws Exception { createModelSystemIndex(); - String testModelID = "test-model-id"; - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - - addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); - - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; + + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + Map method = xContentBuilderToMap(getModelMethodBuilder()); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, method); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); Request request = new Request("GET", restURI); - List filterdPath = Arrays.asList(MODEL_ID, MODEL_DESCRIPTION, MODEL_TIMESTAMP, KNN_ENGINE); - request.addParameter("filter_path", Strings.join(filterdPath, ",")); + List filteredPath = Arrays.asList(MODEL_ID, MODEL_DESCRIPTION, MODEL_TIMESTAMP, KNN_ENGINE); + request.addParameter("filter_path", Strings.join(filteredPath, ",")); Response response = client().performRequest(request); assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -101,11 +108,10 @@ public void testGetModelExistsWithFilter() throws IOException { Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - assertTrue(responseMap.size() == filterdPath.size()); - assertEquals(testModelID, responseMap.get(MODEL_ID)); - assertEquals(testModelMetadata.getDescription(), responseMap.get(MODEL_DESCRIPTION)); - assertEquals(testModelMetadata.getTimestamp(), responseMap.get(MODEL_TIMESTAMP)); - assertEquals(testModelMetadata.getKnnEngine().getName(), responseMap.get(KNN_ENGINE)); + assertTrue(responseMap.size() == filteredPath.size()); + assertEquals(modelId, responseMap.get(MODEL_ID)); + assertEquals(modelDescription, responseMap.get(MODEL_DESCRIPTION)); + assertEquals(FAISS.getName(), responseMap.get(KNN_ENGINE)); assertFalse(responseMap.containsKey(DIMENSION)); assertFalse(responseMap.containsKey(MODEL_ERROR)); assertFalse(responseMap.containsKey(METHOD_PARAMETER_SPACE_TYPE)); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index a1756cbf16..6ec699d87c 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -48,6 +48,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -341,20 +342,23 @@ public void testScriptStats_multipleShards() throws Exception { public void testModelIndexHealthMetricsStats() throws IOException { // Create request that filters only model index String modelIndexStatusName = StatNames.MODEL_INDEX_STATUS.getName(); + // index can be created in one of previous tests, and as we do not delete it each test the check below became optional + if (!systemIndexExists(MODEL_INDEX_NAME)) { - Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName)); - String responseBody = EntityUtils.toString(response.getEntity()); - Map statsMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + final Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName)); + final String responseBody = EntityUtils.toString(response.getEntity()); + final Map statsMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - // Check that model health status is null since model index is not created to system yet - assertNull(statsMap.get(StatNames.MODEL_INDEX_STATUS.getName())); + // Check that model health status is null since model index is not created to system yet + assertNull(statsMap.get(StatNames.MODEL_INDEX_STATUS.getName())); - createModelSystemIndex(); + createModelSystemIndex(); + } - response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName)); + Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName)); - responseBody = EntityUtils.toString(response.getEntity()); - statsMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + final String responseBody = EntityUtils.toString(response.getEntity()); + final Map statsMap = createParser(XContentType.JSON.xContent(), responseBody).map(); // Check that model health status is not null assertNotNull(statsMap.get(modelIndexStatusName)); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java index a4243537d1..0d900cfbea 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNStatsHandlerIT.java @@ -319,10 +319,15 @@ public void testScriptStats_multipleShards() throws Exception { // Useful settings when debugging to prevent timeouts @Override protected Settings restClientSettings() { + final Settings.Builder builder = Settings.builder(); if (isDebuggingTest || isDebuggingRemoteCluster) { - return Settings.builder().put(CLIENT_SOCKET_TIMEOUT, TimeValue.timeValueMinutes(10)).build(); + builder.put(CLIENT_SOCKET_TIMEOUT, TimeValue.timeValueMinutes(10)); } else { - return super.restClientSettings(); + if (System.getProperty("tests.rest.client_path_prefix") != null) { + builder.put(CLIENT_PATH_PREFIX, System.getProperty("tests.rest.client_path_prefix")); + } } + builder.put("strictDeprecationMode", false); + return builder.build(); } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java index 609fe7f09a..7364d24924 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestSearchModelHandlerIT.java @@ -16,7 +16,6 @@ import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; -import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; @@ -39,6 +38,8 @@ import static org.opensearch.knn.common.KNNConstants.PARAM_SIZE; import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MAX_SIZE; import static org.opensearch.knn.common.KNNConstants.SEARCH_MODEL_MIN_SIZE; +import static org.opensearch.knn.index.SpaceType.L2; +import static org.opensearch.knn.index.util.KNNEngine.FAISS; /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestSearchModelHandler} @@ -98,13 +99,23 @@ public void testSizeValidationFailsInvalidSize() throws IOException { public void testSearchModelExists() throws IOException { createModelSystemIndex(); - createIndex("irrelevant-index", Settings.EMPTY); - addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); + String trainingIndex = "irrelevant-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + String modelDescription = "dummy description"; + createBasicKnnIndex(trainingIndex, trainingFieldName, dimension); + List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - for (String modelID : testModelID) { - addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + for (String modelId : testModelID) { + ingestDataAndTrainModel( + modelId, + trainingIndex, + trainingFieldName, + dimension, + modelDescription, + xContentBuilderToMap(getModelMethodBuilder()) + ); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); @@ -128,21 +139,25 @@ public void testSearchModelExists() throws IOException { for (SearchHit hit : searchResponse.getHits().getHits()) { assertTrue(testModelID.contains(hit.getId())); Model model = Model.getModelFromSourceMap(hit.getSourceAsMap()); - assertEquals(getModelMetadata(), model.getModelMetadata()); - assertArrayEquals(testModelBlob, model.getModelBlob()); + assertEquals(modelDescription, model.getModelMetadata().getDescription()); + assertEquals(FAISS, model.getModelMetadata().getKnnEngine()); + assertEquals(L2, model.getModelMetadata().getSpaceType()); } } } public void testSearchModelWithoutSource() throws IOException { createModelSystemIndex(); - createIndex("irrelevant-index", Settings.EMPTY); - addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); - List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - for (String modelID : testModelID) { - addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + String trainingIndex = "irrelevant-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + createBasicKnnIndex(trainingIndex, trainingFieldName, dimension); + + List testModelIds = Arrays.asList("test-modelid1", "test-modelid2"); + for (String modelId : testModelIds) { + String modelDescription = "dummy description"; + ingestDataAndTrainModel(modelId, trainingIndex, trainingFieldName, dimension, modelDescription); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); @@ -163,10 +178,10 @@ public void testSearchModelWithoutSource() throws IOException { assertNotNull(searchResponse); // returns only model from ModelIndex - assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + assertEquals(searchResponse.getHits().getHits().length, testModelIds.size()); for (SearchHit hit : searchResponse.getHits().getHits()) { - assertTrue(testModelID.contains(hit.getId())); + assertTrue(testModelIds.contains(hit.getId())); assertNull(hit.getSourceAsMap()); } } @@ -174,13 +189,16 @@ public void testSearchModelWithoutSource() throws IOException { public void testSearchModelWithSourceFilteringIncludes() throws IOException { createModelSystemIndex(); - createIndex("irrelevant-index", Settings.EMPTY); - addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); - List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - for (String modelID : testModelID) { - addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + String trainingIndex = "irrelevant-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + createBasicKnnIndex(trainingIndex, trainingFieldName, dimension); + + List testModelIds = Arrays.asList("test-modelid1", "test-modelid2"); + for (String modelId : testModelIds) { + String modelDescription = "dummy description"; + ingestDataAndTrainModel(modelId, trainingIndex, trainingFieldName, dimension, modelDescription); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); @@ -208,10 +226,10 @@ public void testSearchModelWithSourceFilteringIncludes() throws IOException { assertNotNull(searchResponse); // returns only model from ModelIndex - assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + assertEquals(searchResponse.getHits().getHits().length, testModelIds.size()); for (SearchHit hit : searchResponse.getHits().getHits()) { - assertTrue(testModelID.contains(hit.getId())); + assertTrue(testModelIds.contains(hit.getId())); Map sourceAsMap = hit.getSourceAsMap(); assertFalse(sourceAsMap.containsKey("model_blob")); assertTrue(sourceAsMap.containsKey("state")); @@ -223,13 +241,16 @@ public void testSearchModelWithSourceFilteringIncludes() throws IOException { public void testSearchModelWithSourceFilteringExcludes() throws IOException { createModelSystemIndex(); - createIndex("irrelevant-index", Settings.EMPTY); - addDocWithBinaryField("irrelevant-index", "id1", "field-name", "value"); - List testModelID = Arrays.asList("test-modelid1", "test-modelid2"); - byte[] testModelBlob = "hello".getBytes(); - ModelMetadata testModelMetadata = getModelMetadata(); - for (String modelID : testModelID) { - addModelToSystemIndex(modelID, testModelMetadata, testModelBlob); + String trainingIndex = "irrelevant-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + createBasicKnnIndex(trainingIndex, trainingFieldName, dimension); + + List testModelIds = Arrays.asList("test-modelid1", "test-modelid2"); + for (String modelId : testModelIds) { + String modelDescription = "dummy description"; + ingestDataAndTrainModel(modelId, trainingIndex, trainingFieldName, dimension, modelDescription); + assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC); } String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); @@ -257,10 +278,10 @@ public void testSearchModelWithSourceFilteringExcludes() throws IOException { assertNotNull(searchResponse); // returns only model from ModelIndex - assertEquals(searchResponse.getHits().getHits().length, testModelID.size()); + assertEquals(searchResponse.getHits().getHits().length, testModelIds.size()); for (SearchHit hit : searchResponse.getHits().getHits()) { - assertTrue(testModelID.contains(hit.getId())); + assertTrue(testModelIds.contains(hit.getId())); Map sourceAsMap = hit.getSourceAsMap(); assertFalse(sourceAsMap.containsKey("model_blob")); assertTrue(sourceAsMap.containsKey("state")); diff --git a/src/test/resources/security/sample.pem b/src/test/resources/security/sample.pem new file mode 100644 index 0000000000..fa785ca10f --- /dev/null +++ b/src/test/resources/security/sample.pem @@ -0,0 +1,28 @@ +-----BEGIN CERTIFICATE----- +MIIEyTCCA7GgAwIBAgIGAWLrc1O2MA0GCSqGSIb3DQEBCwUAMIGPMRMwEQYKCZIm +iZPyLGQBGRYDY29tMRcwFQYKCZImiZPyLGQBGRYHZXhhbXBsZTEZMBcGA1UECgwQ +RXhhbXBsZSBDb20gSW5jLjEhMB8GA1UECwwYRXhhbXBsZSBDb20gSW5jLiBSb290 +IENBMSEwHwYDVQQDDBhFeGFtcGxlIENvbSBJbmMuIFJvb3QgQ0EwHhcNMTgwNDIy +MDM0MzQ3WhcNMjgwNDE5MDM0MzQ3WjBeMRIwEAYKCZImiZPyLGQBGRYCZGUxDTAL +BgNVBAcMBHRlc3QxDTALBgNVBAoMBG5vZGUxDTALBgNVBAsMBG5vZGUxGzAZBgNV +BAMMEm5vZGUtMC5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC +AQoCggEBAJa+f476vLB+AwK53biYByUwN+40D8jMIovGXm6wgT8+9Sbs899dDXgt +9CE1Beo65oP1+JUz4c7UHMrCY3ePiDt4cidHVzEQ2g0YoVrQWv0RedS/yx/DKhs8 +Pw1O715oftP53p/2ijD5DifFv1eKfkhFH+lwny/vMSNxellpl6NxJTiJVnQ9HYOL +gf2t971ITJHnAuuxUF48HcuNovW4rhtkXef8kaAN7cE3LU+A9T474ULNCKkEFPIl +ZAKN3iJNFdVsxrTU+CUBHzk73Do1cCkEvJZ0ZFjp0Z3y8wLY/gqWGfGVyA9l2CUq +eIZNf55PNPtGzOrvvONiui48vBKH1LsCAwEAAaOCAVkwggFVMIG8BgNVHSMEgbQw +gbGAFJI1DOAPHitF9k0583tfouYSl0BzoYGVpIGSMIGPMRMwEQYKCZImiZPyLGQB +GRYDY29tMRcwFQYKCZImiZPyLGQBGRYHZXhhbXBsZTEZMBcGA1UECgwQRXhhbXBs +ZSBDb20gSW5jLjEhMB8GA1UECwwYRXhhbXBsZSBDb20gSW5jLiBSb290IENBMSEw +HwYDVQQDDBhFeGFtcGxlIENvbSBJbmMuIFJvb3QgQ0GCAQEwHQYDVR0OBBYEFKyv +78ZmFjVKM9g7pMConYH7FVBHMAwGA1UdEwEB/wQCMAAwDgYDVR0PAQH/BAQDAgXg +MCAGA1UdJQEB/wQWMBQGCCsGAQUFBwMBBggrBgEFBQcDAjA1BgNVHREELjAsiAUq +AwQFBYISbm9kZS0wLmV4YW1wbGUuY29tgglsb2NhbGhvc3SHBH8AAAEwDQYJKoZI +hvcNAQELBQADggEBAIOKuyXsFfGv1hI/Lkpd/73QNqjqJdxQclX57GOMWNbOM5H0 +5/9AOIZ5JQsWULNKN77aHjLRr4owq2jGbpc/Z6kAd+eiatkcpnbtbGrhKpOtoEZy +8KuslwkeixpzLDNISSbkeLpXz4xJI1ETMN/VG8ZZP1bjzlHziHHDu0JNZ6TnNzKr +XzCGMCohFfem8vnKNnKUneMQMvXd3rzUaAgvtf7Hc2LTBlf4fZzZF1EkwdSXhaMA +1lkfHiqOBxtgeDLxCHESZ2fqgVqsWX+t3qHQfivcPW6txtDyrFPRdJOGhiMGzT/t +e/9kkAtQRgpTb3skYdIOOUOV0WGQ60kJlFhAzIs= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/src/test/resources/security/test-kirk.jks b/src/test/resources/security/test-kirk.jks new file mode 100644 index 0000000000000000000000000000000000000000..174dbda656f41b10341adb78ab91a46afaae8a1c GIT binary patch literal 3874 zcmY+GcQhM}zs8e@RFtSn>{Y9_XzfvZl*TSQG6F9FNd(wu zFab99cRY+FKt2CO5E21u`*&mo0s{VisDB9%$qk|Z?*;}S1PKGv1*1XCugHHEKp;B6 z69Saqd|bch;ZdXcj@o48Or^T{VjiQWQ)um?koax&EW2Jd6%cmO+99&?<0M#TkhMY0 z>TOc9NNj$5o%GwnI2>ZpA<-syd;YVlrkqVstJxqe_w8#F0dlKW!#D3WVDWfwaN@uX z{)l!>hgv`=r)M_tPedAH8wS zrMCsCM3^vbf3iWkdUoK)O(h9`bxp3s^zq4CU5%IJN;Y04OLiLfXPS%;Duo}L?EKtE z$4DyO?uRf+Ovm@OBmMKYjcI;;3k(jA`wJ`_W&){Es6Nv(A-s;NYZhfPTZJ%tBZ{1@ zc|_(P(o|Du6c{sJ4@Q6w- zF)*aVb&dDqmGoH8(8Y;T2S?DR9+P|nUT>q8177|so}DjY7IWc!jB(9r?rJ%YyVvh5 z4`BJLeFX6F2g1N^WT?dWin3^|1>$*MQP~CSqFMgQ4m&bJp``1>I(!5Pe9&NB7{wXc z+p)Bs6Durb104tWmIOYRkBU~Waz;l#k`+@Fye00vbTIQq3dY*R{KBH-UF3%r{=+v` zqu(DD1~xv;*N0vqhN9l+bCm(5u37KF+&JF&or0qB&J%}ZmdviHekDmr#GlPK60J4Q zJ#vSZYt1pSxEPM~S27`bL-X}ig&?t1ubwy1&P?lEwQUs|t?a7>dqM7^&@^5tSL9pMp+&5H?jk>BGMj!JcQ+3*rxFcY4MY2z z4C?1*^xq&(g`+u7JnXS-Yuq8?$%DG-Zs#VDo=cTmcJRfEFTG1T4~(u1j$Snc+7Cs; zyB9?mE4rqbq_*xqj?#OlN%@YGt*PgH+-~Fy+blur5jn zu_S?>vGKl_57zp6>#CW5Q&HHKl|qVToNrM`8!zz5n*{CQ+r2#n4{2tk@;0m{ zM8pbY25rVQv1<0iw2CPT?uG+>NVZVLalVoRSZQdC(&M@`0$mC@6l?zxF&LAM8XHR1Ah3S zb?4&7@N$w<+PVC^0ws=h2pqrozQ!=b!?Zy2@uQjFh1)BEPT$JlDa9Q8(%YHT_r)w# z<4bW`j)gX^ktonho#Uf=U=ZH5QT!;ug%qe!Fi?N(OjphEVY3YTU5B*j^ZMOg+XmnL zPpT%`zoHjGCw~=w|5zC`KWOFwsF`=Jjwez^hwA2rgTt^ z^10Gp<3*%@mI37QZ>P3$*PX4;4LpFQqK9AnvMxAg!|B)unEQ{13w`0LO;;mgV22L5 z=Y8bwo8Fch2UFgZEqeTdMGZMKmz)4Uzb#-R)&H4zUC45?<4&g?`6XX-=`F2|(~Esf z4P+-+Y;J{*hV8L55?o`K^wL+ zE>e|WH7ZW48)vi%Zq4nbkLikeTd&2pCr5A#jJC9jypS>*@uF<#i}Xp$3X7~b0>bXQ zd@CV7FY-$A{IR_m5uZie z+ckdOpNC4bjck=wZ@3lTl5+`W3~_4oPuGx4#mk-f?CsbGulgu|BAb)LTI|hBYM==Q zPLdu6@x)I_O{qq^{%cI*Q`-C+WZjpp^GjGiWv(#7Vr(pZ@A532u&Rn|3@4+xgKqNc zMhtgDOn)7lv}KZc^U}jD!KU{3;=7as(>uBwDx5}ii8iIz!F(WDlbe(V`WH5PS-XhZ zPJFI;eV}4{aJ?&?Sv%?zMZJ9SRFL%?ZZ0C(FdozY2R@i=1>&&E< z<(hauSRE!6;QE6ujbYrYrWNm9;!ixJV`}*=J$7wZ^0l>rTb7|)`olK^*^m3Ex%nq2 zL({r^1)T=Q7qM>-F~1lC817t!PNhq1c&?{#kiAuiMtlDELuI?Ut6LMQ6()675@U5L z_g(P7&7MR-N3z!C5a+qZ$!xmrg0qbsQn*7vqc!v-^yqc6`tlc%aQl-Fe+IYP5Pe^K z^%zx2w*a+^&+F*;<~HZ&=XwRTB6z)Uec2XkH=^cl)cHs|VxGqSQStks&td*NQbTPW z@??ewN#dRVCH?t{p-$)JDIxkVF$#9Q?iS!Qqby9p zttQuw3k2_4Hs9`5TG}3Jwk97Nste6#I!jG)f$b(~xI#)Bs7nQ7es#6RzYPh=8vCY$@K;aE z0JYYxSm&6)?GS&eI-ibs8vhi$EXK)Yhv7%bHy2C$czjfz?F4J+b%lJkXj+1&h?Ti_R;#D>}h%qh-ltN3^kJE=J$q9lGN z97&*c`aeQNBG8(G3ADz4#|D3&4&?Ix=oLK>L?VDUkp%Gi|FbTdf2<2!*X4kUenR(; zb%6=se)ca%eZ zOyn3`1eb66NoONNlb!Qgq|BuMxwULjnW>4u2iuhj(ZUV8fC!eY=nsZF*}w6V0(LxJ zVJ|ew^cV0%UizR_Y1yOEtM1}iw*f#fPAX(#E)%*G)QD7W7O$XT5e!*pv0krMED!yw zv)_h?54B@8<=GZ6ukEmkmrx<@jaUud2Y%EQU-vBcCChZ&9Xf`1Rw3w4G=@{y>I<<5 zr)BfiiXe`(Z@ksE4@BqB5d!$>pA(N&9b7XX5GBfr?j{H(J6=OSr*~9Ff8Zh0^d;HS3|V9O<+-Py zxI&YAI-gM^t2+X1O6JyQ*^8SfuZ5{?m1F14fGg;0aeF|P)4c8tw{C;?*J)`bjV2~qOsSjk^$@gQ1{3jw}OGfYhan!3#Y zHIQX-5|4fmT69zTvDd3aW(AkQqj4t}?Md}bd>>Q>N!29V@klLOr#L%^gPrlgw8ASS>!fstf*6i;ka?xLu@MUq>?r_mf*HCZ0jHy2N^B`x>Y90Tt5-jn7*G)Ai~?r^6!i zChFK}Z-Np|s#K(ct1NYcNSoxM%p~ng6bf7}uXm#_v&(wHHp4Tljgd6EW$Kg0xZkkr zi&o;({o`MC#=#JXFx-Py14vyFMbGypX`-a>1F9n21b`MXKk|zU$zEO&>l1Rjkx$4Vg-UeUetqM3xCVt2 z#4}QY$t__sQxkuq9U8E_JbjM8#9JvlSK48A@`?q^I*~JnT-!@f$l49YlT>fpGqYJ9 zr+k*tw-oT8l~Dr<$GT8lt$6D+{n7Af1%CX7h0*}>N)s;I);DZqq{57a method = xContentBuilderToMap(builder); + ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, method); + } + + protected void ingestDataAndTrainModel( + String modelId, + String trainingIndexName, + String trainingFieldName, + int dimension, + String modelDescription, + Map method + ) throws Exception { + int trainingDataCount = 40; + bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + Response trainResponse = trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, modelDescription); + + assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); + } + + protected XContentBuilder getModelMethodBuilder() throws IOException { + XContentBuilder modelMethodBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, FAISS.getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, L2.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); + return modelMethodBuilder; + } + /** * We need to be able to dump the jacoco coverage before cluster is shut down. * The new internal testing framework removed some of the gradle tasks we were listening to diff --git a/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java b/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java index 5f174b964a..097fe014de 100644 --- a/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/ODFERestTestCase.java @@ -5,13 +5,6 @@ package org.opensearch.knn; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - import org.apache.http.Header; import org.apache.http.HttpHost; import org.apache.http.auth.AuthScope; @@ -21,23 +14,54 @@ import org.apache.http.impl.client.BasicCredentialsProvider; import org.apache.http.message.BasicHeader; import org.apache.http.ssl.SSLContextBuilder; +import org.apache.http.util.EntityUtils; +import org.junit.After; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.client.RestClientBuilder; +import org.opensearch.common.Strings; +import org.opensearch.common.io.PathUtils; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.rest.SecureRestClientBuilder; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.rest.RestStatus; +import org.opensearch.search.SearchHit; import org.opensearch.test.rest.OpenSearchRestTestCase; -import org.junit.After; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD; +import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH; import static org.opensearch.knn.TestUtils.KNN_BWC_PREFIX; import static org.opensearch.knn.TestUtils.OPENDISTRO_SECURITY; +import static org.opensearch.knn.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX; +import static org.opensearch.knn.TestUtils.SECURITY_AUDITLOG_PREFIX; import static org.opensearch.knn.TestUtils.SKIP_DELETE_MODEL_INDEX; +import static org.opensearch.knn.common.KNNConstants.MODELS; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; /** @@ -45,6 +69,8 @@ */ public abstract class ODFERestTestCase extends OpenSearchRestTestCase { + private final Set IMMUTABLE_INDEX_PREFIXES = Set.of(KNN_BWC_PREFIX, SECURITY_AUDITLOG_PREFIX, OPENSEARCH_SYSTEM_INDEX_PREFIX); + protected boolean isHttps() { boolean isHttps = Optional.ofNullable(System.getProperty("https")).map("true"::equalsIgnoreCase).orElse(false); if (isHttps) { @@ -66,7 +92,22 @@ protected String getProtocol() { protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException { RestClientBuilder builder = RestClient.builder(hosts); if (isHttps()) { - configureHttpsClient(builder, settings); + String keystore = settings.get(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH); + if (Objects.nonNull(keystore)) { + URI uri; + try { + uri = this.getClass().getClassLoader().getResource("security/sample.pem").toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + Path configPath = PathUtils.get(uri).getParent().toAbsolutePath(); + return new SecureRestClientBuilder(settings, configPath).build(); + } else { + configureHttpsClient(builder, settings); + boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true); + builder.setStrictDeprecationMode(strictDeprecationMode); + return builder.build(); + } } else { configureClient(builder, settings); } @@ -120,8 +161,8 @@ protected boolean preserveIndicesUponCompletion() { @SuppressWarnings("unchecked") @After - protected void wipeAllODFEIndices() throws IOException { - Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); + protected void wipeAllODFEIndices() throws Exception { + Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); XContentType xContentType = XContentType.fromMediaType(response.getEntity().getContentType().getValue()); try ( XContentParser parser = xContentType.xContent() @@ -140,7 +181,11 @@ protected void wipeAllODFEIndices() throws IOException { } for (Map index : parserList) { - String indexName = (String) index.get("index"); + final String indexName = (String) index.get("index"); + if (isIndexCleanupRequired(indexName)) { + wipeIndexContent(indexName); + continue; + } if (!skipDeleteIndex(indexName)) { adminClient().performRequest(new Request("DELETE", "/" + indexName)); } @@ -148,6 +193,57 @@ protected void wipeAllODFEIndices() throws IOException { } } + private boolean isIndexCleanupRequired(final String index) { + return MODEL_INDEX_NAME.equals(index) && !getSkipDeleteModelIndexFlag(); + } + + private void wipeIndexContent(String indexName) throws IOException { + deleteModels(getModelIds()); + deleteAllDocs(indexName); + } + + private List getModelIds() throws IOException { + final String restURIGetModels = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "_search"); + final Response response = adminClient().performRequest(new Request("GET", restURIGetModels)); + + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + final String responseBody = EntityUtils.toString(response.getEntity()); + assertNotNull(responseBody); + + final XContentParser parser = createParser(XContentType.JSON.xContent(), responseBody); + final SearchResponse searchResponse = SearchResponse.fromXContent(parser); + + return Arrays.stream(searchResponse.getHits().getHits()).map(SearchHit::getId).collect(Collectors.toList()); + } + + private void deleteModels(final List modelIds) throws IOException { + for (final String testModelID : modelIds) { + final String restURIGetModel = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + final Response getModelResponse = adminClient().performRequest(new Request("GET", restURIGetModel)); + if (RestStatus.OK != RestStatus.fromCode(getModelResponse.getStatusLine().getStatusCode())) { + continue; + } + final String restURIDeleteModel = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + adminClient().performRequest(new Request("DELETE", restURIDeleteModel)); + } + } + + private void deleteAllDocs(final String indexName) throws IOException { + final String restURIDeleteByQuery = String.join("/", indexName, "_delete_by_query"); + final Request request = new Request("POST", restURIDeleteByQuery); + final XContentBuilder matchAllDocsQuery = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("match_all") + .endObject() + .endObject() + .endObject(); + + request.setJsonEntity(Strings.toString(matchAllDocsQuery)); + adminClient().performRequest(request); + } + private boolean getSkipDeleteModelIndexFlag() { return Boolean.parseBoolean(System.getProperty(SKIP_DELETE_MODEL_INDEX, "false")); } @@ -159,11 +255,25 @@ private boolean skipDeleteModelIndex(String indexName) { private boolean skipDeleteIndex(String indexName) { if (indexName != null && !OPENDISTRO_SECURITY.equals(indexName) - && !indexName.startsWith(KNN_BWC_PREFIX) + && IMMUTABLE_INDEX_PREFIXES.stream().noneMatch(indexName::startsWith) && !skipDeleteModelIndex(indexName)) { return false; } return true; } + + @Override + protected Settings restAdminSettings() { + return Settings.builder() + // disable the warning exception for admin client since it's only used for cleanup. + .put("strictDeprecationMode", false) + .put("http.port", 9200) + .put(OPENSEARCH_SECURITY_SSL_HTTP_ENABLED, isHttps()) + .put(OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH, "sample.pem") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH, "test-kirk.jks") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD, "changeit") + .put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD, "changeit") + .build(); + } } diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index f179eef362..0843176e7a 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -95,6 +95,8 @@ public class TestUtils { public static final String ROLLING_UPGRADE_FIRST_ROUND = "tests.rest.first_round"; public static final String SKIP_DELETE_MODEL_INDEX = "tests.skip_delete_model_index"; public static final String UPGRADED_CLUSTER = "upgraded_cluster"; + public static final String SECURITY_AUDITLOG_PREFIX = "security-auditlog"; + public static final String OPENSEARCH_SYSTEM_INDEX_PREFIX = ".opensearch"; // Generating vectors using random function with a seed which makes these vectors standard and generate same vectors for each run. public static float[][] randomlyGenerateStandardVectors(int numVectors, int dimensions, int seed) {