From bca2b19ec6a34cf0c3fb360bbf0de08b3cd1f10e Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 18 Mar 2024 16:16:49 -0700 Subject: [PATCH] avoid throwing duplicate deploy model task Signed-off-by: Yaliang Wu --- .../opensearch/ml/model/MLModelCacheHelper.java | 17 ++++++++++++++--- .../org/opensearch/ml/model/MLModelManager.java | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 2af2763acd..b9c5defe11 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -50,18 +50,29 @@ public MLModelCacheHelper(ClusterService clusterService, Settings settings) { * @param state model state * @param functionName function name */ + public synchronized void initModelState( + String modelId, + MLModelState state, + FunctionName functionName, + List targetWorkerNodes, + boolean deployToAllNodes + ) { + initModelState(modelId, state, functionName, targetWorkerNodes, deployToAllNodes, false); + } + public synchronized void initModelState( String modelId, MLModelState state, FunctionName functionName, List targetWorkerNodes, - boolean deployToAllNodes + boolean deployToAllNodes, + boolean autoDeployModel ) { - if (isModelRunningOnNode(modelId)) { + if (!autoDeployModel && isModelRunningOnNode(modelId)) { throw new MLLimitExceededException("Duplicate deploy model task"); } log.debug("init model state for model {}, state: {}", modelId, state); - MLModelCache modelCache = new MLModelCache(); + MLModelCache modelCache = autoDeployModel ? modelCaches.computeIfAbsent(modelId, key -> new MLModelCache()) : new MLModelCache(); modelCache.setModelState(state); modelCache.setFunctionName(functionName); modelCache.setTargetWorkerNodes(targetWorkerNodes); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 6396e58e74..663777d27c 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -958,7 +958,7 @@ public void deployModel( return; } int eligibleNodeCount = workerNodes.size(); - modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); + modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes, autoDeployModel); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, () -> { context.restore();