diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index ae28066d5a..78f6f4ac60 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -191,7 +191,11 @@ public MLModel(StreamInput input) throws IOException{ modelContentSizeInBytes = input.readOptionalLong(); modelContentHash = input.readOptionalString(); if (input.readBoolean()) { - modelConfig = new TextEmbeddingModelConfig(input); + if (algorithm.equals(FunctionName.METRICS_CORRELATION)) { + modelConfig = new MetricsCorrelationModelConfig(input); + } else { + modelConfig = new TextEmbeddingModelConfig(input); + } } createdTime = input.readOptionalInstant(); lastUpdateTime = input.readOptionalInstant(); diff --git a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java index 4f26e4b4d2..e1c9203cae 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java @@ -8,6 +8,7 @@ import lombok.Builder; import lombok.Getter; import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; @@ -28,6 +29,10 @@ public MetricsCorrelationModelConfig(String modelType, String allConfig) { super(modelType, allConfig); } + public MetricsCorrelationModelConfig(StreamInput in) throws IOException{ + super(in); + } + @Override public String getWriteableName() { return PARSE_FIELD_NAME; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index a791739678..c778675ea8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -18,6 +18,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import java.io.IOException; @@ -137,7 +138,11 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.modelFormat = in.readEnum(MLModelFormat.class); } if (in.readBoolean()) { - this.modelConfig = new TextEmbeddingModelConfig(in); + if (this.functionName.equals(FunctionName.METRICS_CORRELATION)) { + this.modelConfig = new MetricsCorrelationModelConfig(in); + } else { + this.modelConfig = new TextEmbeddingModelConfig(in); + } } this.deployModel = in.readBoolean(); this.modelNodeIds = in.readOptionalStringArray(); diff --git a/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java new file mode 100644 index 0000000000..4700039939 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class MetricsCorrelationModelConfigTests { + + MetricsCorrelationModelConfig config; + Function function; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + config = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + function = parser -> { + try { + return MetricsCorrelationModelConfig.parse(parser); + } catch (IOException e) { + throw new RuntimeException("Failed to parse MetricsCorrelationModelConfig", e); + } + }; + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + config.toXContent(builder, EMPTY_PARAMS); + String configContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + } + + @Test + public void nullFields_ModelType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("model type is null"); + config = MetricsCorrelationModelConfig.builder() + .build(); + } + + @Test + public void parse() throws IOException { + String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + TestHelper.testParseFromString(config, content, function); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(config); + } + + public void readInputStream(MetricsCorrelationModelConfig config) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + config.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MetricsCorrelationModelConfig parsedConfig = new MetricsCorrelationModelConfig(streamInput); + assertEquals(config.getModelType(), parsedConfig.getModelType()); + assertEquals(config.getAllConfig(), parsedConfig.getAllConfig()); + assertEquals(config.getWriteableName(), parsedConfig.getWriteableName()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index c28d4ccb1e..52e4bec17c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -22,6 +22,7 @@ import org.opensearch.ml.common.connector.HttpConnectorTest; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.search.SearchModule; @@ -74,7 +75,6 @@ public void setUp() throws Exception { .deployModel(true) .modelNodeIds(new String[]{"modelNodeIds" }) .build(); - } @Test @@ -257,6 +257,59 @@ public void readInputStream_WithInternalConnector() throws IOException { }); } + @Test + public void testMCorrInput() throws IOException { + String testString = "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + + MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + mcorrInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals(testString, jsonStr); + } + + @Test + public void readInputStream_MCorr() throws IOException { + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + + MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + readInputStream(mcorrInput, parsedInput -> { + assertEquals(parsedInput.getModelConfig().getModelType(), mcorrConfig.getModelType()); + assertEquals(parsedInput.getModelConfig().getAllConfig(), mcorrConfig.getAllConfig()); + assertEquals(parsedInput.getFunctionName(), FunctionName.METRICS_CORRELATION); + assertEquals(parsedInput.getModelName(), FunctionName.METRICS_CORRELATION.name()); + assertEquals(parsedInput.getModelGroupId(), modelGroupId); + }); + } + private void readInputStream(MLRegisterModelInput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java index 8a7ce3aa54..4ae7e31449 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java @@ -125,7 +125,7 @@ public void close() { * @param modelId id of the model * @param modelName name of the model * @param version version of the model - * @param engine engine where model will be run. For now we are supporting only pytorch engine only. + * @param engine engine where model will be run. For now, we are supporting only pytorch engine only. */ private void loadModel(File modelZipFile, String modelId, String modelName, String version, String engine) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index fa6818efc4..d815356381 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -119,7 +119,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { if (modelId == null) { boolean hasModelGroupIndex = clusterService.state().getMetadata().hasIndex(ML_MODEL_GROUP_INDEX); - if (!hasModelGroupIndex) { // Create model group index if doesn't exist + if (!hasModelGroupIndex) { // Create model group index if it doesn't exist try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { CreateIndexRequest request = new CreateIndexRequest(ML_MODEL_GROUP_INDEX).mapping(ML_MODEL_GROUP_INDEX_MAPPING); CreateIndexResponse createIndexResponse = client.admin().indices().create(request).actionGet(1000); @@ -162,7 +162,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { }, e-> { log.error("Failed to get model", e); }); - client.get(getModelRequest, ActionListener.runBefore(listener, () -> context.restore())); + client.get(getModelRequest, ActionListener.runBefore(listener, context::restore)); } } } else { @@ -177,10 +177,17 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { waitUntil(() -> { if (modelId != null) { MLModelState modelState = getModel(modelId).getModelState(); - return modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED; + if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED){ + log.info("Model deployed: " + modelState); + return true; + } else if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) { + log.info("Model not deployed: " + modelState); + deployModel(modelId, ActionListener.wrap(deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e))); + return false; + } } return false; - }, 10, TimeUnit.SECONDS); + }, 120, TimeUnit.SECONDS); Output djlOutput; try { @@ -230,9 +237,7 @@ void registerModel(ActionListener listener) throws Inte log.error("Failed to Register Model", e); listener.onFailure(e); })); - }, e-> { - listener.onFailure(e); - }), () -> context.restore())); + }, listener::onFailure), context::restore)); } catch (IOException e) { throw new MLException(e); } @@ -300,6 +305,8 @@ public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, } sum += timeInMillis; timeInMillis = Math.min(AWAIT_BUSY_THRESHOLD, timeInMillis * 2); + + log.info("Waiting... Time elapsed: " + sum + "ms"); } timeInMillis = maxTimeInMillis - sum; try { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index ce22db12fe..4230b3027a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.metrics_correlation; +import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Ignore; @@ -13,7 +14,17 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -23,6 +34,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; @@ -67,6 +79,7 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; +import org.opensearch.threadpool.ThreadPool; import java.io.File; import java.io.IOException; @@ -75,9 +88,12 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -92,24 +108,30 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE; import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER; import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE; import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION; import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH; -//TODO: fix mockito error: Cannot mock/spy class org.opensearch.common.settings.Settings final class -@Ignore public class MetricsCorrelationTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Mock Client client; - @Mock Settings settings; + @Mock private ClusterService clusterService; + + @Mock + ThreadPool threadPool; + + ThreadContext threadContext; + @Mock SearchRequest searchRequest; SearchResponse searchResponse; @@ -142,6 +164,8 @@ public class MetricsCorrelationTest { private final String modelId = "modelId"; private final String modelGroupId = "modelGroupId"; + final String USER_STRING = "myuser|role1,role2|myTenant"; + MLTask mlTask; Map params = new HashMap<>(); @@ -186,6 +210,16 @@ public void setUp() throws IOException, URISyntaxException { MockitoAnnotations.openMocks(this); metricsCorrelation = spy(new MetricsCorrelation(client, settings, clusterService)); + + settings = Settings.builder().build(); + ClusterState testClusterState = setupTestClusterState(); + when(clusterService.state()).thenReturn(testClusterState); + + threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + List inputData = new ArrayList<>(); inputData.add(new float[]{-1.0f, 2.0f, 3.0f}); inputData.add(new float[]{-1.0f, 2.0f, 3.0f}); @@ -199,16 +233,26 @@ public void setUp() throws IOException, URISyntaxException { } + @Ignore @Test public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteException { - metricsCorrelation.initModel(model, params); MLModelGetResponse response = new MLModelGetResponse(model); ActionFuture mockedFuture = mock(ActionFuture.class); when(client.execute(any(MLModelGetAction.class), any(MLModelGetRequest.class))).thenReturn(mockedFuture); when(mockedFuture.actionGet(anyLong())).thenReturn(response); doAnswer(invocation -> { - MLModel smallModel = model.toBuilder().modelConfig(modelConfig).modelState(MLModelState.DEPLOYED).build(); + + MLModel smallModel = MLModel.builder() + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .name(FunctionName.METRICS_CORRELATION.name()) + .modelId(modelId) + .modelGroupId(modelGroupId) + .algorithm(FunctionName.METRICS_CORRELATION) + .version(MCORR_ML_VERSION) + .modelConfig(modelConfig) + .modelState(MLModelState.UNDEPLOYED) + .build(); MLModelGetResponse responseTemp = new MLModelGetResponse(smallModel); ActionFuture mockedFutureTemp = mock(ActionFuture.class); MLTaskGetResponse taskResponse = new MLTaskGetResponse(mlTask); @@ -216,8 +260,8 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio when(client.execute(any(MLTaskGetAction.class), any(MLTaskGetRequest.class))).thenReturn(mockedFutureResponse); when(mockedFutureResponse.actionGet(anyLong())).thenReturn(taskResponse); when(mockedFutureTemp.actionGet(anyLong())).thenReturn(responseTemp); - metricsCorrelation.initModel(smallModel, params); + smallModel.toBuilder().modelState(MLModelState.DEPLOYED).build(); return null; }).when(client).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), isA(ActionListener.class)); @@ -227,7 +271,7 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio assertNull(mlModelOutputs.get(0).getMCorrModelTensors()); } - + @Ignore @Test public void testExecuteWithModelInIndexAndEmptyOutput() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -290,7 +334,7 @@ public void testExecuteWithModelInIndexAndOneEvent() throws ExecuteException, UR assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } - + @Ignore @Test public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -331,7 +375,7 @@ public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, UR assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } - + @Ignore @Test public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -379,7 +423,7 @@ public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws Execu } - + @Ignore @Test public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -422,6 +466,7 @@ public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, } + //working @Test public void testGetModel() { ActionFuture mockedFuture = mock(ActionFuture.class); @@ -453,6 +498,7 @@ public static XContentBuilder builder() throws IOException { return XContentBuilder.builder(XContentType.JSON.xContent()); } + //working @Test public void testSearchRequest() { String expectedIndex = CommonValue.ML_MODEL_INDEX; @@ -486,7 +532,7 @@ public void testSearchRequest() { assertEquals(MLModel.MODEL_VERSION_FIELD, versionQueryBuilder.fieldName()); } - + @Ignore @Test public void testRegisterModel() throws InterruptedException { doAnswer(invocation -> { @@ -510,7 +556,6 @@ public void testRegisterModel() throws InterruptedException { verify(mlRegisterModelResponseActionListener).onResponse(mlRegisterModelResponse); } - @Test public void testDeployModel() { doAnswer(invocation -> { @@ -525,7 +570,6 @@ public void testDeployModel() { verify(mlDeployModelResponseActionListener).onResponse(mlDeployModelResponse); } - @Test public void testDeployModelFail() { Exception ex = new ExecuteException("Testing"); @@ -538,14 +582,12 @@ public void testDeployModelFail() { verify(mlDeployModelResponseActionListener).onFailure(ex); } - @Test public void testWrongInput() throws ExecuteException { exceptionRule.expect(ExecuteException.class); metricsCorrelation.execute(mock(LocalSampleCalculatorInput.class)); } - @Test public void parseModelTensorOutput_NullOutput() { exceptionRule.expect(MLException.class); @@ -553,7 +595,6 @@ public void parseModelTensorOutput_NullOutput() { metricsCorrelation.parseModelTensorOutput(null, null); } - @Test public void initModel_NullModelZipFile() { exceptionRule.expect(IllegalArgumentException.class); @@ -563,7 +604,6 @@ public void initModel_NullModelZipFile() { metricsCorrelation.initModel(model, params); } - @Test public void initModel_NullModelHelper() throws URISyntaxException { exceptionRule.expect(IllegalArgumentException.class); @@ -573,7 +613,6 @@ public void initModel_NullModelHelper() throws URISyntaxException { metricsCorrelation.initModel(model, params); } - @Test public void initModel_NullMLEngine() throws URISyntaxException { exceptionRule.expect(IllegalArgumentException.class); @@ -584,7 +623,6 @@ public void initModel_NullMLEngine() throws URISyntaxException { metricsCorrelation.initModel(model, params); } - @Test public void initModel_NullModelId() throws URISyntaxException { exceptionRule.expect(IllegalArgumentException.class); @@ -594,7 +632,6 @@ public void initModel_NullModelId() throws URISyntaxException { metricsCorrelation.initModel(model, params); } - @Test public void initModel_WrongFunctionName() { exceptionRule.expect(IllegalArgumentException.class); @@ -663,4 +700,54 @@ private SearchResponse createEmptySearchModelResponse( SearchResponse.Clusters.EMPTY ); } + + public static ClusterState setupTestClusterState() { + Set roleSet = new HashSet<>(); + roleSet.add(DiscoveryNodeRole.DATA_ROLE); + DiscoveryNode node = new DiscoveryNode( + "node", + new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), + new HashMap<>(), + roleSet, + Version.CURRENT + ); + Metadata metadata = new Metadata.Builder() + .indices( + ImmutableMap + .builder() + .put( + ML_MODEL_INDEX, + IndexMetadata + .builder("test") + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .put(ML_MODEL_GROUP_INDEX, IndexMetadata.builder(ML_MODEL_GROUP_INDEX) + .settings(Settings.builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id)) + .build()) + .build() + ) + .build(); + return new ClusterState( + new ClusterName("test cluster"), + 123l, + "111111", + metadata, + null, + DiscoveryNodes.builder().add(node).build(), + null, + Map.of(), + 0, + false + ); + } }