From ddba6a3bd4ef9c0729f570ba986a7841d6bf9912 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 22 Sep 2023 18:49:40 -0700 Subject: [PATCH] Implemented Register Model Step Signed-off-by: Owais Kazi --- build.gradle | 2 +- .../flowframework/FlowFrameworkPlugin.java | 38 ++++- .../RegisterModel/RegisterModelStep.java | 145 ++++++++++++++++++ .../workflow/UploadModel/UploadModelStep.java | 22 --- .../RegisterModel/RegisterModelTests.java | 98 ++++++++++++ 5 files changed, 281 insertions(+), 24 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java delete mode 100644 src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java diff --git a/build.gradle b/build.gradle index aa20423ee..691cebb6b 100644 --- a/build.gradle +++ b/build.gradle @@ -107,7 +107,7 @@ dependencies { implementation 'org.junit.jupiter:junit-jupiter:5.10.0' implementation "com.google.code.gson:gson:2.10.1" compileOnly "com.google.guava:guava:32.1.2-jre" - api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" + api group: 'org.opensearch', name:'opensearch-ml-client', version: "2.10.0.0-SNAPSHOT" configurations.all { resolutionStrategy { diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index f810767eb..ed0d80805 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -8,11 +8,47 @@ */ package org.opensearch.flowframework; +import com.google.common.collect.ImmutableList; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.flowframework.workflow.RegisterModel.RegisterModelStep; import org.opensearch.plugins.Plugin; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.watcher.ResourceWatcherService; + +import java.util.Collection; +import java.util.function.Supplier; /** * An OpenSearch plugin that enables builders to innovate AI apps on OpenSearch. */ public class FlowFrameworkPlugin extends Plugin { - // Implement the relevant Plugin Interfaces here + + private Client client; + + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { + this.client = client; + RegisterModelStep uploadModelStep = new RegisterModelStep(client); + return ImmutableList.of(uploadModelStep); + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java new file mode 100644 index 000000000..906acfe74 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelStep.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow.RegisterModel; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; +import org.opensearch.flowframework.client.MLClient; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +public class RegisterModelStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); + + private Client client; + private final String NAME = "register_model_step"; + + private static final String FUNCTION_NAME = "function_name"; + private static final String MODEL_NAME = "model_name"; + private static final String MODEL_VERSION = "model_version"; + private static final String MODEL_GROUP_ID = "model_group_id"; + private static final String MODEL_URL = "url"; + private static final String MODEL_FORMAT = "model_format"; + private static final String MODEL_CONFIG = "model_config"; + private static final String DEPLOY_MODEL = "deploy_model"; + private static final String MODEL_NODES_IDS = "model_nodes_ids"; + + public RegisterModelStep(Client client) { + this.client = client; + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture registerModelFuture = new CompletableFuture<>(); + + FunctionName functionName = null; + String modelName = null; + String modelVersion = null; + String modelGroupId = null; + String modelUrl = null; + MLModelFormat modelFormat = null; + String modelConfig = null; + Boolean deployModel = null; + String[] modelNodesId = null; + + for (WorkflowData workflowData : data) { + Map parameters = workflowData.getParams(); + Map content = workflowData.getContent(); + logger.info("Previous step sent params: {}, content: {}", parameters, content); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case FUNCTION_NAME: + functionName = (FunctionName) content.get(FUNCTION_NAME); + break; + case MODEL_NAME: + modelName = (String) content.get(MODEL_NAME); + break; + case MODEL_VERSION: + modelVersion = (String) content.get(MODEL_VERSION); + break; + case MODEL_GROUP_ID: + modelGroupId = (String) content.get(MODEL_GROUP_ID); + break; + case MODEL_URL: + modelUrl = (String) content.get(MODEL_URL); + break; + case MODEL_FORMAT: + modelFormat = (MLModelFormat) content.get(MODEL_FORMAT); + break; + case MODEL_CONFIG: + modelConfig = (String) content.get(MODEL_CONFIG); + break; + case DEPLOY_MODEL: + deployModel = (Boolean) content.get(DEPLOY_MODEL); + break; + case MODEL_NODES_IDS: + modelNodesId = (String[]) content.get(MODEL_NODES_IDS); + default: + break; + + } + } + } + + if (Stream.of(functionName, modelName, modelVersion, modelGroupId, modelConfig, modelFormat, deployModel, modelNodesId) + .allMatch(x -> x != null)) { + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client); + // TODO: Add model Config and type cast correctly + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(functionName) + .modelName(modelName) + .version(modelVersion) + .modelGroupId(modelGroupId) + .url(modelUrl) + .modelFormat(modelFormat) + .deployModel(deployModel) + .modelNodeIds(modelNodesId) + .build(); + + MLRegisterModelResponse mlRegisterModelResponse = machineLearningNodeClient.register(mlInput).actionGet(); + + registerModelFuture.complete(new WorkflowData() { + @Override + public Map getContent() { + return Map.ofEntries( + Map.entry("taskId", mlRegisterModelResponse.getTaskId()), + Map.entry("status", mlRegisterModelResponse.getStatus()) + ); + } + }); + + } else { + logger.error("Failed to register model"); + registerModelFuture.completeExceptionally(new IOException("Failed to register model ")); + } + return registerModelFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java deleted file mode 100644 index 7d27cf8b3..000000000 --- a/src/main/java/org/opensearch/flowframework/workflow/UploadModel/UploadModelStep.java +++ /dev/null @@ -1,22 +0,0 @@ -package org.opensearch.flowframework.workflow.UploadModel; - -import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.flowframework.workflow.WorkflowStep; - -import java.util.List; -import java.util.concurrent.CompletableFuture; - -public class UploadModelStep implements WorkflowStep { - - private final String NAME = "upload_model_step"; - - @Override - public CompletableFuture execute(List data) { - return null; - } - - @Override - public String getName() { - return NAME; - } -} diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java new file mode 100644 index 000000000..80839fee6 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModel/RegisterModelTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow.RegisterModel; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; + +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.Mockito.*; + +public class RegisterModelTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock(answer = RETURNS_DEEP_STUBS) + NodeClient client; + + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData() { + @Override + public Map getContent() { + return Map.ofEntries( + Map.entry("function_name", FunctionName.KMEANS), + Map.entry("model_name", "bedrock"), + Map.entry("model_version", "1.0.0"), + Map.entry("model_group_id", "1.0"), + Map.entry("url", "url"), + Map.entry("model_format", MLModelFormat.TORCH_SCRIPT), + Map.entry("deploy_model", true), + Map.entry("model_nodes_ids", new String[] { "foo", "bar", "baz" }) + ); + } + }; + + machineLearningNodeClient = mock(MachineLearningNodeClient.class); + + } + + public void testRegisterModel() { + + FunctionName functionName = FunctionName.KMEANS; + + MLModelConfig config = TextEmbeddingModelConfig.builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLRegisterModelInput mlInput = MLRegisterModelInput.builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); + + RegisterModelStep registerModelStep = new RegisterModelStep(client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + CompletableFuture future = registerModelStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient, times(1)).register(mlInput); + assertEquals("1", (argumentCaptor.getValue()).getTaskId()); + + assertTrue(future.isDone()); + } + +}