diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index b11fe7afc2..da39166742 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import java.time.Instant; import java.util.ArrayList; @@ -41,8 +42,9 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; -import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction; -import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.search.SearchHit; @@ -97,6 +99,14 @@ public void run() { // gather running model/tasks on nodes client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> { List responses = r.getNodes(); + if (r.failures() != null && r.failures().size() != 0) { + log + .debug( + "Received {} failures in the sync up response on nodes. Error messages are {}", + r.failures().size(), + r.failures().stream().map(Exception::getMessage).collect(Collectors.joining(", ")) + ); + } // key is model id, value is set of worker node ids Map> modelWorkerNodes = new HashMap<>(); // key is task id, value is set of worker node ids @@ -143,7 +153,6 @@ public void run() { if (modelWorkerNodes.containsKey(modelId) && expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) { // this model has expired in all the nodes - modelWorkerNodes.remove(modelId); modelsToUndeploy.add(modelId); } } @@ -168,37 +177,44 @@ public void run() { MLSyncUpInput syncUpInput = inputBuilder.build(); MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput); // sync up running model/tasks on nodes - client - .execute( - MLSyncUpAction.INSTANCE, - syncUpRequest, - ActionListener.wrap(re -> { log.debug("sync model routing job finished"); }, ex -> { - log.error("Failed to sync model routing", ex); - }) - ); - // Undeploy expired models - undeployExpiredModels(modelsToUndeploy, modelWorkerNodes); + client.execute(MLSyncUpAction.INSTANCE, syncUpRequest, ActionListener.wrap(re -> { + log.debug("sync model routing job finished"); + if (!modelsToUndeploy.isEmpty()) { + // Undeploy expired models + undeployExpiredModels(modelsToUndeploy, modelWorkerNodes, deployingModels); + return; + } + // refresh model status + mlIndicesHandler + .initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> { + log.error("Failed to init model index", e); + })); + }, ex -> { log.error("Failed to sync model routing", ex); })); + }, e -> { log.error("Failed to sync model routing", e); })); + } + + private void undeployExpiredModels( + Set expiredModels, + Map> modelWorkerNodes, + Map> deployingModels + ) { + String[] targetNodeIds = getAllNodes(clusterService); + MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest( + expiredModels.toArray(new String[expiredModels.size()]), + targetNodeIds + ); + + client.execute(MLUndeployModelsAction.INSTANCE, mlUndeployModelsRequest, ActionListener.wrap(r -> { + MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r.getResponse(); + if (mlUndeployModelNodesResponse.failures() != null && mlUndeployModelNodesResponse.failures().size() != 0) { + log.debug("Received failures in undeploying expired models", mlUndeployModelNodesResponse.failures()); + } - // refresh model status mlIndicesHandler .initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> { log.error("Failed to init model index", e); })); - }, e -> { log.error("Failed to sync model routing", e); })); - } - - private void undeployExpiredModels(Set expiredModels, Map> modelWorkerNodes) { - expiredModels.forEach(modelId -> { - String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]); - - MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest( - targetNodeIds, - new String[] { modelId } - ); - client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { - log.debug("model {} is un_deployed", modelId); - }, e -> { log.error("Failed to undeploy model {}", modelId, e); })); - }); + }, e -> { log.error("Failed to undeploy models {}", expiredModels, e); })); } @VisibleForTesting diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java index c895163e1c..0cc30752df 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUndeployModelAction.java @@ -9,16 +9,14 @@ import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getAllNodes; import java.io.IOException; -import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.Locale; import org.apache.commons.lang3.ArrayUtils; import org.opensearch.client.node.NodeClient; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentParser; @@ -102,24 +100,15 @@ MLUndeployModelsRequest getRequest(RestRequest request) throws IOException { } targetNodeIds = nodeIds; } else { - targetNodeIds = getAllNodes(); + targetNodeIds = getAllNodes(clusterService); } if (ArrayUtils.isNotEmpty(modelIds)) { targetModelIds = modelIds; } } else { - targetNodeIds = getAllNodes(); + targetNodeIds = getAllNodes(clusterService); } return new MLUndeployModelsRequest(targetModelIds, targetNodeIds); } - - private String[] getAllNodes() { - Iterator iterator = clusterService.state().nodes().iterator(); - List nodeIds = new ArrayList<>(); - while (iterator.hasNext()) { - nodeIds.add(iterator.next().getId()); - } - return nodeIds.toArray(new String[0]); - } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 4841cb2b35..101d9c9244 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -263,7 +263,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener