diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java index 361e3f5e79..f7725e7cdc 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java @@ -25,37 +25,50 @@ @Getter public class MLDeploySetting implements ToXContentObject, Writeable { public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled"; - public static final String MODEL_TTL_FIELD = "model_ttl"; + public static final String MODEL_TTL_HOURS_FIELD = "model_ttl_hours"; + public static final String MODEL_TTL_MINUTES_FIELD = "model_ttl_minutes"; private static final long DEFAULT_TTL_HOUR = -1; + private static final long DEFAULT_TTL_MINUTES = -1; private Boolean isAutoDeployEnabled; - private Long modelTTL; // Time to live in hours + private Long modelTTLInHours; // Time to live in hours + private Long modelTTLInMinutes; // in minutes @Builder(toBuilder = true) - public MLDeploySetting(Boolean isAutoDeployEnabled) { + public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTLInHours, Long modelTTLInMinutes) { this.isAutoDeployEnabled = isAutoDeployEnabled; - this.modelTTL = DEFAULT_TTL_HOUR; - } - @Builder(toBuilder = true) - public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTL) { - this.isAutoDeployEnabled = isAutoDeployEnabled; - this.modelTTL = modelTTL; + this.modelTTLInHours = modelTTLInHours; + this.modelTTLInMinutes = modelTTLInMinutes; + if (modelTTLInHours == null && modelTTLInMinutes == null) { + this.modelTTLInHours = DEFAULT_TTL_HOUR; + this.modelTTLInMinutes = DEFAULT_TTL_MINUTES; + return; + } + if (modelTTLInHours == null) { + this.modelTTLInHours = 0L; + } + if (modelTTLInMinutes == null) { + this.modelTTLInMinutes = 0L; + } } public MLDeploySetting(StreamInput in) throws IOException { this.isAutoDeployEnabled = in.readOptionalBoolean(); - this.modelTTL = in.readOptionalLong(); + this.modelTTLInHours = in.readOptionalLong(); + this.modelTTLInMinutes = in.readOptionalLong(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(isAutoDeployEnabled); - out.writeOptionalLong(modelTTL); + out.writeOptionalLong(modelTTLInHours); + out.writeOptionalLong(modelTTLInMinutes); } public static MLDeploySetting parse(XContentParser parser) throws IOException { Boolean isAutoDeployEnabled = null; - Long modelTTL = null; + Long modelTTLHours = null; + Long modelTTLMinutes = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { String fieldName = parser.currentName(); @@ -64,14 +77,16 @@ public static MLDeploySetting parse(XContentParser parser) throws IOException { case IS_AUTO_DEPLOY_ENABLED_FIELD: isAutoDeployEnabled = parser.booleanValue(); break; - case MODEL_TTL_FIELD: - modelTTL = parser.longValue(); + case MODEL_TTL_HOURS_FIELD: + modelTTLHours = parser.longValue(); + case MODEL_TTL_MINUTES_FIELD: + modelTTLMinutes = parser.longValue(); default: parser.skipChildren(); break; } } - return new MLDeploySetting(isAutoDeployEnabled, modelTTL); + return new MLDeploySetting(isAutoDeployEnabled, modelTTLHours, modelTTLMinutes); } @Override @@ -80,8 +95,11 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isAutoDeployEnabled != null) { builder.field(IS_AUTO_DEPLOY_ENABLED_FIELD, isAutoDeployEnabled); } - if (modelTTL != null) { - builder.field(MODEL_TTL_FIELD, modelTTL); + if (modelTTLInHours != null) { + builder.field(MODEL_TTL_HOURS_FIELD, modelTTLInHours); + } + if (modelTTLInMinutes != null) { + builder.field(MODEL_TTL_MINUTES_FIELD, modelTTLInMinutes); } builder.endObject(); return builder; diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java index 113248bbe0..f23b270fc3 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLDeployingSettingTests.java @@ -36,7 +36,7 @@ public class MLDeployingSettingTests { private MLDeploySetting deploySettingNull; - private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl\":-1}"; + private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl_hours\":-1,\"model_ttl_minutes\":-1}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -66,7 +66,7 @@ public void testToXContent() throws Exception { @Test public void testToXContentIncomplete() throws Exception { - final String expectedIncompleteInputStr = "{\"model_ttl\":-1}"; + final String expectedIncompleteInputStr = "{\"model_ttl_hours\":-1,\"model_ttl_minutes\":-1}"; String jsonStr = serializationWithToXContent(deploySettingNull); assertEquals(expectedIncompleteInputStr, jsonStr); @@ -109,12 +109,12 @@ public void parseWithIllegalArgumentInteger() throws Exception { @Test public void parseWithIllegalField() throws Exception { - final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl\":-1," + + final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl_hours\":0," + "\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { - assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + assertEquals("{\"is_auto_deploy_enabled\":true,\"model_ttl_hours\":0,\"model_ttl_minutes\":0}", serializationWithToXContent(parsedInput)); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 94bafc32db..29ce9a96aa 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -104,22 +104,19 @@ public void run() { // key is model id, value is set of worker node ids Map> deployingModels = new HashMap<>(); // key is expired model_id, value is set of worker node ids - Map> expiredModels = new HashMap<>(); + Map> expiredModelToNodess = new HashMap<>(); for (MLSyncUpNodeResponse response : responses) { String nodeId = response.getNode().getId(); String[] expiredModelIds = response.getExpiredModelIds(); if (expiredModelIds != null && expiredModelIds.length > 0) { Arrays .stream(expiredModelIds) - .forEach(modelId -> { expiredModels.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); }); + .forEach(modelId -> { expiredModelToNodess.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); }); } String[] deployedModelIds = response.getDeployedModelIds(); if (deployedModelIds != null && deployedModelIds.length > 0) { for (String modelId : deployedModelIds) { - if (expiredModels.containsKey(modelId)) { - continue; - } Set workerNodes = modelWorkerNodes.computeIfAbsent(modelId, it -> new HashSet<>()); workerNodes.add(nodeId); } @@ -140,6 +137,16 @@ public void run() { } } } + + Set modelsToUndeploy = new HashSet<>(); + for (String modelId : expiredModelToNodess.keySet()) { + if (expiredModelToNodess.get(modelId) == modelWorkerNodes.get(modelId)) { + // this model has expired in all the nodes + modelWorkerNodes.remove(modelId); + modelsToUndeploy.add(modelId); + } + } + for (Map.Entry> entry : modelWorkerNodes.entrySet()) { String modelId = entry.getKey(); log.debug("will sync model worker nodes for model: {}: {}", modelId, entry.getValue().toArray(new String[0])); @@ -169,7 +176,7 @@ public void run() { }) ); // Undeploy expired models - undeployExpiredModels(expiredModels.keySet(), modelWorkerNodes); + undeployExpiredModels(modelsToUndeploy, modelWorkerNodes); // refresh model status mlIndicesHandler diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 6152497033..2f64479665 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -357,10 +357,13 @@ public String[] getExpiredModels() { return false; // no TTL, never expire } Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now()); - Long timeToLive = mlModel.getDeploySetting().getModelTTL(); - boolean isModelExpired = (timeToLive != null - && timeToLive > 0 - && liveDuration.getSeconds() > Duration.ofHours(timeToLive).getSeconds()); + Long ttlInHour = mlModel.getDeploySetting().getModelTTLInHours(); + Long ttlInMinutes = mlModel.getDeploySetting().getModelTTLInMinutes(); + if (ttlInHour < 0 || ttlInMinutes < 0) { + return false; + } + Duration ttl = Duration.ofHours(ttlInHour).plusMinutes(ttlInMinutes); + boolean isModelExpired = liveDuration.getSeconds() > ttl.getSeconds(); return isModelExpired && mlModel.getModelState() == MLModelState.DEPLOYED; }).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]); }