Skip to content

Commit

Permalink
avoid throwing duplicate deploy model task
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Mar 18, 2024
1 parent f7530e8 commit bca2b19
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,29 @@ public MLModelCacheHelper(ClusterService clusterService, Settings settings) {
* @param state model state
* @param functionName function name
*/
public synchronized void initModelState(
String modelId,
MLModelState state,
FunctionName functionName,
List<String> targetWorkerNodes,
boolean deployToAllNodes
) {
initModelState(modelId, state, functionName, targetWorkerNodes, deployToAllNodes, false);
}

public synchronized void initModelState(
String modelId,
MLModelState state,
FunctionName functionName,
List<String> targetWorkerNodes,
boolean deployToAllNodes
boolean deployToAllNodes,
boolean autoDeployModel
) {
if (isModelRunningOnNode(modelId)) {
if (!autoDeployModel && isModelRunningOnNode(modelId)) {
throw new MLLimitExceededException("Duplicate deploy model task");
}
log.debug("init model state for model {}, state: {}", modelId, state);
MLModelCache modelCache = new MLModelCache();
MLModelCache modelCache = autoDeployModel ? modelCaches.computeIfAbsent(modelId, key -> new MLModelCache()) : new MLModelCache();
modelCache.setModelState(state);
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ public void deployModel(
return;
}
int eligibleNodeCount = workerNodes.size();
modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes);
modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes, autoDeployModel);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> wrappedListener = ActionListener.runBefore(listener, () -> {
context.restore();
Expand Down

0 comments on commit bca2b19

Please sign in to comment.