From e380395cfca2261d348e900bf39c72b3df0d401f Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Mon, 29 Apr 2024 21:38:55 -0700 Subject: [PATCH] Add TTL to un-deploy model automatically (#2365) * add ttl to un-deploy model automatically Signed-off-by: Xun Zhang * undeploy only for models that expired in all nodes Signed-off-by: Xun Zhang * add bwc version for model ttl Signed-off-by: Xun Zhang * only use minutes in ttl Signed-off-by: Xun Zhang * move MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL to MLDeploySetting Signed-off-by: Xun Zhang --------- Signed-off-by: Xun Zhang --- .../ml/common/model/MLDeploySetting.java | 27 ++++++++++++- .../common/transport/sync/MLSyncUpInput.java | 1 + .../transport/sync/MLSyncUpNodeResponse.java | 15 +++++++- .../common/model/MLDeployingSettingTests.java | 8 ++-- .../sync/MLSyncUpNodeResponseTest.java | 6 ++- .../TransportPredictionTaskAction.java | 1 + .../syncup/TransportSyncUpOnNodeAction.java | 5 ++- .../opensearch/ml/cluster/MLSyncUpCron.java | 38 +++++++++++++++++++ .../org/opensearch/ml/model/MLModelCache.java | 2 + .../ml/model/MLModelCacheHelper.java | 37 ++++++++++++++++++ .../opensearch/ml/model/MLModelManager.java | 9 +++++ .../TransportSyncUpOnNodeActionTests.java | 4 +- .../ml/cluster/MLSyncUpCronTests.java | 13 ++++++- 13 files changed, 153 insertions(+), 13 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java index 51acd8bea3..4bd864b237 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java @@ -8,6 +8,7 @@ import lombok.Builder; import lombok.Getter; import lombok.Setter; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -24,25 +25,42 @@ @Getter public class MLDeploySetting implements ToXContentObject, Writeable { public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled"; + public static final String MODEL_TTL_MINUTES_FIELD = "model_ttl_minutes"; + private static final long DEFAULT_TTL_MINUTES = -1; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL = Version.V_2_14_0; private Boolean isAutoDeployEnabled; + private Long modelTTLInMinutes; // in minutes @Builder(toBuilder = true) - public MLDeploySetting(Boolean isAutoDeployEnabled) { + public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTLInMinutes) { this.isAutoDeployEnabled = isAutoDeployEnabled; + this.modelTTLInMinutes = modelTTLInMinutes; + if (modelTTLInMinutes == null) { + this.modelTTLInMinutes = DEFAULT_TTL_MINUTES; + } } public MLDeploySetting(StreamInput in) throws IOException { this.isAutoDeployEnabled = in.readOptionalBoolean(); + Version streamInputVersion = in.getVersion(); + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { + this.modelTTLInMinutes = in.readOptionalLong(); + } } @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); out.writeOptionalBoolean(isAutoDeployEnabled); + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { + out.writeOptionalLong(modelTTLInMinutes); + } } public static MLDeploySetting parse(XContentParser parser) throws IOException { Boolean isAutoDeployEnabled = null; + Long modelTTLMinutes = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -51,12 +69,14 @@ public static MLDeploySetting parse(XContentParser parser) throws IOException { case IS_AUTO_DEPLOY_ENABLED_FIELD: isAutoDeployEnabled = parser.booleanValue(); break; + case MODEL_TTL_MINUTES_FIELD: + modelTTLMinutes = parser.longValue(); default: parser.skipChildren(); break; } } - return new MLDeploySetting(isAutoDeployEnabled); + return new MLDeploySetting(isAutoDeployEnabled, modelTTLMinutes); } @Override @@ -65,6 +85,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isAutoDeployEnabled != null) { builder.field(IS_AUTO_DEPLOY_ENABLED_FIELD, isAutoDeployEnabled); } + if (modelTTLInMinutes != null) { + builder.field(MODEL_TTL_MINUTES_FIELD, modelTTLInMinutes); + } builder.endObject(); return builder; } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java index de04b2936d..7ad34321b8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java @@ -7,6 +7,7 @@ import lombok.Builder; import lombok.Data; +import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java index e7ac993fba..74893ec91e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java @@ -7,10 +7,12 @@ import lombok.Getter; import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.model.MLDeploySetting; import java.io.IOException; @@ -22,22 +24,28 @@ public class MLSyncUpNodeResponse extends BaseNodeResponse { private String[] deployedModelIds; private String[] runningDeployModelIds; // model ids which have deploying model task running private String[] runningDeployModelTaskIds; // deploy model task ids which is running + private String[] expiredModelIds; public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] deployedModelIds, String[] runningDeployModelIds, - String[] runningDeployModelTaskIds) { + String[] runningDeployModelTaskIds, String[] expiredModelIds) { super(node); this.modelStatus = modelStatus; this.deployedModelIds = deployedModelIds; this.runningDeployModelIds = runningDeployModelIds; this.runningDeployModelTaskIds = runningDeployModelTaskIds; + this.expiredModelIds = expiredModelIds; } public MLSyncUpNodeResponse(StreamInput in) throws IOException { super(in); + Version streamInputVersion = in.getVersion(); this.modelStatus = in.readOptionalString(); this.deployedModelIds = in.readOptionalStringArray(); this.runningDeployModelIds = in.readOptionalStringArray(); this.runningDeployModelTaskIds = in.readOptionalStringArray(); + if (streamInputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { + this.expiredModelIds = in.readOptionalStringArray(); + } } public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException { @@ -46,11 +54,14 @@ public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException @Override public void writeTo(StreamOutput out) throws IOException { + Version streamOutputVersion = out.getVersion(); super.writeTo(out); out.writeOptionalString(modelStatus); out.writeOptionalStringArray(deployedModelIds); out.writeOptionalStringArray(runningDeployModelIds); out.writeOptionalStringArray(runningDeployModelTaskIds); + if (streamOutputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { + out.writeOptionalStringArray(expiredModelIds); + } } - } diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java index e30cf10d4d..72b5e883b5 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java @@ -36,7 +36,7 @@ public class MLDeployingSettingTests { private MLDeploySetting deploySettingNull; - private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true}"; + private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl_minutes\":-1}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -66,7 +66,7 @@ public void testToXContent() throws Exception { @Test public void testToXContentIncomplete() throws Exception { - final String expectedIncompleteInputStr = "{}"; + final String expectedIncompleteInputStr = "{\"model_ttl_minutes\":-1}"; String jsonStr = serializationWithToXContent(deploySettingNull); assertEquals(expectedIncompleteInputStr, jsonStr); @@ -109,12 +109,12 @@ public void parseWithIllegalArgumentInteger() throws Exception { @Test public void parseWithIllegalField() throws Exception { - final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + + final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl_hours\":0," + "\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { - assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + assertEquals("{\"is_auto_deploy_enabled\":true,\"model_ttl_minutes\":-1}", serializationWithToXContent(parsedInput)); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java index 56e1672852..8599002354 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java @@ -24,6 +24,8 @@ public class MLSyncUpNodeResponseTest { private final String[] loadedModelIds = {"loadedModelIds"}; private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"}; private final String[] runningLoadModelIds = {"modelid1"}; + + private final String[] expiredModelIds = {"modelExpired"}; @Before public void setUp() throws Exception { localNode = new DiscoveryNode( @@ -38,7 +40,7 @@ public void setUp() throws Exception { @Test public void testSerializationDeserialization() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput()); @@ -51,7 +53,7 @@ public void testSerializationDeserialization() throws IOException { @Test public void testReadProfile() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput()); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 34be01b3a2..b396b39016 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -223,6 +223,7 @@ private void executePredict( long endTime = System.nanoTime(); double durationInMs = (endTime - startTime) / 1e6; modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); + modelCacheHelper.refreshLastAccessTime(modelId); log.debug("completed predict request " + requestId + " for model " + modelId); }) ); diff --git a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java index 94327f190e..5fad4aa4d2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeAction.java @@ -162,11 +162,13 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU String[] deployedModelIds = null; String[] runningDeployModelTaskIds = null; String[] runningDeployModelIds = null; + String[] expiredModelIds = null; if (syncUpInput.isGetDeployedModels()) { deployedModelIds = mlModelManager.getLocalDeployedModels(); List localRunningDeployModel = mlTaskManager.getLocalRunningDeployModelTasks(); runningDeployModelTaskIds = localRunningDeployModel.get(0); runningDeployModelIds = localRunningDeployModel.get(1); + expiredModelIds = mlModelManager.getExpiredModels(); } if (syncUpInput.isClearRoutingTable()) { @@ -186,7 +188,8 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU "ok", deployedModelIds, runningDeployModelIds, - runningDeployModelTaskIds + runningDeployModelTaskIds, + expiredModelIds ); } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 12e37b7b4d..b11fe7afc2 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -41,6 +41,8 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction; +import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.search.SearchHit; @@ -101,8 +103,17 @@ public void run() { Map> runningDeployModelTasks = new HashMap<>(); // key is model id, value is set of worker node ids Map> deployingModels = new HashMap<>(); + // key is expired model_id, value is set of worker node ids + Map> expiredModelToNodes = new HashMap<>(); for (MLSyncUpNodeResponse response : responses) { String nodeId = response.getNode().getId(); + String[] expiredModelIds = response.getExpiredModelIds(); + if (expiredModelIds != null && expiredModelIds.length > 0) { + Arrays + .stream(expiredModelIds) + .forEach(modelId -> { expiredModelToNodes.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); }); + } + String[] deployedModelIds = response.getDeployedModelIds(); if (deployedModelIds != null && deployedModelIds.length > 0) { for (String modelId : deployedModelIds) { @@ -126,6 +137,17 @@ public void run() { } } } + + Set modelsToUndeploy = new HashSet<>(); + for (String modelId : expiredModelToNodes.keySet()) { + if (modelWorkerNodes.containsKey(modelId) + && expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) { + // this model has expired in all the nodes + modelWorkerNodes.remove(modelId); + modelsToUndeploy.add(modelId); + } + } + for (Map.Entry> entry : modelWorkerNodes.entrySet()) { String modelId = entry.getKey(); log.debug("will sync model worker nodes for model: {}: {}", modelId, entry.getValue().toArray(new String[0])); @@ -154,6 +176,8 @@ public void run() { log.error("Failed to sync model routing", ex); }) ); + // Undeploy expired models + undeployExpiredModels(modelsToUndeploy, modelWorkerNodes); // refresh model status mlIndicesHandler @@ -163,6 +187,20 @@ public void run() { }, e -> { log.error("Failed to sync model routing", e); })); } + private void undeployExpiredModels(Set expiredModels, Map> modelWorkerNodes) { + expiredModels.forEach(modelId -> { + String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]); + + MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest( + targetNodeIds, + new String[] { modelId } + ); + client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> { + log.debug("model {} is un_deployed", modelId); + }, e -> { log.error("Failed to undeploy model {}", modelId, e); })); + }); + } + @VisibleForTesting void initMLConfig() { if (mlConfigInited) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 06509c30ca..61d27576d1 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -5,6 +5,7 @@ package org.opensearch.ml.model; +import java.time.Instant; import java.util.DoubleSummaryStatistics; import java.util.List; import java.util.Map; @@ -51,6 +52,7 @@ public class MLModelCache { // In rare case, this could be null, e.g. model info not synced up yet a predict request comes in. @Setter private Boolean deployToAllNodes; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime; public MLModelCache() { targetWorkerNodes = ConcurrentHashMap.newKeySet(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index fdfb703677..429323297c 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -7,6 +7,8 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; +import java.time.Duration; +import java.time.Instant; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -68,6 +70,7 @@ public synchronized void initModelState( modelCache.setFunctionName(functionName); modelCache.setTargetWorkerNodes(targetWorkerNodes); modelCache.setDeployToAllNodes(deployToAllNodes); + modelCache.setLastAccessTime(Instant.now()); modelCaches.put(modelId, modelCache); } @@ -87,6 +90,7 @@ public synchronized void initModelStateLocal( modelCache.setFunctionName(functionName); modelCache.setTargetWorkerNodes(targetWorkerNodes); modelCache.setDeployToAllNodes(false); + modelCache.setLastAccessTime(Instant.now()); modelCaches.put(modelId, modelCache); } @@ -341,6 +345,29 @@ public String[] getLocalDeployedModels() { .toArray(new String[0]); } + /** + * Get expired models on node. + * + * @return array of expired model id + */ + public String[] getExpiredModels() { + return modelCaches.entrySet().stream().filter(entry -> { + MLModel mlModel = entry.getValue().getCachedModelInfo(); + if (mlModel.getDeploySetting() == null) { + return false; // no TTL, never expire + } + Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now()); + Long ttlInMinutes = mlModel.getDeploySetting().getModelTTLInMinutes(); + if (ttlInMinutes < 0) { + return false; + } + Duration ttl = Duration.ofMinutes(ttlInMinutes); + boolean isModelExpired = liveDuration.getSeconds() >= ttl.getSeconds(); + return isModelExpired + && (mlModel.getModelState() == MLModelState.DEPLOYED || mlModel.getModelState() == MLModelState.PARTIALLY_DEPLOYED); + }).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]); + } + /** * Check if model is running on node. * @@ -403,6 +430,16 @@ public void setTargetWorkerNodes(String modelId, List targetWorkerNodes) } } + /** + * Set the last access time to Instant.now() + * + * @param modelId model id + */ + public void refreshLastAccessTime(String modelId) { + MLModelCache modelCache = modelCaches.get(modelId); + modelCache.setLastAccessTime(Instant.now()); + } + /** * Remove model. * diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index bb519efd37..73b585162b 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -967,6 +967,7 @@ public void deployModel( log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId); modelCacheHelper.setDeployToAllNodes(modelId, deployToAllNodes); modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes); + modelCacheHelper.refreshLastAccessTime(modelId); } listener.onResponse("successful"); return; @@ -1041,6 +1042,7 @@ public void deployModel( modelCacheHelper.setMLExecutor(modelId, mlExecutable); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); wrappedListener.onResponse("successful"); } catch (Exception e) { log.error("Failed to add predictor to cache", e); @@ -1053,6 +1055,7 @@ public void deployModel( modelCacheHelper.setPredictor(modelId, predictable); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes(); long contentSize = modelContentSizeInBytes == null ? mlModel.getTotalChunks() * CHUNK_SIZE @@ -1105,6 +1108,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); wrappedListener.onResponse("successful"); return; } @@ -1114,6 +1118,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + modelCacheHelper.refreshLastAccessTime(modelId); wrappedListener.onResponse("successful"); log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); }, wrappedListener::onFailure)); @@ -1857,6 +1862,10 @@ public String[] getLocalDeployedModels() { return modelCacheHelper.getDeployedModels(); } + public String[] getExpiredModels() { + return modelCacheHelper.getExpiredModels(); + } + /** * Sync model routing table. * diff --git a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java index 49293a96f6..392fc2379a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/syncup/TransportSyncUpOnNodeActionTests.java @@ -171,12 +171,14 @@ public void testNewNodeResponse() throws IOException { String[] deployedModelIds = new String[] { "123" }; String[] runningDeployModelIds = new String[] { "model1" }; String[] runningDeployModelTaskIds = new String[] { "1" }; + String[] expiredModelIds = new String[] { "modelExpired" }; MLSyncUpNodeResponse response = new MLSyncUpNodeResponse( mlNode1, "DEPLOYED", deployedModelIds, runningDeployModelIds, - runningDeployModelTaskIds + runningDeployModelTaskIds, + expiredModelIds ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index b9f169e3d2..0b6709fc16 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -415,7 +415,18 @@ private void mockSyncUp_GatherRunningTasks() { String[] deployedModelIds = new String[] { randomAlphaOfLength(10) }; String[] runningDeployModelIds = new String[] { randomAlphaOfLength(10) }; String[] runningDeployModelTaskIds = new String[] { randomAlphaOfLength(10) }; - nodeResponses.add(new MLSyncUpNodeResponse(mlNode1, "ok", deployedModelIds, runningDeployModelIds, runningDeployModelTaskIds)); + String[] expiredModelIds = new String[] { randomAlphaOfLength(10) }; + nodeResponses + .add( + new MLSyncUpNodeResponse( + mlNode1, + "ok", + deployedModelIds, + runningDeployModelIds, + runningDeployModelTaskIds, + expiredModelIds + ) + ); MLSyncUpNodesResponse response = new MLSyncUpNodesResponse(ClusterName.DEFAULT, nodeResponses, Arrays.asList()); listener.onResponse(response); return null;