From de860dde2e09e8985a64f5d8e2eea46b327bbd7d Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 30 Aug 2023 08:58:11 +0800 Subject: [PATCH] [2.x] Adding an integration test for redeploying a model (#1016) (#1264) * Adding a failing integ test for redeploy model and fix breaking changes from OpenSearch core * Adding model group ID changes for tests * Fixing tests for ImmutableMap copy * Commenting wait out task for model * Adding a failing integ test for redeploy model and fix breaking changes from OpenSearch core * Rebasing with 2.x * Adding logs to debug the test in GHA * GHA tests * Still debugging * Removing comment * Removing unnecessary changes * Removing logs --------- Signed-off-by: Sarat Vemulapalli Co-authored-by: Sarat Vemulapalli --- .../ml/rest/RestMLDeployModelActionIT.java | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java new file mode 100644 index 0000000000..fa542c7c0c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; + +import java.io.IOException; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.utils.TestHelper; + +public class RestMLDeployModelActionIT extends MLCommonsRestTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + private MLRegisterModelInput registerModelInput; + private MLRegisterModelGroupInput mlRegisterModelGroupInput; + private String modelGroupId; + + @Before + public void setup() throws IOException { + mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder().name("testGroupID").description("This is test Group").build(); + registerModelGroup(client(), TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> { + this.modelGroupId = (String) registerModelGroupResult.get("model_group_id"); + }); + registerModelInput = createRegisterModelInput(modelGroupId); + } + + public void testReDeployModel() throws InterruptedException, IOException { + // Register Model + String taskId = registerModel(TestHelper.toJsonString(registerModelInput)); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + String model_id = (String) response.get(MODEL_ID_FIELD); + try { + // Deploy Model + String taskId1 = deployModel(model_id); + getTask(client(), taskId1, innerResponse -> { assertEquals(model_id, innerResponse.get(MODEL_ID_FIELD)); }); + waitForTask(taskId1, MLTaskState.COMPLETED); + + // Undeploy Model + Map undeployresponse = undeployModel(model_id); + for (Map.Entry entry : undeployresponse.entrySet()) { + Map stats = (Map) ((Map) entry.getValue()).get("stats"); + assertEquals("undeployed", stats.get(model_id)); + } + + // Deploy Model again + taskId1 = deployModel(model_id); + getTask(client(), taskId1, innerResponse -> { logger.info("Re-Deploy model {}", innerResponse); }); + waitForTask(taskId1, MLTaskState.COMPLETED); + + getModel(client(), model_id, model -> { + logger.info("Get Model after re-deploy {}", model); + assertEquals("DEPLOYED", model.get("model_state")); + }); + + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } +}