Skip to content

Commit

Permalink
auto deploy remote model
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Mar 18, 2024
1 parent bf0a595 commit f7530e8
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,12 @@ public static FunctionName from(String value) {
public static boolean isDLModel(FunctionName functionName) {
return DL_MODELS.contains(functionName);
}

public static boolean needDeployFirst(FunctionName functionName) {
return DL_MODELS.contains(functionName) || functionName == REMOTE;
}

public static boolean isAutoDeployEnabled(boolean autoDeploymentEnabled, FunctionName functionName) {
return autoDeploymentEnabled && functionName == FunctionName.REMOTE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ private void deployModel(
try {
log.debug("start deploying model {}", modelId);
mlModelManager
.deployModel(modelId, modelContentHash, functionName, deployToAllNodes, mlTask, ActionListener.runBefore(listener, () -> {
.deployModel(modelId, modelContentHash, functionName, deployToAllNodes, false, mlTask, ActionListener.runBefore(listener, () -> {
if (!coordinatingNodeId.equals(localNodeId)) {
mlTaskManager.remove(mlTask.getTaskId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -163,7 +165,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
updateFields.put(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD, 0);
}
log.info("deploy model done with state: {}, model id: {}", modelState, modelId);
mlModelManager.updateModel(modelId, updateFields);
ActionListener updateModelListener = ActionListener.<UpdateResponse>wrap(response -> {
if (response.status() == RestStatus.OK) {
log.debug("Updated ML model successfully: {}, model id: {}", response.status(), modelId);
} else {
log.error("Failed to update ML model {}, status: {}", modelId, response.status());
}
}, e -> { log.error("Failed to update ML model: " + modelId, e); });
mlModelManager.updateModel(modelId, updateFields, ActionListener.runBefore(updateModelListener, () -> {
mlModelManager.removeAutoDeployModel(modelId);
}));
}
listener.onResponse(new MLForwardResponse("ok", null));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
@Log4j2
public class MLModelCacheHelper {
private final Map<String, MLModelCache> modelCaches;
private final Map<String, MLModel> autoDeployModels;
private volatile Long maxRequestCount;

public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
this.modelCaches = new ConcurrentHashMap<>();
this.autoDeployModels = new ConcurrentHashMap<>();

maxRequestCount = ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MONITORING_REQUEST_COUNT, it -> maxRequestCount = it);
Expand Down Expand Up @@ -358,6 +360,7 @@ public void removeModel(String modelId) {
modelCache.clear();
modelCaches.remove(modelId);
}
autoDeployModels.remove(modelId);
}

/**
Expand Down Expand Up @@ -590,4 +593,18 @@ private MLModelCache getOrCreateModelCache(String modelId) {
return modelCaches.computeIfAbsent(modelId, it -> new MLModelCache());
}

public MLModel addModelToAutoDeployCache(String modelId, MLModel model) {
MLModel addedModel = autoDeployModels.computeIfAbsent(modelId, key -> model);
if (addedModel == model) {
log.info("Add model {} to auto deploy cache", modelId);
}
return addedModel;
}

public void removeAutoDeployModel(String modelId) {
MLModel removedModel = autoDeployModels.remove(modelId);
if (removedModel != null) {
log.info("Remove model {} from auto deploy cache", modelId);
}
}
}
20 changes: 17 additions & 3 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -934,16 +934,18 @@ public void deployModel(
String modelContentHash,
FunctionName functionName,
boolean deployToAllNodes,
boolean autoDeployModel,
MLTask mlTask,
ActionListener<String> listener
) {
log.debug("Auto deploy model : {}", autoDeployModel);
mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats.createModelCounterStatIfAbsent(modelId, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment();
List<String> workerNodes = mlTask.getWorkerNodes();
if (modelCacheHelper.isModelDeployed(modelId)) {
if (workerNodes != null && workerNodes.size() > 0) {
if (!autoDeployModel && workerNodes != null && workerNodes.size() > 0) {
log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId);
modelCacheHelper.setDeployToAllNodes(modelId, deployToAllNodes);
modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes);
Expand All @@ -958,8 +960,13 @@ public void deployModel(
int eligibleNodeCount = workerNodes.size();
modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> wrappedListener = ActionListener.runBefore(listener, context::restore);
checkAndAddRunningTask(mlTask, maxDeployTasksPerNode);
ActionListener<String> wrappedListener = ActionListener.runBefore(listener, () -> {
context.restore();
modelCacheHelper.removeAutoDeployModel(modelId);
});
if (!autoDeployModel) {
checkAndAddRunningTask(mlTask, maxDeployTasksPerNode);
}
this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> {
modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled());
if (FunctionName.REMOTE == mlModel.getAlgorithm()
Expand Down Expand Up @@ -1821,4 +1828,11 @@ public boolean isModelRunningOnNode(String modelId) {
return modelCacheHelper.isModelRunningOnNode(modelId);
}

public MLModel addModelToAutoDeployCache(String modelId, MLModel model) {
return modelCacheHelper.addModelToAutoDeployCache(modelId, model);
}

public void removeAutoDeployModel(String modelId) {
modelCacheHelper.removeAutoDeployModel(modelId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ public Collection<Object> createComponents(
xContentRegistry,
mlModelManager,
nodeHelper,
mlEngine
mlEngine,
settings
);
mlTrainAndPredictTaskRunner = new MLTrainAndPredictTaskRunner(
threadPool,
Expand Down Expand Up @@ -870,6 +871,7 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN,
MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES,
MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_SUCCESS_RATIO,
MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE = Setting
.boolSetting("plugins.ml_commons.model_auto_redeploy.enable", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE = Setting
.boolSetting("plugins.ml_commons.model_auto_deploy.enable", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Integer> ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES = Setting
.intSetting("plugins.ml_commons.model_auto_redeploy.lifetime_retry_times", 3, Setting.Property.NodeScope, Setting.Property.Dynamic);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import static org.opensearch.ml.permission.AccessController.checkUserPermissions;
import static org.opensearch.ml.permission.AccessController.getUserContext;
import static org.opensearch.ml.plugin.MachineLearningPlugin.PREDICT_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;

import java.time.Instant;
import java.util.Arrays;
import java.util.UUID;

import org.opensearch.OpenSearchException;
Expand All @@ -24,6 +26,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
Expand All @@ -44,6 +47,8 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.engine.MLEngine;
Expand Down Expand Up @@ -75,6 +80,7 @@ public class MLPredictTaskRunner extends MLTaskRunner<MLPredictionTaskRequest, M
private final MLModelManager mlModelManager;
private final DiscoveryNodeHelper nodeHelper;
private final MLEngine mlEngine;
private volatile boolean autoDeploymentEnabled;

public MLPredictTaskRunner(
ThreadPool threadPool,
Expand All @@ -88,7 +94,8 @@ public MLPredictTaskRunner(
NamedXContentRegistry xContentRegistry,
MLModelManager mlModelManager,
DiscoveryNodeHelper nodeHelper,
MLEngine mlEngine
MLEngine mlEngine,
Settings settings
) {
super(mlTaskManager, mlStats, nodeHelper, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
this.threadPool = threadPool;
Expand All @@ -99,6 +106,10 @@ public MLPredictTaskRunner(
this.mlModelManager = mlModelManager;
this.nodeHelper = nodeHelper;
this.mlEngine = mlEngine;
autoDeploymentEnabled = ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> autoDeploymentEnabled = it);
}

@Override
Expand Down Expand Up @@ -133,7 +144,31 @@ public void dispatchTask(
}, e -> { listener.onFailure(e); });
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true);
if (workerNodes == null || workerNodes.length == 0) {
if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) {
if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlModelManager.getModel(modelId, ActionListener.runBefore(ActionListener.wrap(model -> {
String[] planningWorkerNodes = model.getPlanningWorkerNodes();
MLModel modelBeingAutoDeployed = mlModelManager.addModelToAutoDeployCache(modelId, model);
if (modelBeingAutoDeployed == model) {
log.info("Automatically deploy model {}", modelId);
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, planningWorkerNodes, false, true, false);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(r -> {
log.info("Auto deployment action triggered for model {}", modelId);
}, e-> {
log.error("Auto deployment action failed for model " + modelId, e);
}));
}
if (planningWorkerNodes == null || planningWorkerNodes.length == 0) {
planningWorkerNodes = nodeHelper.getEligibleNodeIds(functionName);
}
mlTaskDispatcher.dispatchPredictTask(planningWorkerNodes, actionListener);
}, e -> {
log.error("Failed to get model " + modelId, e);
listener.onFailure(e);
}), context::restore));
}
return;
} else if (FunctionName.needDeployFirst(functionName)) {
listener
.onFailure(
new IllegalArgumentException(
Expand All @@ -144,6 +179,8 @@ public void dispatchTask(
} else {
workerNodes = nodeHelper.getEligibleNodeIds(functionName);
}
} else {
mlModelManager.removeAutoDeployModel(modelId);
}
mlTaskDispatcher.dispatchPredictTask(workerNodes, actionListener);
} catch (Exception e) {
Expand Down Expand Up @@ -210,7 +247,42 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
mlTask.setState(MLTaskState.RUNNING);
mlTaskManager.add(mlTask);

FunctionName algorithm = mlInput.getAlgorithm();
FunctionName functionName = mlInput.getFunctionName();

Predictable predictor = mlModelManager.getPredictor(modelId);
boolean modelReady = predictor != null && predictor.isModelReady();
if (!modelReady && FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) {
log.info("Auto deploy model {} to local node", modelId);
Instant now = Instant.now();
MLTask mlDeployTask = MLTask.builder()
.taskId(UUID.randomUUID().toString())
.functionName(functionName)
.async(false)
.taskType(MLTaskType.DEPLOY_MODEL)
.createTime(now)
.lastUpdateTime(now)
.state(MLTaskState.RUNNING)
.workerNodes(Arrays.asList(clusterService.localNode().getId()))
.build();
mlModelManager.deployModel(modelId,
null,
functionName,
false,
true,
mlDeployTask,
ActionListener.wrap(s -> {
runPredict(modelId, mlTask, mlInput, functionName, internalListener);
}, e -> {
log.error("Failed to auto deploy model " + modelId, e);
internalListener.onFailure(e);
}));
return;
}

runPredict(modelId, mlTask, mlInput, functionName, internalListener);
}

private void runPredict(String modelId, MLTask mlTask, MLInput mlInput, FunctionName algorithm, ActionListener<MLTaskResponse> internalListener) {
// run predict
if (modelId != null) {
Predictable predictor = mlModelManager.getPredictor(modelId);
Expand All @@ -233,7 +305,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
handlePredictFailure(mlTask, internalListener, e, false, modelId);
return;
}
} else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
} else if (FunctionName.needDeployFirst(algorithm)) {
throw new IllegalArgumentException("Model not ready to be used: " + modelId);
}

Expand Down

0 comments on commit f7530e8

Please sign in to comment.