From 3efb65f088fa621dba4067f25c4b0bc7bd4be3b4 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Fri, 6 Oct 2023 14:18:33 -0700 Subject: [PATCH] fixing metrics correlation algorithm (#1448) * fixing metrics correlation algorithm Signed-off-by: Dhrubo Saha --- .../org/opensearch/ml/common/MLModel.java | 6 +- .../model/MetricsCorrelationModelConfig.java | 5 + .../register/MLRegisterModelInput.java | 7 +- .../MetricsCorrelationModelConfigTests.java | 84 ++++++++ .../register/MLRegisterModelInputTest.java | 55 +++++- .../ml/engine/algorithms/DLModelExecute.java | 2 +- .../MetricsCorrelation.java | 23 ++- .../MetricsCorrelationTest.java | 183 ++++++++++++++---- 8 files changed, 317 insertions(+), 48 deletions(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index ae28066d5a..78f6f4ac60 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -191,7 +191,11 @@ public MLModel(StreamInput input) throws IOException{ modelContentSizeInBytes = input.readOptionalLong(); modelContentHash = input.readOptionalString(); if (input.readBoolean()) { - modelConfig = new TextEmbeddingModelConfig(input); + if (algorithm.equals(FunctionName.METRICS_CORRELATION)) { + modelConfig = new MetricsCorrelationModelConfig(input); + } else { + modelConfig = new TextEmbeddingModelConfig(input); + } } createdTime = input.readOptionalInstant(); lastUpdateTime = input.readOptionalInstant(); diff --git a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java index 4f26e4b4d2..e1c9203cae 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java @@ -8,6 +8,7 @@ import lombok.Builder; import lombok.Getter; import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; @@ -28,6 +29,10 @@ public MetricsCorrelationModelConfig(String modelType, String allConfig) { super(modelType, allConfig); } + public MetricsCorrelationModelConfig(StreamInput in) throws IOException{ + super(in); + } + @Override public String getWriteableName() { return PARSE_FIELD_NAME; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index e79a09c5b2..71fd8afa27 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -18,6 +18,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import java.io.IOException; @@ -137,7 +138,11 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.modelFormat = in.readEnum(MLModelFormat.class); } if (in.readBoolean()) { - this.modelConfig = new TextEmbeddingModelConfig(in); + if (this.functionName.equals(FunctionName.METRICS_CORRELATION)) { + this.modelConfig = new MetricsCorrelationModelConfig(in); + } else { + this.modelConfig = new TextEmbeddingModelConfig(in); + } } this.deployModel = in.readBoolean(); this.modelNodeIds = in.readOptionalStringArray(); diff --git a/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java new file mode 100644 index 0000000000..4700039939 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class MetricsCorrelationModelConfigTests { + + MetricsCorrelationModelConfig config; + Function function; + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + config = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + function = parser -> { + try { + return MetricsCorrelationModelConfig.parse(parser); + } catch (IOException e) { + throw new RuntimeException("Failed to parse MetricsCorrelationModelConfig", e); + } + }; + } + + @Test + public void toXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + config.toXContent(builder, EMPTY_PARAMS); + String configContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + } + + @Test + public void nullFields_ModelType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("model type is null"); + config = MetricsCorrelationModelConfig.builder() + .build(); + } + + @Test + public void parse() throws IOException { + String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + TestHelper.testParseFromString(config, content, function); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(config); + } + + public void readInputStream(MetricsCorrelationModelConfig config) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + config.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MetricsCorrelationModelConfig parsedConfig = new MetricsCorrelationModelConfig(streamInput); + assertEquals(config.getModelType(), parsedConfig.getModelType()); + assertEquals(config.getAllConfig(), parsedConfig.getAllConfig()); + assertEquals(config.getWriteableName(), parsedConfig.getWriteableName()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index 24a409bd44..ee95447f62 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -22,6 +22,7 @@ import org.opensearch.ml.common.connector.HttpConnectorTest; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.search.SearchModule; @@ -74,7 +75,6 @@ public void setUp() throws Exception { .deployModel(true) .modelNodeIds(new String[]{"modelNodeIds" }) .build(); - } @Test @@ -257,6 +257,59 @@ public void readInputStream_WithInternalConnector() throws IOException { }); } + @Test + public void testMCorrInput() throws IOException { + String testString = "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + + MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + mcorrInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = builder.toString(); + assertEquals(testString, jsonStr); + } + + @Test + public void readInputStream_MCorr() throws IOException { + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); + + MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[]{"modelNodeIds" }) + .build(); + readInputStream(mcorrInput, parsedInput -> { + assertEquals(parsedInput.getModelConfig().getModelType(), mcorrConfig.getModelType()); + assertEquals(parsedInput.getModelConfig().getAllConfig(), mcorrConfig.getAllConfig()); + assertEquals(parsedInput.getFunctionName(), FunctionName.METRICS_CORRELATION); + assertEquals(parsedInput.getModelName(), FunctionName.METRICS_CORRELATION.name()); + assertEquals(parsedInput.getModelGroupId(), modelGroupId); + }); + } + private void readInputStream(MLRegisterModelInput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java index 4e6bd6bd69..fab052da19 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java @@ -120,7 +120,7 @@ public void close() { * @param modelId id of the model * @param modelName name of the model * @param version version of the model - * @param engine engine where model will be run. For now we are supporting only pytorch engine only. + * @param engine engine where model will be run. For now, we are supporting only pytorch engine only. */ private void loadModel(File modelZipFile, String modelId, String modelName, String version, String engine) { try { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index 96aad3168e..1c96a048fc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -121,7 +121,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { if (modelId == null) { boolean hasModelGroupIndex = clusterService.state().getMetadata().hasIndex(ML_MODEL_GROUP_INDEX); - if (!hasModelGroupIndex) { // Create model group index if doesn't exist + if (!hasModelGroupIndex) { // Create model group index if it doesn't exist try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { CreateIndexRequest request = new CreateIndexRequest(ML_MODEL_GROUP_INDEX).mapping(ML_MODEL_GROUP_INDEX_MAPPING); CreateIndexResponse createIndexResponse = client.admin().indices().create(request).actionGet(1000); @@ -175,8 +175,10 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { ) ); } - }, e -> { log.error("Failed to get model", e); }); - client.get(getModelRequest, ActionListener.runBefore(listener, () -> context.restore())); + }, e-> { + log.error("Failed to get model", e); + }); + client.get(getModelRequest, ActionListener.runBefore(listener, context::restore)); } } } else { @@ -197,10 +199,17 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { waitUntil(() -> { if (modelId != null) { MLModelState modelState = getModel(modelId).getModelState(); - return modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED; + if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED){ + log.info("Model deployed: " + modelState); + return true; + } else if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) { + log.info("Model not deployed: " + modelState); + deployModel(modelId, ActionListener.wrap(deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e))); + return false; + } } return false; - }, 10, TimeUnit.SECONDS); + }, 120, TimeUnit.SECONDS); Output djlOutput; try { @@ -253,7 +262,7 @@ void registerModel(ActionListener listener) throws Inte log.error("Failed to Register Model", e); listener.onFailure(e); })); - }, e -> { listener.onFailure(e); }), () -> context.restore())); + }, listener::onFailure), context::restore)); } catch (IOException e) { throw new MLException(e); } @@ -322,6 +331,8 @@ public static boolean waitUntil(BooleanSupplier breakSupplier, long maxWaitTime, } sum += timeInMillis; timeInMillis = Math.min(AWAIT_BUSY_THRESHOLD, timeInMillis * 2); + + log.info("Waiting... Time elapsed: " + sum + "ms"); } timeInMillis = maxTimeInMillis - sum; try { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index 02132687e3..36a1fe9998 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -5,36 +5,7 @@ package org.opensearch.ml.engine.algorithms.metrics_correlation; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE; -import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER; -import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE; -import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION; -import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH; - -import java.io.File; -import java.io.IOException; -import java.net.URISyntaxException; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; - +import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Ignore; @@ -43,16 +14,27 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; @@ -97,19 +79,59 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.suggest.Suggest; +import org.opensearch.threadpool.ThreadPool; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE; +import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER; +import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE; +import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION; +import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH; -//TODO: fix mockito error: Cannot mock/spy class org.opensearch.common.settings.Settings final class -@Ignore public class MetricsCorrelationTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Mock Client client; - @Mock Settings settings; + @Mock private ClusterService clusterService; + + @Mock + ThreadPool threadPool; + + ThreadContext threadContext; + @Mock SearchRequest searchRequest; SearchResponse searchResponse; @@ -142,6 +164,8 @@ public class MetricsCorrelationTest { private final String modelId = "modelId"; private final String modelGroupId = "modelGroupId"; + final String USER_STRING = "myuser|role1,role2|myTenant"; + MLTask mlTask; Map params = new HashMap<>(); @@ -180,6 +204,16 @@ public void setUp() throws IOException, URISyntaxException { MockitoAnnotations.openMocks(this); metricsCorrelation = spy(new MetricsCorrelation(client, settings, clusterService)); + + settings = Settings.builder().build(); + ClusterState testClusterState = setupTestClusterState(); + when(clusterService.state()).thenReturn(testClusterState); + + threadContext = new ThreadContext(settings); + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + List inputData = new ArrayList<>(); inputData.add(new float[] { -1.0f, 2.0f, 3.0f }); inputData.add(new float[] { -1.0f, 2.0f, 3.0f }); @@ -261,16 +295,27 @@ public void setUp() throws IOException, URISyntaxException { extendedInput = MetricsCorrelationInput.builder().inputData(extendedInputData).build(); } + + @Ignore @Test public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteException { - metricsCorrelation.initModel(model, params); MLModelGetResponse response = new MLModelGetResponse(model); ActionFuture mockedFuture = mock(ActionFuture.class); when(client.execute(any(MLModelGetAction.class), any(MLModelGetRequest.class))).thenReturn(mockedFuture); when(mockedFuture.actionGet(anyLong())).thenReturn(response); doAnswer(invocation -> { - MLModel smallModel = model.toBuilder().modelConfig(modelConfig).modelState(MLModelState.DEPLOYED).build(); + + MLModel smallModel = MLModel.builder() + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .name(FunctionName.METRICS_CORRELATION.name()) + .modelId(modelId) + .modelGroupId(modelGroupId) + .algorithm(FunctionName.METRICS_CORRELATION) + .version(MCORR_ML_VERSION) + .modelConfig(modelConfig) + .modelState(MLModelState.UNDEPLOYED) + .build(); MLModelGetResponse responseTemp = new MLModelGetResponse(smallModel); ActionFuture mockedFutureTemp = mock(ActionFuture.class); MLTaskGetResponse taskResponse = new MLTaskGetResponse(mlTask); @@ -278,8 +323,8 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio when(client.execute(any(MLTaskGetAction.class), any(MLTaskGetRequest.class))).thenReturn(mockedFutureResponse); when(mockedFutureResponse.actionGet(anyLong())).thenReturn(taskResponse); when(mockedFutureTemp.actionGet(anyLong())).thenReturn(responseTemp); - metricsCorrelation.initModel(smallModel, params); + smallModel.toBuilder().modelState(MLModelState.DEPLOYED).build(); return null; }).when(client).execute(any(MLDeployModelAction.class), any(MLDeployModelRequest.class), isA(ActionListener.class)); @@ -289,6 +334,8 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio assertNull(mlModelOutputs.get(0).getMCorrModelTensors()); } + + @Ignore @Test public void testExecuteWithModelInIndexAndEmptyOutput() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -349,6 +396,7 @@ public void testExecuteWithModelInIndexAndOneEvent() throws ExecuteException, UR assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } + @Ignore @Test public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -389,6 +437,8 @@ public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, UR assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } + + @Ignore @Test public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -435,6 +485,8 @@ public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws Execu assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } + + @Ignore @Test public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, URISyntaxException { Map params = new HashMap<>(); @@ -476,6 +528,8 @@ public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } + + //working @Test public void testGetModel() { ActionFuture mockedFuture = mock(ActionFuture.class); @@ -508,6 +562,7 @@ public static XContentBuilder builder() throws IOException { return XContentBuilder.builder(XContentType.JSON.xContent()); } + //working @Test public void testSearchRequest() { String expectedIndex = CommonValue.ML_MODEL_INDEX; @@ -546,6 +601,8 @@ public void testSearchRequest() { assertEquals(MLModel.MODEL_VERSION_FIELD, versionQueryBuilder.fieldName()); } + + @Ignore @Test public void testRegisterModel() throws InterruptedException { doAnswer(invocation -> { @@ -711,4 +768,54 @@ private SearchResponse createEmptySearchModelResponse() throws IOException { SearchResponse.Clusters.EMPTY ); } + + public static ClusterState setupTestClusterState() { + Set roleSet = new HashSet<>(); + roleSet.add(DiscoveryNodeRole.DATA_ROLE); + DiscoveryNode node = new DiscoveryNode( + "node", + new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), + new HashMap<>(), + roleSet, + Version.CURRENT + ); + Metadata metadata = new Metadata.Builder() + .indices( + ImmutableMap + .builder() + .put( + ML_MODEL_INDEX, + IndexMetadata + .builder("test") + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .put(ML_MODEL_GROUP_INDEX, IndexMetadata.builder(ML_MODEL_GROUP_INDEX) + .settings(Settings.builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id)) + .build()) + .build() + ) + .build(); + return new ClusterState( + new ClusterName("test cluster"), + 123l, + "111111", + metadata, + null, + DiscoveryNodes.builder().add(node).build(), + null, + Map.of(), + 0, + false + ); + } }