-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Owais Kazi <[email protected]>
- Loading branch information
1 parent
0fbd6bc
commit ddba6a3
Showing
5 changed files
with
281 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
145 changes: 145 additions & 0 deletions
145
src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow.RegisterModel; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.client.node.NodeClient; | ||
import org.opensearch.flowframework.client.MLClient; | ||
import org.opensearch.flowframework.workflow.WorkflowData; | ||
import org.opensearch.flowframework.workflow.WorkflowStep; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.model.MLModelFormat; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelInput; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Map.Entry; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.stream.Stream; | ||
|
||
public class RegisterModelStep implements WorkflowStep { | ||
|
||
private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); | ||
|
||
private Client client; | ||
private final String NAME = "register_model_step"; | ||
|
||
private static final String FUNCTION_NAME = "function_name"; | ||
private static final String MODEL_NAME = "model_name"; | ||
private static final String MODEL_VERSION = "model_version"; | ||
private static final String MODEL_GROUP_ID = "model_group_id"; | ||
private static final String MODEL_URL = "url"; | ||
private static final String MODEL_FORMAT = "model_format"; | ||
private static final String MODEL_CONFIG = "model_config"; | ||
private static final String DEPLOY_MODEL = "deploy_model"; | ||
private static final String MODEL_NODES_IDS = "model_nodes_ids"; | ||
|
||
public RegisterModelStep(Client client) { | ||
this.client = client; | ||
} | ||
|
||
@Override | ||
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) { | ||
|
||
CompletableFuture<WorkflowData> registerModelFuture = new CompletableFuture<>(); | ||
|
||
FunctionName functionName = null; | ||
String modelName = null; | ||
String modelVersion = null; | ||
String modelGroupId = null; | ||
String modelUrl = null; | ||
MLModelFormat modelFormat = null; | ||
String modelConfig = null; | ||
Boolean deployModel = null; | ||
String[] modelNodesId = null; | ||
|
||
for (WorkflowData workflowData : data) { | ||
Map<String, String> parameters = workflowData.getParams(); | ||
Map<String, Object> content = workflowData.getContent(); | ||
logger.info("Previous step sent params: {}, content: {}", parameters, content); | ||
|
||
for (Entry<String, Object> entry : content.entrySet()) { | ||
switch (entry.getKey()) { | ||
case FUNCTION_NAME: | ||
functionName = (FunctionName) content.get(FUNCTION_NAME); | ||
break; | ||
case MODEL_NAME: | ||
modelName = (String) content.get(MODEL_NAME); | ||
break; | ||
case MODEL_VERSION: | ||
modelVersion = (String) content.get(MODEL_VERSION); | ||
break; | ||
case MODEL_GROUP_ID: | ||
modelGroupId = (String) content.get(MODEL_GROUP_ID); | ||
break; | ||
case MODEL_URL: | ||
modelUrl = (String) content.get(MODEL_URL); | ||
break; | ||
case MODEL_FORMAT: | ||
modelFormat = (MLModelFormat) content.get(MODEL_FORMAT); | ||
break; | ||
case MODEL_CONFIG: | ||
modelConfig = (String) content.get(MODEL_CONFIG); | ||
break; | ||
case DEPLOY_MODEL: | ||
deployModel = (Boolean) content.get(DEPLOY_MODEL); | ||
break; | ||
case MODEL_NODES_IDS: | ||
modelNodesId = (String[]) content.get(MODEL_NODES_IDS); | ||
default: | ||
break; | ||
|
||
} | ||
} | ||
} | ||
|
||
if (Stream.of(functionName, modelName, modelVersion, modelGroupId, modelConfig, modelFormat, deployModel, modelNodesId) | ||
.allMatch(x -> x != null)) { | ||
MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client); | ||
// TODO: Add model Config and type cast correctly | ||
MLRegisterModelInput mlInput = MLRegisterModelInput.builder() | ||
.functionName(functionName) | ||
.modelName(modelName) | ||
.version(modelVersion) | ||
.modelGroupId(modelGroupId) | ||
.url(modelUrl) | ||
.modelFormat(modelFormat) | ||
.deployModel(deployModel) | ||
.modelNodeIds(modelNodesId) | ||
.build(); | ||
|
||
MLRegisterModelResponse mlRegisterModelResponse = machineLearningNodeClient.register(mlInput).actionGet(); | ||
|
||
registerModelFuture.complete(new WorkflowData() { | ||
@Override | ||
public Map<String, Object> getContent() { | ||
return Map.ofEntries( | ||
Map.entry("taskId", mlRegisterModelResponse.getTaskId()), | ||
Map.entry("status", mlRegisterModelResponse.getStatus()) | ||
); | ||
} | ||
}); | ||
|
||
} else { | ||
logger.error("Failed to register model"); | ||
registerModelFuture.completeExceptionally(new IOException("Failed to register model ")); | ||
} | ||
return registerModelFuture; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
} |
22 changes: 0 additions & 22 deletions
22
src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java
This file was deleted.
Oops, something went wrong.
98 changes: 98 additions & 0 deletions
98
src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow.RegisterModel; | ||
|
||
import org.opensearch.client.node.NodeClient; | ||
import org.opensearch.flowframework.workflow.WorkflowData; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.model.MLModelConfig; | ||
import org.opensearch.ml.common.model.MLModelFormat; | ||
import org.opensearch.ml.common.model.TextEmbeddingModelConfig; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelInput; | ||
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; | ||
import org.opensearch.test.OpenSearchTestCase; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
|
||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.Mock; | ||
|
||
import static org.mockito.Answers.RETURNS_DEEP_STUBS; | ||
import static org.mockito.Mockito.*; | ||
|
||
public class RegisterModelTests extends OpenSearchTestCase { | ||
private WorkflowData inputData = WorkflowData.EMPTY; | ||
|
||
@Mock(answer = RETURNS_DEEP_STUBS) | ||
NodeClient client; | ||
|
||
MachineLearningNodeClient machineLearningNodeClient; | ||
|
||
@Override | ||
public void setUp() throws Exception { | ||
super.setUp(); | ||
|
||
inputData = new WorkflowData() { | ||
@Override | ||
public Map<String, Object> getContent() { | ||
return Map.ofEntries( | ||
Map.entry("function_name", FunctionName.KMEANS), | ||
Map.entry("model_name", "bedrock"), | ||
Map.entry("model_version", "1.0.0"), | ||
Map.entry("model_group_id", "1.0"), | ||
Map.entry("url", "url"), | ||
Map.entry("model_format", MLModelFormat.TORCH_SCRIPT), | ||
Map.entry("deploy_model", true), | ||
Map.entry("model_nodes_ids", new String[] { "foo", "bar", "baz" }) | ||
); | ||
} | ||
}; | ||
|
||
machineLearningNodeClient = mock(MachineLearningNodeClient.class); | ||
|
||
} | ||
|
||
public void testRegisterModel() { | ||
|
||
FunctionName functionName = FunctionName.KMEANS; | ||
|
||
MLModelConfig config = TextEmbeddingModelConfig.builder() | ||
.modelType("testModelType") | ||
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") | ||
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) | ||
.embeddingDimension(100) | ||
.build(); | ||
|
||
MLRegisterModelInput mlInput = MLRegisterModelInput.builder() | ||
.functionName(functionName) | ||
.modelName("testModelName") | ||
.version("testModelVersion") | ||
.modelGroupId("modelGroupId") | ||
.url("url") | ||
.modelFormat(MLModelFormat.ONNX) | ||
.modelConfig(config) | ||
.deployModel(true) | ||
.modelNodeIds(new String[] { "modelNodeIds" }) | ||
.build(); | ||
|
||
RegisterModelStep registerModelStep = new RegisterModelStep(client); | ||
|
||
ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); | ||
CompletableFuture<WorkflowData> future = registerModelStep.execute(List.of(inputData)); | ||
|
||
verify(machineLearningNodeClient, times(1)).register(mlInput); | ||
assertEquals("1", (argumentCaptor.getValue()).getTaskId()); | ||
|
||
assertTrue(future.isDone()); | ||
} | ||
|
||
} |