Skip to content

Commit

Permalink
Separates RegisterModelStep into RegisterLocalModelStep and RegisterR…
Browse files Browse the repository at this point in the history
…emoteModelStep. Adds GetMLTaskStep. Handles optional params (#155)

* added RegisterRemoteModelStep and tests

Signed-off-by: Joshua Palis <[email protected]>

* Adding RegisterLocalModelStep, fixing tests, adding input/ouput definitions to workflow step json

Signed-off-by: Joshua Palis <[email protected]>

* Fixing javadoc warnings, fixing log message

Signed-off-by: Joshua Palis <[email protected]>

* Addressing PR comments,making description field optional for RegisterRemoteModelStep and RegisterLocalModelStep

Signed-off-by: Joshua Palis <[email protected]>

* moving modelConfig builder before adding allConfig

Signed-off-by: Joshua Palis <[email protected]>

* handling optional description field for remote/local model

Signed-off-by: Joshua Palis <[email protected]>

* Removing poolingMode, modelMaxLenth, normalizeResult

Signed-off-by: Joshua Palis <[email protected]>

* adding modelType to required fields check

Signed-off-by: Joshua Palis <[email protected]>

* Fixing RegisterLocalModelStep to output a task ID instead of a model id

Signed-off-by: Joshua Palis <[email protected]>

* Adding GetMLTaskStep and tests

Signed-off-by: Joshua Palis <[email protected]>

* Adding todo for GetMLTask retry capability

Signed-off-by: Joshua Palis <[email protected]>

---------

Signed-off-by: Joshua Palis <[email protected]>
  • Loading branch information
joshpalis authored Nov 13, 2023
1 parent 56ccb1d commit 2142874
Show file tree
Hide file tree
Showing 12 changed files with 634 additions and 91 deletions.
16 changes: 16 additions & 0 deletions src/main/java/org/opensearch/flowframework/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ private CommonValue() {}
public static final String OUTPUT_FIELD_NAME = "output_field_name";
/** Model Id field */
public static final String MODEL_ID = "model_id";
/** Task Id field */
public static final String TASK_ID = "task_id";
/** Register Model Status field */
public static final String REGISTER_MODEL_STATUS = "register_model_status";
/** Function Name field */
public static final String FUNCTION_NAME = "function_name";
/** Name field */
Expand All @@ -95,8 +99,20 @@ private CommonValue() {}
public static final String CONNECTOR_ID = "connector_id";
/** Model format field */
public static final String MODEL_FORMAT = "model_format";
/** Model content hash value field */
public static final String MODEL_CONTENT_HASH_VALUE = "model_content_hash_value";
/** URL field */
public static final String URL = "url";
/** Model config field */
public static final String MODEL_CONFIG = "model_config";
/** Model type field */
public static final String MODEL_TYPE = "model_type";
/** Embedding dimension field */
public static final String EMBEDDING_DIMENSION = "embedding_dimension";
/** Framework type field */
public static final String FRAMEWORK_TYPE = "framework_type";
/** All config field */
public static final String ALL_CONFIG = "all_config";
/** Version field */
public static final String VERSION_FIELD = "version";
/** Connector protocol field */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL
}

}, e -> {
String errorMessage = "Failed to create global_context index";
String errorMessage = "Failed to create workflow_state index";
logger.error(errorMessage, e);
listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e)));
}));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.TASK_ID;

/**
* Step to retrieve an ML Task
*/
public class GetMLTaskStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(GetMLTaskStep.class);
private MachineLearningNodeClient mlClient;
static final String NAME = "get_ml_task";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public GetMLTaskStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> getMLTaskFuture = new CompletableFuture<>();

ActionListener<MLTask> actionListener = ActionListener.wrap(response -> {

// TODO : Add retry capability if response status is not COMPLETED :
// https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/158

logger.info("ML Task retrieval successful");
getMLTaskFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(MODEL_ID, response.getModelId()), Map.entry(REGISTER_MODEL_STATUS, response.getState().name()))
)
);
}, exception -> {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
});

String taskId = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();
for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case TASK_ID:
taskId = (String) content.get(TASK_ID);
break;
default:
break;
}
}
}

if (taskId == null) {
logger.error("Failed to retrieve ML Task");
getMLTaskFuture.completeExceptionally(new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST));
} else {
mlClient.getTask(taskId, actionListener);
}

return getMLTaskFuture;
}

@Override
public String getName() {
return NAME;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
*/
public class ModelGroupStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);
private static final Logger logger = LogManager.getLogger(ModelGroupStep.class);

private MachineLearningNodeClient mlClient;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig.TextEmbeddingModelConfigBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION;
import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE;
import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT;
import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID;
import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE;
import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
import static org.opensearch.flowframework.common.CommonValue.TASK_ID;
import static org.opensearch.flowframework.common.CommonValue.URL;
import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD;

/**
* Step to register a local model
*/
public class RegisterLocalModelStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterLocalModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "register_local_model";

/**
* Instantiate this class
* @param mlClient client to instantiate MLClient
*/
public RegisterLocalModelStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> registerLocalModelFuture = new CompletableFuture<>();

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
logger.info("Local Model registration task creation successful");
registerLocalModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry(TASK_ID, mlRegisterModelResponse.getTaskId()),
Map.entry(REGISTER_MODEL_STATUS, mlRegisterModelResponse.getStatus())
)
)
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to register local model");
registerLocalModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

String modelName = null;
String modelVersion = null;
String description = null;
MLModelFormat modelFormat = null;
String modelGroupId = null;
String modelContentHashValue = null;
String modelType = null;
String embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
String url = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelName = (String) content.get(NAME_FIELD);
break;
case VERSION_FIELD:
modelVersion = (String) content.get(VERSION_FIELD);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case MODEL_FORMAT:
modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT));
break;
case MODEL_GROUP_ID:
modelGroupId = (String) content.get(MODEL_GROUP_ID);
break;
case MODEL_TYPE:
modelType = (String) content.get(MODEL_TYPE);
break;
case EMBEDDING_DIMENSION:
embeddingDimension = (String) content.get(EMBEDDING_DIMENSION);
break;
case FRAMEWORK_TYPE:
frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE));
break;
case ALL_CONFIG:
allConfig = (String) content.get(ALL_CONFIG);
break;
case MODEL_CONTENT_HASH_VALUE:
modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE);
break;
case URL:
url = (String) content.get(URL);
break;
default:
break;

}
}
}

if (Stream.of(
modelName,
modelVersion,
modelFormat,
modelGroupId,
modelType,
embeddingDimension,
frameworkType,
modelContentHashValue,
url
).allMatch(x -> x != null)) {

// Create Model configudation
TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder()
.modelType(modelType)
.embeddingDimension(Integer.valueOf(embeddingDimension))
.frameworkType(frameworkType);
if (allConfig != null) {
modelConfigBuilder.allConfig(allConfig);
}
MLModelConfig modelConfig = modelConfigBuilder.build();

// Create register local model input
MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder()
.modelName(modelName)
.version(modelVersion)
.modelFormat(modelFormat)
.modelGroupId(modelGroupId)
.hashValue(modelContentHashValue)
.modelConfig(modelConfig)
.url(url);
if (description != null) {
mlInputBuilder.description(description);
}

MLRegisterModelInput mlInput = mlInputBuilder.build();

mlClient.register(mlInput, actionListener);
} else {
registerLocalModelFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
}

return registerLocalModelFuture;
}

@Override
public String getName() {
return NAME;
}
}
Loading

0 comments on commit 2142874

Please sign in to comment.