diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java index 51acd8bea3..361e3f5e79 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java @@ -17,6 +17,7 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.time.Duration; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -24,25 +25,37 @@ @Getter public class MLDeploySetting implements ToXContentObject, Writeable { public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled"; + public static final String MODEL_TTL_FIELD = "model_ttl"; + private static final long DEFAULT_TTL_HOUR = -1; private Boolean isAutoDeployEnabled; + private Long modelTTL; // Time to live in hours @Builder(toBuilder = true) public MLDeploySetting(Boolean isAutoDeployEnabled) { this.isAutoDeployEnabled = isAutoDeployEnabled; + this.modelTTL = DEFAULT_TTL_HOUR; + } + @Builder(toBuilder = true) + public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTL) { + this.isAutoDeployEnabled = isAutoDeployEnabled; + this.modelTTL = modelTTL; } public MLDeploySetting(StreamInput in) throws IOException { this.isAutoDeployEnabled = in.readOptionalBoolean(); + this.modelTTL = in.readOptionalLong(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(isAutoDeployEnabled); + out.writeOptionalLong(modelTTL); } public static MLDeploySetting parse(XContentParser parser) throws IOException { Boolean isAutoDeployEnabled = null; + Long modelTTL = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -51,12 +64,14 @@ public static MLDeploySetting parse(XContentParser parser) throws IOException { case IS_AUTO_DEPLOY_ENABLED_FIELD: isAutoDeployEnabled = parser.booleanValue(); break; + case MODEL_TTL_FIELD: + modelTTL = parser.longValue(); default: parser.skipChildren(); break; } } - return new MLDeploySetting(isAutoDeployEnabled); + return new MLDeploySetting(isAutoDeployEnabled, modelTTL); } @Override @@ -65,6 +80,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isAutoDeployEnabled != null) { builder.field(IS_AUTO_DEPLOY_ENABLED_FIELD, isAutoDeployEnabled); } + if (modelTTL != null) { + builder.field(MODEL_TTL_FIELD, modelTTL); + } builder.endObject(); return builder; } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java index e7ac993fba..79f2e8a9f4 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java @@ -22,14 +22,16 @@ public class MLSyncUpNodeResponse extends BaseNodeResponse { private String[] deployedModelIds; private String[] runningDeployModelIds; // model ids which have deploying model task running private String[] runningDeployModelTaskIds; // deploy model task ids which is running + private String[] expiredModelIds; public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] deployedModelIds, String[] runningDeployModelIds, - String[] runningDeployModelTaskIds) { + String[] runningDeployModelTaskIds, String[] expiredModelIds) { super(node); this.modelStatus = modelStatus; this.deployedModelIds = deployedModelIds; this.runningDeployModelIds = runningDeployModelIds; this.runningDeployModelTaskIds = runningDeployModelTaskIds; + this.expiredModelIds = expiredModelIds; } public MLSyncUpNodeResponse(StreamInput in) throws IOException { @@ -38,6 +40,7 @@ public MLSyncUpNodeResponse(StreamInput in) throws IOException { this.deployedModelIds = in.readOptionalStringArray(); this.runningDeployModelIds = in.readOptionalStringArray(); this.runningDeployModelTaskIds = in.readOptionalStringArray(); + this.expiredModelIds = in.readOptionalStringArray(); } public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException { @@ -51,6 +54,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalStringArray(deployedModelIds); out.writeOptionalStringArray(runningDeployModelIds); out.writeOptionalStringArray(runningDeployModelTaskIds); + out.writeOptionalStringArray(expiredModelIds); } } diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java index e30cf10d4d..113248bbe0 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java @@ -36,7 +36,7 @@ public class MLDeployingSettingTests { private MLDeploySetting deploySettingNull; - private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true}"; + private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl\":-1}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -66,7 +66,7 @@ public void testToXContent() throws Exception { @Test public void testToXContentIncomplete() throws Exception { - final String expectedIncompleteInputStr = "{}"; + final String expectedIncompleteInputStr = "{\"model_ttl\":-1}"; String jsonStr = serializationWithToXContent(deploySettingNull); assertEquals(expectedIncompleteInputStr, jsonStr); @@ -109,7 +109,7 @@ public void parseWithIllegalArgumentInteger() throws Exception { @Test public void parseWithIllegalField() throws Exception { - final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + + final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl\":-1," + "\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java index 56e1672852..8599002354 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java @@ -24,6 +24,8 @@ public class MLSyncUpNodeResponseTest { private final String[] loadedModelIds = {"loadedModelIds"}; private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"}; private final String[] runningLoadModelIds = {"modelid1"}; + + private final String[] expiredModelIds = {"modelExpired"}; @Before public void setUp() throws Exception { localNode = new DiscoveryNode( @@ -38,7 +40,7 @@ public void setUp() throws Exception { @Test public void testSerializationDeserialization() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput()); @@ -51,7 +53,7 @@ public void testSerializationDeserialization() throws IOException { @Test public void testReadProfile() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 34be01b3a2..b396b39016 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -223,6 +223,7 @@ private void executePredict( long endTime = System.nanoTime(); double durationInMs = (endTime - startTime) / 1e6; modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); + modelCacheHelper.refreshLastAccessTime(modelId); log.debug("completed predict request " + requestId + " for model " + modelId); }) ); diff --git a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java index 94327f190e..5fad4aa4d2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java @@ -162,11 +162,13 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU String[] deployedModelIds = null; String[] runningDeployModelTaskIds = null; String[] runningDeployModelIds = null; + String[] expiredModelIds = null; if (syncUpInput.isGetDeployedModels()) { deployedModelIds = mlModelManager.getLocalDeployedModels(); List localRunningDeployModel = mlTaskManager.getLocalRunningDeployModelTasks(); runningDeployModelTaskIds = localRunningDeployModel.get(0); runningDeployModelIds = localRunningDeployModel.get(1); + expiredModelIds = mlModelManager.getExpiredModels(); } if (syncUpInput.isClearRoutingTable()) { @@ -186,7 +188,8 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU "ok", deployedModelIds, runningDeployModelIds, - runningDeployModelTaskIds + runningDeployModelTaskIds, + expiredModelIds ); } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 12e37b7b4d..94bafc32db 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -41,6 +41,8 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.search.SearchHit; @@ -101,11 +103,23 @@ public void run() { Map> runningDeployModelTasks = new HashMap<>(); // key is model id, value is set of worker node ids Map> deployingModels = new HashMap<>(); + // key is expired model_id, value is set of worker node ids + Map> expiredModels = new HashMap<>(); for (MLSyncUpNodeResponse response : responses) { String nodeId = response.getNode().getId(); + String[] expiredModelIds = response.getExpiredModelIds(); + if (expiredModelIds != null && expiredModelIds.length > 0) { + Arrays + .stream(expiredModelIds) + .forEach(modelId -> { expiredModels.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); }); + } + String[] deployedModelIds = response.getDeployedModelIds(); if (deployedModelIds != null && deployedModelIds.length > 0) { for (String modelId : deployedModelIds) { + if (expiredModels.containsKey(modelId)) { + continue; + } Set workerNodes = modelWorkerNodes.computeIfAbsent(modelId, it -> new HashSet<>()); workerNodes.add(nodeId); } @@ -154,6 +168,8 @@ public void run() { log.error("Failed to sync model routing", ex); }) ); + // Undeploy expired models + undeployExpiredModels(expiredModels.keySet(), modelWorkerNodes); // refresh model status mlIndicesHandler @@ -163,6 +179,21 @@ public void run() { }, e -> { log.error("Failed to sync model routing", e); })); } + private void undeployExpiredModels(Set expiredModels, Map> modelWorkerNodes) { + expiredModels.forEach(modelId -> { + String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]); + + MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest( + targetNodeIds, + new String[] { modelId } + ); + + client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { + log.debug("model {} is un_deployed", modelId); + }, e -> { log.error("Failed to undeploy model {}", modelId, e); })); + }); + } + @VisibleForTesting void initMLConfig() { if (mlConfigInited) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 06509c30ca..61d27576d1 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -5,6 +5,7 @@ package org.opensearch.ml.model; +import java.time.Instant; import java.util.DoubleSummaryStatistics; import java.util.List; import java.util.Map; @@ -51,6 +52,7 @@ public class MLModelCache { // In rare case, this could be null, e.g. model info not synced up yet a predict request comes in. @Setter private Boolean deployToAllNodes; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime; public MLModelCache() { targetWorkerNodes = ConcurrentHashMap.newKeySet(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index fdfb703677..6152497033 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -7,6 +7,8 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; +import java.time.Duration; +import java.time.Instant; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -68,6 +70,7 @@ public synchronized void initModelState( modelCache.setFunctionName(functionName); modelCache.setTargetWorkerNodes(targetWorkerNodes); modelCache.setDeployToAllNodes(deployToAllNodes); + modelCache.setLastAccessTime(Instant.now()); modelCaches.put(modelId, modelCache); } @@ -87,6 +90,7 @@ public synchronized void initModelStateLocal( modelCache.setFunctionName(functionName); modelCache.setTargetWorkerNodes(targetWorkerNodes); modelCache.setDeployToAllNodes(false); + modelCache.setLastAccessTime(Instant.now()); modelCaches.put(modelId, modelCache); } @@ -341,6 +345,26 @@ public String[] getLocalDeployedModels() { .toArray(new String[0]); } + /** + * Get expired models on node. + * + * @return array of expired model id + */ + public String[] getExpiredModels() { + return modelCaches.entrySet().stream().filter(entry -> { + MLModel mlModel = entry.getValue().getCachedModelInfo(); + if (mlModel.getDeploySetting() == null) { + return false; // no TTL, never expire + } + Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now()); + Long timeToLive = mlModel.getDeploySetting().getModelTTL(); + boolean isModelExpired = (timeToLive != null + && timeToLive > 0 + && liveDuration.getSeconds() > Duration.ofHours(timeToLive).getSeconds()); + return isModelExpired && mlModel.getModelState() == MLModelState.DEPLOYED; + }).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]); + } + /** * Check if model is running on node. * @@ -403,6 +427,16 @@ public void setTargetWorkerNodes(String modelId, List targetWorkerNodes) } } + /** + * Set the last access time to Instant.now() + * + * @param modelId model id + */ + public void refreshLastAccessTime(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + modelCache.setLastAccessTime(Instant.now()); + } + /** * Remove model. * diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index bb519efd37..73b585162b 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -967,6 +967,7 @@ public void deployModel( log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId); modelCacheHelper.setDeployToAllNodes(modelId, deployToAllNodes); modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes); + modelCacheHelper.refreshLastAccessTime(modelId); } listener.onResponse("successful"); return; @@ -1041,6 +1042,7 @@ public void deployModel( modelCacheHelper.setMLExecutor(modelId, mlExecutable); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); wrappedListener.onResponse("successful"); } catch (Exception e) { log.error("Failed to add predictor to cache", e); @@ -1053,6 +1055,7 @@ public void deployModel( modelCacheHelper.setPredictor(modelId, predictable); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes(); long contentSize = modelContentSizeInBytes == null ? mlModel.getTotalChunks() * CHUNK_SIZE @@ -1105,6 +1108,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); wrappedListener.onResponse("successful"); return; } @@ -1114,6 +1118,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); wrappedListener.onResponse("successful"); log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); }, wrappedListener::onFailure)); @@ -1857,6 +1862,10 @@ public String[] getLocalDeployedModels() { return modelCacheHelper.getDeployedModels(); } + public String[] getExpiredModels() { + return modelCacheHelper.getExpiredModels(); + } + /** * Sync model routing table. * diff --git a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java index 49293a96f6..392fc2379a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java @@ -171,12 +171,14 @@ public void testNewNodeResponse() throws IOException { String[] deployedModelIds = new String[] { "123" }; String[] runningDeployModelIds = new String[] { "model1" }; String[] runningDeployModelTaskIds = new String[] { "1" }; + String[] expiredModelIds = new String[] { "modelExpired" }; MLSyncUpNodeResponse response = new MLSyncUpNodeResponse( mlNode1, "DEPLOYED", deployedModelIds, runningDeployModelIds, - runningDeployModelTaskIds + runningDeployModelTaskIds, + expiredModelIds ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index b9f169e3d2..0b6709fc16 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -415,7 +415,18 @@ private void mockSyncUp_GatherRunningTasks() { String[] deployedModelIds = new String[] { randomAlphaOfLength(10) }; String[] runningDeployModelIds = new String[] { randomAlphaOfLength(10) }; String[] runningDeployModelTaskIds = new String[] { randomAlphaOfLength(10) }; - nodeResponses.add(new MLSyncUpNodeResponse(mlNode1, "ok", deployedModelIds, runningDeployModelIds, runningDeployModelTaskIds)); + String[] expiredModelIds = new String[] { randomAlphaOfLength(10) }; + nodeResponses + .add( + new MLSyncUpNodeResponse( + mlNode1, + "ok", + deployedModelIds, + runningDeployModelIds, + runningDeployModelTaskIds, + expiredModelIds + ) + ); MLSyncUpNodesResponse response = new MLSyncUpNodesResponse(ClusterName.DEFAULT, nodeResponses, Arrays.asList()); listener.onResponse(response); return null;