Skip to content

Commit

Permalink
[Backport to main] add ML task timeout setting and clean up expired t…
Browse files Browse the repository at this point in the history
…asks from cache (opensearch-project#662) (opensearch-project#770)

* add ML task timeout setting and clean up expired tasks from cache (opensearch-project#662)

* add ML task timeout setting and clean up expired tasks from cache

Signed-off-by: Yaliang Wu <[email protected]>

* add log for corner case

Signed-off-by: Yaliang Wu <[email protected]>

* rollback setting name change to avoid breaking bwc

Signed-off-by: Yaliang Wu <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>

* fix code format

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Mar 2, 2023
1 parent c56ca85 commit cc126e3
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ protected void loadTextEmbeddingModel(File modelZipFile, String modelId, String
}
log.info("Load model {} successfully on {} devices", modelId, devices.length);
return null;
} catch (Exception e) {
} catch (Throwable e) {
String errorMessage = "Failed to load model " + modelId;
log.error(errorMessage, e);
close();
Expand Down Expand Up @@ -296,7 +296,7 @@ protected ModelTensorOutput predictTextEmbedding(String modelId, MLInputDataset
}
return new ModelTensorOutput(tensorOutputs);
});
} catch (PrivilegedActionException e) {
} catch (Throwable e) {
String errorMsg = "Failed to inference text embedding";
log.error(errorMsg, e);
throw new MLException(errorMsg, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.text_embedding;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -268,8 +269,9 @@ public void initModel_WrongModelFile() throws URISyntaxException {
textEmbeddingModel.initModel(model, params);
} catch (Exception e) {
assertEquals(MLException.class, e.getClass());
assertEquals(IllegalArgumentException.class, e.getCause().getClass());
assertEquals("found multiple models", e.getCause().getMessage());
Throwable rootCause = ExceptionUtils.getRootCause(e);
assertEquals(IllegalArgumentException.class, rootCause.getClass());
assertEquals("found multiple models", rootCause.getMessage());
}
}

Expand Down Expand Up @@ -311,7 +313,7 @@ public void predict_NullModelId() {
@Test
public void predict_AfterModelClosed() {
exceptionRule.expect(MLException.class);
exceptionRule.expectMessage("model not loaded");
exceptionRule.expectMessage("Failed to inference text embedding");
textEmbeddingModel.initModel(model, params);
textEmbeddingModel.close();
textEmbeddingModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -128,6 +129,22 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
nodeIds.add(nodeId);
}
}
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
if (workerNodes != null && workerNodes.length > 0) {
Set<String> difference = new HashSet<String>(Arrays.asList(workerNodes));
difference.removeAll(Arrays.asList(targetNodeIds));
if (difference.size() > 0) {
listener
.onFailure(
new IllegalArgumentException(
"Model already deployed to these nodes: "
+ Arrays.toString(difference.toArray(new String[0]))
+ ", but they are not included in target node ids. Unload model from these nodes if don't need them any more."
)
);
return;
}
}
} else {
nodeIds.addAll(allEligibleNodeIds);
eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
package org.opensearch.ml.action.syncup;

import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS;

import java.io.IOException;
import java.nio.file.Path;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -23,7 +25,13 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest;
Expand All @@ -34,10 +42,14 @@
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

@Log4j2
public class TransportSyncUpOnNodeAction extends
TransportNodesAction<MLSyncUpNodesRequest, MLSyncUpNodesResponse, MLSyncUpNodeRequest, MLSyncUpNodeResponse> {
Expand All @@ -51,9 +63,12 @@ public class TransportSyncUpOnNodeAction extends
NamedXContentRegistry xContentRegistry;
MLEngine mlEngine;

private volatile Integer mlTaskTimeout;

@Inject
public TransportSyncUpOnNodeAction(
TransportService transportService,
Settings settings,
ActionFilters actionFilters,
ModelHelper modelHelper,
MLTaskManager mlTaskManager,
Expand Down Expand Up @@ -84,6 +99,9 @@ public TransportSyncUpOnNodeAction(
this.client = client;
this.xContentRegistry = xContentRegistry;
this.mlEngine = mlEngine;

this.mlTaskTimeout = ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS, it -> { mlTaskTimeout = it; });
}

@Override
Expand Down Expand Up @@ -148,11 +166,62 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest loadM
mlTaskManager.syncRunningLoadModelTasks(runningLoadModelTasks);
}

cleanUpLocalCache();
cleanUpLocalCacheFiles();

return new MLSyncUpNodeResponse(clusterService.localNode(), "ok", loadedModelIds, runningLoadModelTaskIds);
}

@VisibleForTesting
void cleanUpLocalCache() {
String[] allTaskIds = mlTaskManager.getAllTaskIds();
if (allTaskIds == null) {
return;
}
for (String taskId : allTaskIds) {
MLTaskCache mlTaskCache = mlTaskManager.getMLTaskCache(taskId);
MLTask mlTask = mlTaskCache.getMlTask();
Instant lastUpdateTime = mlTask.getLastUpdateTime();
Instant now = Instant.now();
if (now.isAfter(lastUpdateTime.plusSeconds(mlTaskTimeout))) {
log.info("ML task timeout. task id: {}, task type: {}", taskId, mlTask.getTaskType());
mlTaskManager
.updateMLTask(
taskId,
ImmutableMap
.of(MLTask.STATE_FIELD, MLTaskState.FAILED, MLTask.ERROR_FIELD, "timeout after " + mlTaskTimeout + " seconds"),
10_000,
true
);

if (mlTask.getTaskType() == MLTaskType.LOAD_MODEL) {
String modelId = mlTask.getModelId();
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
MLModelState modelState;
if (workerNodes == null || workerNodes.length == 0) {
modelState = MLModelState.LOAD_FAILED;
} else if (mlTask.getWorkerNodes().size() > workerNodes.length) {
modelState = MLModelState.PARTIALLY_LOADED;
} else {
modelState = MLModelState.LOADED;
if (mlTask.getWorkerNodes().size() < workerNodes.length) {
log
.warn(
"Model loaded on more nodes than target worker nodes. taskId:{}, modelId: {}, workerNodes: {}, targetWorkerNodes: {}",
taskId,
modelId,
Arrays.toString(workerNodes),
Arrays.toString(mlTask.getWorkerNodes().toArray(new String[0]))
);
}
}
log.info("Reset model state as {} for model {}", modelState, modelId);
mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, modelState));
}
}
}
}

private void cleanUpLocalCacheFiles() {
Path uploadModelRootPath = mlEngine.getUploadModelRootPath();
Path loadModelRootPath = mlEngine.getLoadModelRootPath();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ public Predictable getPredictor(String modelId) {
return modelCache.getPredictor();
}

/**
* Set target worker nodes of model.
* @param modelId model id
* @param targetWorkerNodes target worker nodes of model
*/
public void setTargetWorkerNodes(String modelId, List<String> targetWorkerNodes) {
MLModelCache modelCache = modelCaches.get(modelId);
if (modelCache != null) {
modelCache.setTargetWorkerNodes(targetWorkerNodes);
}
}

/**
* Remove model.
* @param modelId model id
Expand Down
22 changes: 21 additions & 1 deletion plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.Files;

/**
Expand Down Expand Up @@ -129,6 +130,16 @@ public class MLModelManager {
private volatile Integer maxUploadTasksPerNode;
private volatile Integer maxLoadTasksPerNode;

public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet
.of(
MLModelState.TRAINED,
MLModelState.UPLOADED,
MLModelState.LOADED,
MLModelState.PARTIALLY_LOADED,
MLModelState.LOAD_FAILED,
MLModelState.UNLOADED
);

public MLModelManager(
ClusterService clusterService,
Client client,
Expand Down Expand Up @@ -422,15 +433,20 @@ public void loadModel(
ActionListener<String> listener
) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, ML_ACTION_REQUEST_COUNT).increment();
List<String> workerNodes = mlTask.getWorkerNodes();
if (modelCacheHelper.isModelLoaded(modelId)) {
if (workerNodes != null && workerNodes.size() > 0) {
log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId);
modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes);
}
listener.onResponse("successful");
return;
}
if (modelCacheHelper.getLoadedModels().length >= maxModelPerNode) {
listener.onFailure(new IllegalArgumentException("Exceed max model per node limit"));
return;
}
modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, mlTask.getWorkerNodes());
modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, workerNodes);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
checkAndAddRunningTask(mlTask, maxLoadTasksPerNode);
this.getModel(modelId, threadedActionListener(LOAD_THREAD_POOL, ActionListener.wrap(mlModel -> {
Expand Down Expand Up @@ -596,6 +612,10 @@ public void updateModel(String modelId, Map<String, Object> updatedFields, Actio
UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId);
updateRequest.doc(updatedFields);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (updatedFields.containsKey(MLModel.MODEL_STATE_FIELD)
&& MODEL_DONE_STATES.contains(updatedFields.get(MLModel.MODEL_STATE_FIELD))) {
updateRequest.retryOnConflict(3);
}
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE,
MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE,
MLCommonsSettings.ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS,
MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS,
MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT,
MLCommonsSettings.ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE,
MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ private MLCommonsSettings() {}
.simpleString("plugins.ml_commons.task_dispatch_policy", "round_robin", Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Integer> ML_COMMONS_MAX_MODELS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_model_on_node", 10, 0, 1000, Setting.Property.NodeScope, Setting.Property.Dynamic);
.intSetting("plugins.ml_commons.max_model_on_node", 10, 0, 10000, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_upload_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_load_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting
Expand All @@ -32,6 +34,8 @@ private MLCommonsSettings() {}
Setting.Property.Dynamic
);

public static final Setting<Integer> ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS = Setting
.intSetting("plugins.ml_commons.ml_task_timeout_in_seconds", 600, 1, 86400, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Long> ML_COMMONS_MONITORING_REQUEST_COUNT = Setting
.longSetting(
"plugins.ml_commons.monitoring_request_count",
Expand All @@ -42,9 +46,6 @@ private MLCommonsSettings() {}
Setting.Property.Dynamic
);

public static final Setting<Integer> ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_upload_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<String> ML_COMMONS_TRUSTED_URL_REGEX = Setting
.simpleString(
"plugins.ml_commons.trusted_url_regex",
Expand Down
10 changes: 10 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

/**
* MLTaskManager is responsible for managing MLTask.
Expand All @@ -63,6 +64,9 @@ public class MLTaskManager {
private final MLIndicesHandler mlIndicesHandler;
private final Map<MLTaskType, AtomicInteger> runningTasksCount;

public static final ImmutableSet TASK_DONE_STATES = ImmutableSet
.of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED);

/**
* Constructor to create ML task manager.
*
Expand Down Expand Up @@ -320,6 +324,9 @@ public void updateMLTask(
updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli());
updateRequest.doc(updatedContent);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (updatedFields.containsKey(STATE_FIELD) && TASK_DONE_STATES.contains(updatedFields.containsKey(STATE_FIELD))) {
updateRequest.retryOnConflict(3);
}
ActionListener<UpdateResponse> actionListener = semaphore == null
? listener
: ActionListener.runAfter(listener, () -> semaphore.release());
Expand Down Expand Up @@ -360,6 +367,9 @@ public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedField
updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli());
updateRequest.doc(updatedContent);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (updatedFields.containsKey(STATE_FIELD) && TASK_DONE_STATES.contains(updatedFields.containsKey(STATE_FIELD))) {
updateRequest.retryOnConflict(3);
}
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand Down
Loading

0 comments on commit cc126e3

Please sign in to comment.