diff --git a/common/src/main/java/org/opensearch/ml/common/FunctionName.java b/common/src/main/java/org/opensearch/ml/common/FunctionName.java index 2c02b3e13d..a2c900f6cc 100644 --- a/common/src/main/java/org/opensearch/ml/common/FunctionName.java +++ b/common/src/main/java/org/opensearch/ml/common/FunctionName.java @@ -52,4 +52,12 @@ public static FunctionName from(String value) { public static boolean isDLModel(FunctionName functionName) { return DL_MODELS.contains(functionName); } + + public static boolean needDeployFirst(FunctionName functionName) { + return DL_MODELS.contains(functionName) || functionName == REMOTE; + } + + public static boolean isAutoDeployEnabled(boolean autoDeploymentEnabled, FunctionName functionName) { + return autoDeploymentEnabled && functionName == FunctionName.REMOTE; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java index bf8c81756b..665eed4079 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelOnNodeAction.java @@ -222,7 +222,7 @@ private void deployModel( try { log.debug("start deploying model {}", modelId); mlModelManager - .deployModel(modelId, modelContentHash, functionName, deployToAllNodes, mlTask, ActionListener.runBefore(listener, () -> { + .deployModel(modelId, modelContentHash, functionName, deployToAllNodes, false, mlTask, ActionListener.runBefore(listener, () -> { if (!coordinatingNodeId.equals(localNodeId)) { mlTaskManager.remove(mlTask.getTaskId()); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java index 01dfa690bb..276ce1774e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/forward/TransportForwardAction.java @@ -20,12 +20,14 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; @@ -163,7 +165,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListenerwrap(response -> { + if (response.status() == RestStatus.OK) { + log.debug("Updated ML model successfully: {}, model id: {}", response.status(), modelId); + } else { + log.error("Failed to update ML model {}, status: {}", modelId, response.status()); + } + }, e -> { log.error("Failed to update ML model: " + modelId, e); }); + mlModelManager.updateModel(modelId, updateFields, ActionListener.runBefore(updateModelListener, () -> { + mlModelManager.removeAutoDeployModel(modelId); + })); } listener.onResponse(new MLForwardResponse("ok", null)); break; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 99ccc9cce1..2af2763acd 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -32,10 +32,12 @@ @Log4j2 public class MLModelCacheHelper { private final Map modelCaches; + private final Map autoDeployModels; private volatile Long maxRequestCount; public MLModelCacheHelper(ClusterService clusterService, Settings settings) { this.modelCaches = new ConcurrentHashMap<>(); + this.autoDeployModels = new ConcurrentHashMap<>(); maxRequestCount = ML_COMMONS_MONITORING_REQUEST_COUNT.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MONITORING_REQUEST_COUNT, it -> maxRequestCount = it); @@ -358,6 +360,7 @@ public void removeModel(String modelId) { modelCache.clear(); modelCaches.remove(modelId); } + autoDeployModels.remove(modelId); } /** @@ -590,4 +593,18 @@ private MLModelCache getOrCreateModelCache(String modelId) { return modelCaches.computeIfAbsent(modelId, it -> new MLModelCache()); } + public MLModel addModelToAutoDeployCache(String modelId, MLModel model) { + MLModel addedModel = autoDeployModels.computeIfAbsent(modelId, key -> model); + if (addedModel == model) { + log.info("Add model {} to auto deploy cache", modelId); + } + return addedModel; + } + + public void removeAutoDeployModel(String modelId) { + MLModel removedModel = autoDeployModels.remove(modelId); + if (removedModel != null) { + log.info("Remove model {} from auto deploy cache", modelId); + } + } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 24f7375934..6396e58e74 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -934,16 +934,18 @@ public void deployModel( String modelContentHash, FunctionName functionName, boolean deployToAllNodes, + boolean autoDeployModel, MLTask mlTask, ActionListener listener ) { + log.debug("Auto deploy model : {}", autoDeployModel); mlStats.createCounterStatIfAbsent(functionName, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); mlStats.createModelCounterStatIfAbsent(modelId, ActionName.DEPLOY, ML_ACTION_REQUEST_COUNT).increment(); List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelDeployed(modelId)) { - if (workerNodes != null && workerNodes.size() > 0) { + if (!autoDeployModel && workerNodes != null && workerNodes.size() > 0) { log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId); modelCacheHelper.setDeployToAllNodes(modelId, deployToAllNodes); modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes); @@ -958,8 +960,13 @@ public void deployModel( int eligibleNodeCount = workerNodes.size(); modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> { + context.restore(); + modelCacheHelper.removeAutoDeployModel(modelId); + }); + if (!autoDeployModel) { + checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); + } this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); if (FunctionName.REMOTE == mlModel.getAlgorithm() @@ -1821,4 +1828,11 @@ public boolean isModelRunningOnNode(String modelId) { return modelCacheHelper.isModelRunningOnNode(modelId); } + public MLModel addModelToAutoDeployCache(String modelId, MLModel model) { + return modelCacheHelper.addModelToAutoDeployCache(modelId, model); + } + + public void removeAutoDeployModel(String modelId) { + modelCacheHelper.removeAutoDeployModel(modelId); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 57423da475..b5e435680e 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -523,7 +523,8 @@ public Collection createComponents( xContentRegistry, mlModelManager, nodeHelper, - mlEngine + mlEngine, + settings ); mlTrainAndPredictTaskRunner = new MLTrainAndPredictTaskRunner( threadPool, @@ -870,6 +871,7 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN, MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE, + MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES, MLCommonsSettings.ML_COMMONS_MODEL_AUTO_REDEPLOY_SUCCESS_RATIO, MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL, diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 33c5b4b554..3d2e841a91 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -85,6 +85,9 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_MODEL_AUTO_REDEPLOY_ENABLE = Setting .boolSetting("plugins.ml_commons.model_auto_redeploy.enable", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE = Setting + .boolSetting("plugins.ml_commons.model_auto_deploy.enable", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_MODEL_AUTO_REDEPLOY_LIFETIME_RETRY_TIMES = Setting .intSetting("plugins.ml_commons.model_auto_redeploy.lifetime_retry_times", 3, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 92e05a5ba9..cd09710c65 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -11,8 +11,10 @@ import static org.opensearch.ml.permission.AccessController.checkUserPermissions; import static org.opensearch.ml.permission.AccessController.getUserContext; import static org.opensearch.ml.plugin.MachineLearningPlugin.PREDICT_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE; import java.time.Instant; +import java.util.Arrays; import java.util.UUID; import org.opensearch.OpenSearchException; @@ -24,6 +26,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; @@ -44,6 +47,8 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; +import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.engine.MLEngine; @@ -75,6 +80,7 @@ public class MLPredictTaskRunner extends MLTaskRunner autoDeploymentEnabled = it); } @Override @@ -133,7 +144,31 @@ public void dispatchTask( }, e -> { listener.onFailure(e); }); String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true); if (workerNodes == null || workerNodes.length == 0) { - if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) { + if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlModelManager.getModel(modelId, ActionListener.runBefore(ActionListener.wrap(model -> { + String[] planningWorkerNodes = model.getPlanningWorkerNodes(); + MLModel modelBeingAutoDeployed = mlModelManager.addModelToAutoDeployCache(modelId, model); + if (modelBeingAutoDeployed == model) { + log.info("Automatically deploy model {}", modelId); + MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, planningWorkerNodes, false, true, false); + client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, ActionListener.wrap(r -> { + log.info("Auto deployment action triggered for model {}", modelId); + }, e-> { + log.error("Auto deployment action failed for model " + modelId, e); + })); + } + if (planningWorkerNodes == null || planningWorkerNodes.length == 0) { + planningWorkerNodes = nodeHelper.getEligibleNodeIds(functionName); + } + mlTaskDispatcher.dispatchPredictTask(planningWorkerNodes, actionListener); + }, e -> { + log.error("Failed to get model " + modelId, e); + listener.onFailure(e); + }), context::restore)); + } + return; + } else if (FunctionName.needDeployFirst(functionName)) { listener .onFailure( new IllegalArgumentException( @@ -144,6 +179,8 @@ public void dispatchTask( } else { workerNodes = nodeHelper.getEligibleNodeIds(functionName); } + } else { + mlModelManager.removeAutoDeployModel(modelId); } mlTaskDispatcher.dispatchPredictTask(workerNodes, actionListener); } catch (Exception e) { @@ -210,7 +247,42 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe mlTask.setState(MLTaskState.RUNNING); mlTaskManager.add(mlTask); - FunctionName algorithm = mlInput.getAlgorithm(); + FunctionName functionName = mlInput.getFunctionName(); + + Predictable predictor = mlModelManager.getPredictor(modelId); + boolean modelReady = predictor != null && predictor.isModelReady(); + if (!modelReady && FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) { + log.info("Auto deploy model {} to local node", modelId); + Instant now = Instant.now(); + MLTask mlDeployTask = MLTask.builder() + .taskId(UUID.randomUUID().toString()) + .functionName(functionName) + .async(false) + .taskType(MLTaskType.DEPLOY_MODEL) + .createTime(now) + .lastUpdateTime(now) + .state(MLTaskState.RUNNING) + .workerNodes(Arrays.asList(clusterService.localNode().getId())) + .build(); + mlModelManager.deployModel(modelId, + null, + functionName, + false, + true, + mlDeployTask, + ActionListener.wrap(s -> { + runPredict(modelId, mlTask, mlInput, functionName, internalListener); + }, e -> { + log.error("Failed to auto deploy model " + modelId, e); + internalListener.onFailure(e); + })); + return; + } + + runPredict(modelId, mlTask, mlInput, functionName, internalListener); + } + + private void runPredict(String modelId, MLTask mlTask, MLInput mlInput, FunctionName algorithm, ActionListener internalListener) { // run predict if (modelId != null) { Predictable predictor = mlModelManager.getPredictor(modelId); @@ -233,7 +305,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe handlePredictFailure(mlTask, internalListener, e, false, modelId); return; } - } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + } else if (FunctionName.needDeployFirst(algorithm)) { throw new IllegalArgumentException("Model not ready to be used: " + modelId); }