Skip to content

Commit

Permalink
Implemented Register Model Step
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Sep 23, 2023
1 parent 0fbd6bc commit ddba6a3
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 24 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ dependencies {
implementation 'org.junit.jupiter:junit-jupiter:5.10.0'
implementation "com.google.code.gson:gson:2.10.1"
compileOnly "com.google.guava:guava:32.1.2-jre"
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
api group: 'org.opensearch', name:'opensearch-ml-client', version: "2.10.0.0-SNAPSHOT"

configurations.all {
resolutionStrategy {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,47 @@
*/
package org.opensearch.flowframework;

import com.google.common.collect.ImmutableList;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.flowframework.workflow.RegisterModel.RegisterModelStep;
import org.opensearch.plugins.Plugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import java.util.Collection;
import java.util.function.Supplier;

/**
* An OpenSearch plugin that enables builders to innovate AI apps on OpenSearch.
*/
public class FlowFrameworkPlugin extends Plugin {
// Implement the relevant Plugin Interfaces here

private Client client;

@Override
public Collection<Object> createComponents(
Client client,
ClusterService clusterService,
ThreadPool threadPool,
ResourceWatcherService resourceWatcherService,
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
Environment environment,
NodeEnvironment nodeEnvironment,
NamedWriteableRegistry namedWriteableRegistry,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
this.client = client;
RegisterModelStep uploadModelStep = new RegisterModelStep(client);
return ImmutableList.of(uploadModelStep);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow.RegisterModel;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStep;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

public class RegisterModelStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private Client client;
private final String NAME = "register_model_step";

private static final String FUNCTION_NAME = "function_name";
private static final String MODEL_NAME = "model_name";
private static final String MODEL_VERSION = "model_version";
private static final String MODEL_GROUP_ID = "model_group_id";
private static final String MODEL_URL = "url";
private static final String MODEL_FORMAT = "model_format";
private static final String MODEL_CONFIG = "model_config";
private static final String DEPLOY_MODEL = "deploy_model";
private static final String MODEL_NODES_IDS = "model_nodes_ids";

public RegisterModelStep(Client client) {
this.client = client;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> registerModelFuture = new CompletableFuture<>();

FunctionName functionName = null;
String modelName = null;
String modelVersion = null;
String modelGroupId = null;
String modelUrl = null;
MLModelFormat modelFormat = null;
String modelConfig = null;
Boolean deployModel = null;
String[] modelNodesId = null;

for (WorkflowData workflowData : data) {
Map<String, String> parameters = workflowData.getParams();
Map<String, Object> content = workflowData.getContent();
logger.info("Previous step sent params: {}, content: {}", parameters, content);

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case FUNCTION_NAME:
functionName = (FunctionName) content.get(FUNCTION_NAME);
break;
case MODEL_NAME:
modelName = (String) content.get(MODEL_NAME);
break;
case MODEL_VERSION:
modelVersion = (String) content.get(MODEL_VERSION);
break;
case MODEL_GROUP_ID:
modelGroupId = (String) content.get(MODEL_GROUP_ID);
break;
case MODEL_URL:
modelUrl = (String) content.get(MODEL_URL);
break;
case MODEL_FORMAT:
modelFormat = (MLModelFormat) content.get(MODEL_FORMAT);
break;
case MODEL_CONFIG:
modelConfig = (String) content.get(MODEL_CONFIG);
break;
case DEPLOY_MODEL:
deployModel = (Boolean) content.get(DEPLOY_MODEL);
break;
case MODEL_NODES_IDS:
modelNodesId = (String[]) content.get(MODEL_NODES_IDS);
default:
break;

}
}
}

if (Stream.of(functionName, modelName, modelVersion, modelGroupId, modelConfig, modelFormat, deployModel, modelNodesId)
.allMatch(x -> x != null)) {
MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client);
// TODO: Add model Config and type cast correctly
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(functionName)
.modelName(modelName)
.version(modelVersion)
.modelGroupId(modelGroupId)
.url(modelUrl)
.modelFormat(modelFormat)
.deployModel(deployModel)
.modelNodeIds(modelNodesId)
.build();

MLRegisterModelResponse mlRegisterModelResponse = machineLearningNodeClient.register(mlInput).actionGet();

registerModelFuture.complete(new WorkflowData() {
@Override
public Map<String, Object> getContent() {
return Map.ofEntries(
Map.entry("taskId", mlRegisterModelResponse.getTaskId()),
Map.entry("status", mlRegisterModelResponse.getStatus())
);
}
});

} else {
logger.error("Failed to register model");
registerModelFuture.completeExceptionally(new IOException("Failed to register model "));
}
return registerModelFuture;
}

@Override
public String getName() {
return NAME;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow.RegisterModel;

import org.opensearch.client.node.NodeClient;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.mockito.ArgumentCaptor;
import org.mockito.Mock;

import static org.mockito.Answers.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.*;

public class RegisterModelTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock(answer = RETURNS_DEEP_STUBS)
NodeClient client;

MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

inputData = new WorkflowData() {
@Override
public Map<String, Object> getContent() {
return Map.ofEntries(
Map.entry("function_name", FunctionName.KMEANS),
Map.entry("model_name", "bedrock"),
Map.entry("model_version", "1.0.0"),
Map.entry("model_group_id", "1.0"),
Map.entry("url", "url"),
Map.entry("model_format", MLModelFormat.TORCH_SCRIPT),
Map.entry("deploy_model", true),
Map.entry("model_nodes_ids", new String[] { "foo", "bar", "baz" })
);
}
};

machineLearningNodeClient = mock(MachineLearningNodeClient.class);

}

public void testRegisterModel() {

FunctionName functionName = FunctionName.KMEANS;

MLModelConfig config = TextEmbeddingModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.build();

MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(functionName)
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.deployModel(true)
.modelNodeIds(new String[] { "modelNodeIds" })
.build();

RegisterModelStep registerModelStep = new RegisterModelStep(client);

ArgumentCaptor<MLRegisterModelResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class);
CompletableFuture<WorkflowData> future = registerModelStep.execute(List.of(inputData));

verify(machineLearningNodeClient, times(1)).register(mlInput);
assertEquals("1", (argumentCaptor.getValue()).getTaskId());

assertTrue(future.isDone());
}

}

0 comments on commit ddba6a3

Please sign in to comment.