Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ML task timeout setting and clean up expired tasks from cache #662

Merged
merged 3 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ protected void loadTextEmbeddingModel(File modelZipFile, String modelId, String
}
log.info("Load model {} successfully on {} devices", modelId, devices.length);
return null;
} catch (Exception e) {
} catch (Throwable e) {
String errorMessage = "Failed to load model " + modelId;
log.error(errorMessage, e);
close();
Expand Down Expand Up @@ -296,7 +296,7 @@ protected ModelTensorOutput predictTextEmbedding(String modelId, MLInputDataset
}
return new ModelTensorOutput(tensorOutputs);
});
} catch (PrivilegedActionException e) {
} catch (Throwable e) {
String errorMsg = "Failed to inference text embedding";
log.error(errorMsg, e);
throw new MLException(errorMsg, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.text_embedding;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -268,8 +269,9 @@ public void initModel_WrongModelFile() throws URISyntaxException {
textEmbeddingModel.initModel(model, params);
} catch (Exception e) {
assertEquals(MLException.class, e.getClass());
assertEquals(IllegalArgumentException.class, e.getCause().getClass());
assertEquals("found multiple models", e.getCause().getMessage());
Throwable rootCause = ExceptionUtils.getRootCause(e);
assertEquals(IllegalArgumentException.class, rootCause.getClass());
assertEquals("found multiple models", rootCause.getMessage());
}
}

Expand Down Expand Up @@ -311,7 +313,7 @@ public void predict_NullModelId() {
@Test
public void predict_AfterModelClosed() {
exceptionRule.expect(MLException.class);
exceptionRule.expectMessage("model not loaded");
exceptionRule.expectMessage("Failed to inference text embedding");
textEmbeddingModel.initModel(model, params);
textEmbeddingModel.close();
textEmbeddingModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -128,6 +129,22 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<LoadMo
nodeIds.add(nodeId);
}
}
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
if (workerNodes != null && workerNodes.length > 0) {
Set<String> difference = new HashSet<String>(Arrays.asList(workerNodes));
difference.removeAll(Arrays.asList(targetNodeIds));
if (difference.size() > 0) {
listener
.onFailure(
new IllegalArgumentException(
"Model already deployed to these nodes: "
+ Arrays.toString(difference.toArray(new String[0]))
+ ", but they are not included in target node ids. Unload model from these nodes if don't need them any more."
)
);
return;
}
}
} else {
nodeIds.addAll(allEligibleNodeIds);
eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
package org.opensearch.ml.action.syncup;

import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS;

import java.io.IOException;
import java.nio.file.Path;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -23,7 +25,13 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest;
Expand All @@ -34,10 +42,14 @@
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLTaskCache;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;

@Log4j2
public class TransportSyncUpOnNodeAction extends
TransportNodesAction<MLSyncUpNodesRequest, MLSyncUpNodesResponse, MLSyncUpNodeRequest, MLSyncUpNodeResponse> {
Expand All @@ -51,9 +63,12 @@ public class TransportSyncUpOnNodeAction extends
NamedXContentRegistry xContentRegistry;
MLEngine mlEngine;

private volatile Integer mlTaskTimeout;

@Inject
public TransportSyncUpOnNodeAction(
TransportService transportService,
Settings settings,
ActionFilters actionFilters,
ModelHelper modelHelper,
MLTaskManager mlTaskManager,
Expand Down Expand Up @@ -84,6 +99,9 @@ public TransportSyncUpOnNodeAction(
this.client = client;
this.xContentRegistry = xContentRegistry;
this.mlEngine = mlEngine;

this.mlTaskTimeout = ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS, it -> { mlTaskTimeout = it; });
}

@Override
Expand Down Expand Up @@ -148,11 +166,52 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest loadM
mlTaskManager.syncRunningLoadModelTasks(runningLoadModelTasks);
}

cleanUpLocalCache();
cleanUpLocalCacheFiles();

return new MLSyncUpNodeResponse(clusterService.localNode(), "ok", loadedModelIds, runningLoadModelTaskIds);
}

@VisibleForTesting
void cleanUpLocalCache() {
String[] allTaskIds = mlTaskManager.getAllTaskIds();
if (allTaskIds == null) {
return;
}
for (String taskId : allTaskIds) {
MLTaskCache mlTaskCache = mlTaskManager.getMLTaskCache(taskId);
MLTask mlTask = mlTaskCache.getMlTask();
Instant lastUpdateTime = mlTask.getLastUpdateTime();
Instant now = Instant.now();
if (now.isAfter(lastUpdateTime.plusSeconds(mlTaskTimeout))) {
log.info("ML task timeout. task id: {}, task type: {}", taskId, mlTask.getTaskType());
mlTaskManager
.updateMLTask(
taskId,
ImmutableMap
.of(MLTask.STATE_FIELD, MLTaskState.FAILED, MLTask.ERROR_FIELD, "timeout after " + mlTaskTimeout + " seconds"),
10_000,
true
);

if (mlTask.getTaskType() == MLTaskType.LOAD_MODEL) {
String modelId = mlTask.getModelId();
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
MLModelState modelState;
if (workerNodes == null || workerNodes.length == 0) {
modelState = MLModelState.LOAD_FAILED;
} else if (mlTask.getWorkerNodes().size() > workerNodes.length) {
modelState = MLModelState.PARTIALLY_LOADED;
} else {
modelState = MLModelState.LOADED;
}
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
log.info("Reset model state as {} for model {}", modelState, modelId);
mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, modelState));
}
}
}
}

private void cleanUpLocalCacheFiles() {
Path uploadModelRootPath = mlEngine.getUploadModelRootPath();
Path loadModelRootPath = mlEngine.getLoadModelRootPath();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ public Predictable getPredictor(String modelId) {
return modelCache.getPredictor();
}

/**
* Set target worker nodes of model.
* @param modelId model id
* @param targetWorkerNodes target worker nodes of model
*/
public void setTargetWorkerNodes(String modelId, List<String> targetWorkerNodes) {
MLModelCache modelCache = modelCaches.get(modelId);
if (modelCache != null) {
modelCache.setTargetWorkerNodes(targetWorkerNodes);
}
}

/**
* Remove model.
* @param modelId model id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.Files;

/**
Expand Down Expand Up @@ -129,6 +130,16 @@ public class MLModelManager {
private volatile Integer maxUploadTasksPerNode;
private volatile Integer maxLoadTasksPerNode;

public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet
.of(
MLModelState.TRAINED,
MLModelState.UPLOADED,
MLModelState.LOADED,
MLModelState.PARTIALLY_LOADED,
MLModelState.LOAD_FAILED,
MLModelState.UNLOADED
);

public MLModelManager(
ClusterService clusterService,
Client client,
Expand Down Expand Up @@ -422,15 +433,20 @@ public void loadModel(
ActionListener<String> listener
) {
mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, ML_ACTION_REQUEST_COUNT).increment();
List<String> workerNodes = mlTask.getWorkerNodes();
if (modelCacheHelper.isModelLoaded(modelId)) {
if (workerNodes != null && workerNodes.size() > 0) {
log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId);
modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes);
}
listener.onResponse("successful");
return;
}
if (modelCacheHelper.getLoadedModels().length >= maxModelPerNode) {
listener.onFailure(new IllegalArgumentException("Exceed max model per node limit"));
return;
}
modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, mlTask.getWorkerNodes());
modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, workerNodes);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
checkAndAddRunningTask(mlTask, maxLoadTasksPerNode);
this.getModel(modelId, threadedActionListener(LOAD_THREAD_POOL, ActionListener.wrap(mlModel -> {
Expand Down Expand Up @@ -596,6 +612,10 @@ public void updateModel(String modelId, Map<String, Object> updatedFields, Actio
UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId);
updateRequest.doc(updatedFields);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (updatedFields.containsKey(MLModel.MODEL_STATE_FIELD)
&& MODEL_DONE_STATES.contains(updatedFields.get(MLModel.MODEL_STATE_FIELD))) {
updateRequest.retryOnConflict(3);
}
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE,
MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE,
MLCommonsSettings.ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS,
MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS,
MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT,
MLCommonsSettings.ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE,
MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ private MLCommonsSettings() {}
.simpleString("plugins.ml_commons.task_dispatch_policy", "round_robin", Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Integer> ML_COMMONS_MAX_MODELS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_model_on_node", 10, 0, 1000, Setting.Property.NodeScope, Setting.Property.Dynamic);
.intSetting("plugins.ml_commons.max_loaded_models_per_node", 10, 0, 10000, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_upload_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_load_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting
Expand All @@ -32,6 +34,8 @@ private MLCommonsSettings() {}
Setting.Property.Dynamic
);

public static final Setting<Integer> ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS = Setting
.intSetting("plugins.ml_commons.ml_task_timeout_in_seconds", 600, 1, 86400, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Long> ML_COMMONS_MONITORING_REQUEST_COUNT = Setting
.longSetting(
"plugins.ml_commons.monitoring_request_count",
Expand All @@ -42,9 +46,6 @@ private MLCommonsSettings() {}
Setting.Property.Dynamic
);

public static final Setting<Integer> ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_upload_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<String> ML_COMMONS_TRUSTED_URL_REGEX = Setting
.simpleString(
"plugins.ml_commons.trusted_url_regex",
Expand Down
10 changes: 10 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

/**
* MLTaskManager is responsible for managing MLTask.
Expand All @@ -63,6 +64,9 @@ public class MLTaskManager {
private final MLIndicesHandler mlIndicesHandler;
private final Map<MLTaskType, AtomicInteger> runningTasksCount;

public static final ImmutableSet TASK_DONE_STATES = ImmutableSet
.of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED);

/**
* Constructor to create ML task manager.
*
Expand Down Expand Up @@ -320,6 +324,9 @@ public void updateMLTask(
updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli());
updateRequest.doc(updatedContent);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (updatedFields.containsKey(STATE_FIELD) && TASK_DONE_STATES.contains(updatedFields.containsKey(STATE_FIELD))) {
updateRequest.retryOnConflict(3);
}
ActionListener<UpdateResponse> actionListener = semaphore == null
? listener
: ActionListener.runAfter(listener, () -> semaphore.release());
Expand Down Expand Up @@ -360,6 +367,9 @@ public void updateMLTaskDirectly(String taskId, Map<String, Object> updatedField
updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli());
updateRequest.doc(updatedContent);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (updatedFields.containsKey(STATE_FIELD) && TASK_DONE_STATES.contains(updatedFields.containsKey(STATE_FIELD))) {
updateRequest.retryOnConflict(3);
}
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore()));
} catch (Exception e) {
Expand Down
Loading