Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.x] Adding an integration test for redeploying a model #1016

Merged
merged 13 commits into from
Jul 5, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.ml.rest;

import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD;

import java.io.IOException;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.utils.TestHelper;

public class RestMLDeployModelActionIT extends MLCommonsRestTestCase {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
private MLRegisterModelInput registerModelInput;
private MLRegisterModelGroupInput mlRegisterModelGroupInput;
private String modelGroupId;

@Before
public void setup() throws IOException {
mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder().name("testGroupID").description("This is test Group").build();
registerModelGroup(client(), TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> {
this.modelGroupId = (String) registerModelGroupResult.get("model_group_id");
});
registerModelInput = createRegisterModelInput(modelGroupId);
}

public void testReDeployModel() throws InterruptedException, IOException {
// Register Model
String taskId = registerModel(TestHelper.toJsonString(registerModelInput));
waitForTask(taskId, MLTaskState.COMPLETED);
getTask(client(), taskId, response -> {
String model_id = (String) response.get(MODEL_ID_FIELD);
try {
// Deploy Model
String taskId1 = deployModel(model_id);
getTask(client(), taskId1, innerResponse -> { assertEquals(model_id, innerResponse.get(MODEL_ID_FIELD)); });
waitForTask(taskId1, MLTaskState.COMPLETED);

// Undeploy Model
Map<String, Object> undeployresponse = undeployModel(model_id);
for (Map.Entry<String, Object> entry : undeployresponse.entrySet()) {
Map stats = (Map) ((Map) entry.getValue()).get("stats");
assertEquals("undeployed", stats.get(model_id));
}

// Deploy Model again
taskId1 = deployModel(model_id);
getTask(client(), taskId1, innerResponse -> { logger.info("Re-Deploy model {}", innerResponse); });
waitForTask(taskId1, MLTaskState.COMPLETED);

getModel(client(), model_id, model -> {
logger.info("Get Model after re-deploy {}", model);
assertEquals("DEPLOYED", model.get("model_state"));
});

} catch (Exception e) {
throw new RuntimeException(e);
}
});
}
}