diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModel.java index 1c12be8fad..517dc9ef7e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModel.java @@ -232,7 +232,7 @@ protected void loadTextEmbeddingModel(File modelZipFile, String modelId, String } log.info("Load model {} successfully on {} devices", modelId, devices.length); return null; - } catch (Exception e) { + } catch (Throwable e) { String errorMessage = "Failed to load model " + modelId; log.error(errorMessage, e); close(); @@ -296,7 +296,7 @@ protected ModelTensorOutput predictTextEmbedding(String modelId, MLInputDataset } return new ModelTensorOutput(tensorOutputs); }); - } catch (PrivilegedActionException e) { + } catch (Throwable e) { String errorMsg = "Failed to inference text embedding"; log.error(errorMsg, e); throw new MLException(errorMsg, e); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java index 06ecb3b475..30b6074ed9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.text_embedding; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -268,8 +269,9 @@ public void initModel_WrongModelFile() throws URISyntaxException { textEmbeddingModel.initModel(model, params); } catch (Exception e) { assertEquals(MLException.class, e.getClass()); - assertEquals(IllegalArgumentException.class, e.getCause().getClass()); - assertEquals("found multiple models", e.getCause().getMessage()); + Throwable rootCause = ExceptionUtils.getRootCause(e); + assertEquals(IllegalArgumentException.class, rootCause.getClass()); + assertEquals("found multiple models", rootCause.getMessage()); } } @@ -311,7 +313,7 @@ public void predict_NullModelId() { @Test public void predict_AfterModelClosed() { exceptionRule.expect(MLException.class); - exceptionRule.expectMessage("model not loaded"); + exceptionRule.expectMessage("Failed to inference text embedding"); textEmbeddingModel.initModel(model, params); textEmbeddingModel.close(); textEmbeddingModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java index 892dae5208..1609f3339a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/load/TransportLoadModelAction.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -128,6 +129,22 @@ protected void doExecute(Task task, ActionRequest request, ActionListener 0) { + Set difference = new HashSet(Arrays.asList(workerNodes)); + difference.removeAll(Arrays.asList(targetNodeIds)); + if (difference.size() > 0) { + listener + .onFailure( + new IllegalArgumentException( + "Model already deployed to these nodes: " + + Arrays.toString(difference.toArray(new String[0])) + + ", but they are not included in target node ids. Unload model from these nodes if don't need them any more." + ) + ); + return; + } + } } else { nodeIds.addAll(allEligibleNodeIds); eligibleNodes.addAll(Arrays.asList(allEligibleNodes)); diff --git a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java index 7917c7d9fb..d932832de9 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java @@ -6,9 +6,11 @@ package org.opensearch.ml.action.syncup; import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS; import java.io.IOException; import java.nio.file.Path; +import java.time.Instant; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -23,7 +25,13 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest; @@ -34,10 +42,14 @@ import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.utils.FileUtils; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskCache; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; + @Log4j2 public class TransportSyncUpOnNodeAction extends TransportNodesAction { @@ -51,9 +63,12 @@ public class TransportSyncUpOnNodeAction extends NamedXContentRegistry xContentRegistry; MLEngine mlEngine; + private volatile Integer mlTaskTimeout; + @Inject public TransportSyncUpOnNodeAction( TransportService transportService, + Settings settings, ActionFilters actionFilters, ModelHelper modelHelper, MLTaskManager mlTaskManager, @@ -84,6 +99,9 @@ public TransportSyncUpOnNodeAction( this.client = client; this.xContentRegistry = xContentRegistry; this.mlEngine = mlEngine; + + this.mlTaskTimeout = ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS, it -> { mlTaskTimeout = it; }); } @Override @@ -148,11 +166,62 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest loadM mlTaskManager.syncRunningLoadModelTasks(runningLoadModelTasks); } + cleanUpLocalCache(); cleanUpLocalCacheFiles(); return new MLSyncUpNodeResponse(clusterService.localNode(), "ok", loadedModelIds, runningLoadModelTaskIds); } + @VisibleForTesting + void cleanUpLocalCache() { + String[] allTaskIds = mlTaskManager.getAllTaskIds(); + if (allTaskIds == null) { + return; + } + for (String taskId : allTaskIds) { + MLTaskCache mlTaskCache = mlTaskManager.getMLTaskCache(taskId); + MLTask mlTask = mlTaskCache.getMlTask(); + Instant lastUpdateTime = mlTask.getLastUpdateTime(); + Instant now = Instant.now(); + if (now.isAfter(lastUpdateTime.plusSeconds(mlTaskTimeout))) { + log.info("ML task timeout. task id: {}, task type: {}", taskId, mlTask.getTaskType()); + mlTaskManager + .updateMLTask( + taskId, + ImmutableMap + .of(MLTask.STATE_FIELD, MLTaskState.FAILED, MLTask.ERROR_FIELD, "timeout after " + mlTaskTimeout + " seconds"), + 10_000, + true + ); + + if (mlTask.getTaskType() == MLTaskType.LOAD_MODEL) { + String modelId = mlTask.getModelId(); + String[] workerNodes = mlModelManager.getWorkerNodes(modelId); + MLModelState modelState; + if (workerNodes == null || workerNodes.length == 0) { + modelState = MLModelState.LOAD_FAILED; + } else if (mlTask.getWorkerNodes().size() > workerNodes.length) { + modelState = MLModelState.PARTIALLY_LOADED; + } else { + modelState = MLModelState.LOADED; + if (mlTask.getWorkerNodes().size() < workerNodes.length) { + log + .warn( + "Model loaded on more nodes than target worker nodes. taskId:{}, modelId: {}, workerNodes: {}, targetWorkerNodes: {}", + taskId, + modelId, + Arrays.toString(workerNodes), + Arrays.toString(mlTask.getWorkerNodes().toArray(new String[0])) + ); + } + } + log.info("Reset model state as {} for model {}", modelState, modelId); + mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, modelState)); + } + } + } + } + private void cleanUpLocalCacheFiles() { Path uploadModelRootPath = mlEngine.getUploadModelRootPath(); Path loadModelRootPath = mlEngine.getLoadModelRootPath(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 081328e55c..78a07f7d37 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -121,6 +121,18 @@ public Predictable getPredictor(String modelId) { return modelCache.getPredictor(); } + /** + * Set target worker nodes of model. + * @param modelId model id + * @param targetWorkerNodes target worker nodes of model + */ + public void setTargetWorkerNodes(String modelId, List targetWorkerNodes) { + MLModelCache modelCache = modelCaches.get(modelId); + if (modelCache != null) { + modelCache.setTargetWorkerNodes(targetWorkerNodes); + } + } + /** * Remove model. * @param modelId model id diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index fe02a64999..f085d03d1c 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -101,6 +101,7 @@ import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.io.Files; /** @@ -129,6 +130,16 @@ public class MLModelManager { private volatile Integer maxUploadTasksPerNode; private volatile Integer maxLoadTasksPerNode; + public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet + .of( + MLModelState.TRAINED, + MLModelState.UPLOADED, + MLModelState.LOADED, + MLModelState.PARTIALLY_LOADED, + MLModelState.LOAD_FAILED, + MLModelState.UNLOADED + ); + public MLModelManager( ClusterService clusterService, Client client, @@ -422,7 +433,12 @@ public void loadModel( ActionListener listener ) { mlStats.createCounterStatIfAbsent(functionName, ActionName.LOAD, ML_ACTION_REQUEST_COUNT).increment(); + List workerNodes = mlTask.getWorkerNodes(); if (modelCacheHelper.isModelLoaded(modelId)) { + if (workerNodes != null && workerNodes.size() > 0) { + log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId); + modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes); + } listener.onResponse("successful"); return; } @@ -430,7 +446,7 @@ public void loadModel( listener.onFailure(new IllegalArgumentException("Exceed max model per node limit")); return; } - modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, mlTask.getWorkerNodes()); + modelCacheHelper.initModelState(modelId, MLModelState.LOADING, functionName, workerNodes); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { checkAndAddRunningTask(mlTask, maxLoadTasksPerNode); this.getModel(modelId, threadedActionListener(LOAD_THREAD_POOL, ActionListener.wrap(mlModel -> { @@ -596,6 +612,10 @@ public void updateModel(String modelId, Map updatedFields, Actio UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_INDEX, modelId); updateRequest.doc(updatedFields); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + if (updatedFields.containsKey(MLModel.MODEL_STATE_FIELD) + && MODEL_DONE_STATES.contains(updatedFields.get(MLModel.MODEL_STATE_FIELD))) { + updateRequest.retryOnConflict(3); + } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index d4ab3821b6..0c64c824b8 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -500,6 +500,7 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE, MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE, MLCommonsSettings.ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS, + MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS, MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT, MLCommonsSettings.ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE, MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE, diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index dfa9fe8f52..0b61a9e249 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -15,7 +15,9 @@ private MLCommonsSettings() {} .simpleString("plugins.ml_commons.task_dispatch_policy", "round_robin", Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_MAX_MODELS_PER_NODE = Setting - .intSetting("plugins.ml_commons.max_model_on_node", 10, 0, 1000, Setting.Property.NodeScope, Setting.Property.Dynamic); + .intSetting("plugins.ml_commons.max_model_on_node", 10, 0, 10000, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE = Setting + .intSetting("plugins.ml_commons.max_upload_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE = Setting .intSetting("plugins.ml_commons.max_load_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting @@ -32,6 +34,8 @@ private MLCommonsSettings() {} Setting.Property.Dynamic ); + public static final Setting ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS = Setting + .intSetting("plugins.ml_commons.ml_task_timeout_in_seconds", 600, 1, 86400, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final Setting ML_COMMONS_MONITORING_REQUEST_COUNT = Setting .longSetting( "plugins.ml_commons.monitoring_request_count", @@ -42,9 +46,6 @@ private MLCommonsSettings() {} Setting.Property.Dynamic ); - public static final Setting ML_COMMONS_MAX_UPLOAD_TASKS_PER_NODE = Setting - .intSetting("plugins.ml_commons.max_upload_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_TRUSTED_URL_REGEX = Setting .simpleString( "plugins.ml_commons.trusted_url_regex", diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index feee503360..2ea8c69183 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -50,6 +50,7 @@ import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; /** * MLTaskManager is responsible for managing MLTask. @@ -63,6 +64,9 @@ public class MLTaskManager { private final MLIndicesHandler mlIndicesHandler; private final Map runningTasksCount; + public static final ImmutableSet TASK_DONE_STATES = ImmutableSet + .of(MLTaskState.COMPLETED, MLTaskState.COMPLETED_WITH_ERROR, MLTaskState.FAILED, MLTaskState.CANCELLED); + /** * Constructor to create ML task manager. * @@ -320,6 +324,9 @@ public void updateMLTask( updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); updateRequest.doc(updatedContent); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + if (updatedFields.containsKey(STATE_FIELD) && TASK_DONE_STATES.contains(updatedFields.containsKey(STATE_FIELD))) { + updateRequest.retryOnConflict(3); + } ActionListener actionListener = semaphore == null ? listener : ActionListener.runAfter(listener, () -> semaphore.release()); @@ -360,6 +367,9 @@ public void updateMLTaskDirectly(String taskId, Map updatedField updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); updateRequest.doc(updatedContent); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + if (updatedFields.containsKey(STATE_FIELD) && TASK_DONE_STATES.contains(updatedFields.containsKey(STATE_FIELD))) { + updateRequest.retryOnConflict(3); + } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java index 465eefa38f..066410ea49 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java @@ -7,9 +7,16 @@ import static java.util.Collections.emptyMap; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.clusterSetting; @@ -18,6 +25,7 @@ import java.io.IOException; import java.net.InetAddress; import java.nio.file.Paths; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -29,6 +37,7 @@ import org.junit.Before; import org.junit.rules.TemporaryFolder; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -44,6 +53,10 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.MLTask; +import org.opensearch.ml.common.MLTaskType; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeRequest; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; @@ -52,11 +65,13 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskCache; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; public class TransportSyncUpOnNodeActionTests extends OpenSearchTestCase { @@ -104,6 +119,7 @@ public void setup() throws IOException { when(clusterService.getClusterName()).thenReturn(new ClusterName("Local Cluster")); action = new TransportSyncUpOnNodeAction( transportService, + settings, actionFilters, modelHelper, mlTaskManager, @@ -221,6 +237,79 @@ public void testNodeOperation_RemovedWorkerNodes() throws IOException { testFolder.delete(); } + public void testCleanUpLocalCache_NoTasks() { + when(mlTaskManager.getAllTaskIds()).thenReturn(null); + action.cleanUpLocalCache(); + verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); + } + + public void testCleanUpLocalCache_EmptyTasks() { + when(mlTaskManager.getAllTaskIds()).thenReturn(new String[] {}); + action.cleanUpLocalCache(); + verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); + } + + public void testCleanUpLocalCache_NotExpiredMLTask() { + String taskId = randomAlphaOfLength(5); + when(mlTaskManager.getAllTaskIds()).thenReturn(new String[] { taskId }); + MLTask mlTask = MLTask.builder().lastUpdateTime(Instant.now()).build(); + MLTaskCache taskCache = MLTaskCache.builder().mlTask(mlTask).build(); + when(mlTaskManager.getMLTaskCache(taskId)).thenReturn(taskCache); + action.cleanUpLocalCache(); + verify(mlTaskManager, never()).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); + } + + public void testCleanUpLocalCache_ExpiredMLTask_Upload() { + String taskId = randomAlphaOfLength(5); + when(mlTaskManager.getAllTaskIds()).thenReturn(new String[] { taskId }); + MLTask mlTask = MLTask.builder().taskType(MLTaskType.UPLOAD_MODEL).lastUpdateTime(Instant.now().minusSeconds(86400)).build(); + MLTaskCache taskCache = MLTaskCache.builder().mlTask(mlTask).build(); + when(mlTaskManager.getMLTaskCache(taskId)).thenReturn(taskCache); + action.cleanUpLocalCache(); + verify(mlTaskManager, times(1)).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); + verify(mlModelManager, never()).updateModel(anyString(), any()); + } + + public void testCleanUpLocalCache_ExpiredMLTask_Load_NullWorkerNode() { + testCleanUpLocalCache_ExpiredMLTask_LoadStatus(MLModelState.LOAD_FAILED); + } + + public void testCleanUpLocalCache_ExpiredMLTask_Load_PartiallyLoaded() { + testCleanUpLocalCache_ExpiredMLTask_LoadStatus(MLModelState.PARTIALLY_LOADED); + } + + public void testCleanUpLocalCache_ExpiredMLTask_Load_Loaded() { + testCleanUpLocalCache_ExpiredMLTask_LoadStatus(MLModelState.LOADED); + } + + private void testCleanUpLocalCache_ExpiredMLTask_LoadStatus(MLModelState modelState) { + String taskId = randomAlphaOfLength(5); + String modelId = randomAlphaOfLength(5); + when(mlTaskManager.getAllTaskIds()).thenReturn(new String[] { taskId }); + MLTask.MLTaskBuilder mlTaskBuilder = MLTask + .builder() + .modelId(modelId) + .taskType(MLTaskType.LOAD_MODEL) + .lastUpdateTime(Instant.now().minusSeconds(86400)); + if (MLModelState.PARTIALLY_LOADED == modelState) { + mlTaskBuilder.workerNodes(ImmutableList.of("node1", "node2")); + } else if (MLModelState.LOADED == modelState) { + mlTaskBuilder.workerNodes(ImmutableList.of("node1")); + } + + MLTask mlTask = mlTaskBuilder.build(); + MLTaskCache taskCache = MLTaskCache.builder().mlTask(mlTask).build(); + if (MLModelState.LOAD_FAILED != modelState) { + when(mlModelManager.getWorkerNodes(modelId)).thenReturn(new String[] { "node1" }); + } + when(mlTaskManager.getMLTaskCache(taskId)).thenReturn(taskCache); + action.cleanUpLocalCache(); + verify(mlTaskManager, times(1)).updateMLTask(anyString(), any(), anyLong(), anyBoolean()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Map.class); + verify(mlModelManager, times(1)).updateModel(eq(modelId), argumentCaptor.capture()); + assertEquals(modelState, argumentCaptor.getValue().get(MLModel.MODEL_STATE_FIELD)); + } + private MLSyncUpInput prepareRequest() { Map addedWorkerNodes = new HashMap<>(); addedWorkerNodes.put("modelId1", new String[] { "nodeId1", "nodeId2", "nodeId3" }); @@ -262,8 +351,12 @@ private MLSyncUpInput prepareRequest2() { } private void mockSettings(boolean onlyRunOnMLNode) { - settings = Settings.builder().put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), onlyRunOnMLNode).build(); - ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ONLY_RUN_ON_ML_NODE); + settings = Settings + .builder() + .put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), onlyRunOnMLNode) + .put(ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS.getKey(), 30) + .build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ONLY_RUN_ON_ML_NODE, ML_COMMONS_ML_TASK_TIMEOUT_IN_SECONDS); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); } }