Skip to content

Commit

Permalink
avoid race condition in syncup model state refresh and handle NP of I…
Browse files Browse the repository at this point in the history
…sAutoDeployEnabled

Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed May 3, 2024
1 parent 950f864 commit b19c415
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 44 deletions.
69 changes: 40 additions & 29 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

import java.time.Instant;
import java.util.ArrayList;
Expand Down Expand Up @@ -41,8 +42,9 @@
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -97,6 +99,9 @@ public void run() {
// gather running model/tasks on nodes
client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> {
List<MLSyncUpNodeResponse> responses = r.getNodes();
if (r.failures() != null && r.failures().size() != 0) {
log.debug("Received {} failures in the sync up response on nodes", r.failures().size());
}
// key is model id, value is set of worker node ids
Map<String, Set<String>> modelWorkerNodes = new HashMap<>();
// key is task id, value is set of worker node ids
Expand Down Expand Up @@ -143,7 +148,6 @@ public void run() {
if (modelWorkerNodes.containsKey(modelId)
&& expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) {
// this model has expired in all the nodes
modelWorkerNodes.remove(modelId);
modelsToUndeploy.add(modelId);
}
}
Expand All @@ -168,37 +172,44 @@ public void run() {
MLSyncUpInput syncUpInput = inputBuilder.build();
MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, syncUpInput);
// sync up running model/tasks on nodes
client
.execute(
MLSyncUpAction.INSTANCE,
syncUpRequest,
ActionListener.wrap(re -> { log.debug("sync model routing job finished"); }, ex -> {
log.error("Failed to sync model routing", ex);
})
);
// Undeploy expired models
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes);
client.execute(MLSyncUpAction.INSTANCE, syncUpRequest, ActionListener.wrap(re -> {
log.debug("sync model routing job finished");
if (!modelsToUndeploy.isEmpty()) {
// Undeploy expired models
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes, deployingModels);
return;
}
// refresh model status
mlIndicesHandler
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
log.error("Failed to init model index", e);
}));
}, ex -> { log.error("Failed to sync model routing", ex); }));
}, e -> { log.error("Failed to sync model routing", e); }));
}

private void undeployExpiredModels(
Set<String> expiredModels,
Map<String, Set<String>> modelWorkerNodes,
Map<String, Set<String>> deployingModels
) {
String[] targetNodeIds = getAllNodes(clusterService);
MLUndeployModelsRequest mlUndeployModelsRequest = new MLUndeployModelsRequest(
expiredModels.toArray(new String[expiredModels.size()]),
targetNodeIds
);

client.execute(MLUndeployModelsAction.INSTANCE, mlUndeployModelsRequest, ActionListener.wrap(r -> {
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = r.getResponse();
if (mlUndeployModelNodesResponse.failures() != null && mlUndeployModelNodesResponse.failures().size() != 0) {
log.debug("Received failures in undeploying expired models", mlUndeployModelNodesResponse.failures());
}

// refresh model status
mlIndicesHandler
.initModelIndexIfAbsent(ActionListener.wrap(res -> { refreshModelState(modelWorkerNodes, deployingModels); }, e -> {
log.error("Failed to init model index", e);
}));
}, e -> { log.error("Failed to sync model routing", e); }));
}

private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes) {
expiredModels.forEach(modelId -> {
String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]);

MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(
targetNodeIds,
new String[] { modelId }
);
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
log.debug("model {} is un_deployed", modelId);
}, e -> { log.error("Failed to undeploy model {}", modelId, e); }));
});
}, e -> { log.error("Failed to undeploy models {}", expiredModels, e); }));
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;

import org.apache.commons.lang3.ArrayUtils;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentParser;
Expand Down Expand Up @@ -102,24 +100,15 @@ MLUndeployModelsRequest getRequest(RestRequest request) throws IOException {
}
targetNodeIds = nodeIds;
} else {
targetNodeIds = getAllNodes();
targetNodeIds = getAllNodes(clusterService);
}
if (ArrayUtils.isNotEmpty(modelIds)) {
targetModelIds = modelIds;
}
} else {
targetNodeIds = getAllNodes();
targetNodeIds = getAllNodes(clusterService);
}

return new MLUndeployModelsRequest(targetModelIds, targetNodeIds);
}

private String[] getAllNodes() {
Iterator<DiscoveryNode> iterator = clusterService.state().nodes().iterator();
List<String> nodeIds = new ArrayList<>();
while (iterator.hasNext()) {
nodeIds.add(iterator.next().getId());
}
return nodeIds.toArray(new String[0]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
}

private boolean checkModelAutoDeployEnabled(MLModel mlModel) {
if (mlModel.getDeploySetting() == null) {
if (mlModel.getDeploySetting() == null || mlModel.getDeploySetting().getIsAutoDeployEnabled() == null) {
return true;
}
return mlModel.getDeploySetting().getIsAutoDeployEnabled();
Expand Down

0 comments on commit b19c415

Please sign in to comment.