diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java index 0237bf58ec..4bd864b237 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLDeploySetting.java @@ -16,7 +16,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import java.io.IOException; @@ -28,6 +27,7 @@ public class MLDeploySetting implements ToXContentObject, Writeable { public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled"; public static final String MODEL_TTL_MINUTES_FIELD = "model_ttl_minutes"; private static final long DEFAULT_TTL_MINUTES = -1; + public static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL = Version.V_2_14_0; private Boolean isAutoDeployEnabled; private Long modelTTLInMinutes; // in minutes @@ -44,7 +44,7 @@ public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTLInMinutes) { public MLDeploySetting(StreamInput in) throws IOException { this.isAutoDeployEnabled = in.readOptionalBoolean(); Version streamInputVersion = in.getVersion(); - if (streamInputVersion.onOrAfter(MLSyncUpInput.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { + if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { this.modelTTLInMinutes = in.readOptionalLong(); } } @@ -53,7 +53,7 @@ public MLDeploySetting(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { Version streamOutputVersion = out.getVersion(); out.writeOptionalBoolean(isAutoDeployEnabled); - if (streamOutputVersion.onOrAfter(MLSyncUpInput.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { + if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { out.writeOptionalLong(modelTTLInMinutes); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java index 90ff0b614d..7ad34321b8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java @@ -36,8 +36,6 @@ public class MLSyncUpInput implements Writeable { // profile API has consistent data with model index. private Map deployToAllNodes; - public static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL = Version.V_2_14_0; - @Builder public MLSyncUpInput(boolean getDeployedModels, Map addedWorkerNodes, diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java index b77f96faa9..74893ec91e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java @@ -12,6 +12,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.model.MLDeploySetting; import java.io.IOException; @@ -42,7 +43,7 @@ public MLSyncUpNodeResponse(StreamInput in) throws IOException { this.deployedModelIds = in.readOptionalStringArray(); this.runningDeployModelIds = in.readOptionalStringArray(); this.runningDeployModelTaskIds = in.readOptionalStringArray(); - if (streamInputVersion.onOrAfter(MLSyncUpInput.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { + if (streamInputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { this.expiredModelIds = in.readOptionalStringArray(); } } @@ -59,7 +60,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalStringArray(deployedModelIds); out.writeOptionalStringArray(runningDeployModelIds); out.writeOptionalStringArray(runningDeployModelTaskIds); - if (streamOutputVersion.onOrAfter(MLSyncUpInput.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { + if (streamOutputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) { out.writeOptionalStringArray(expiredModelIds); } } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 3fee52fc37..22c16d46ca 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -140,7 +140,7 @@ public void run() { Set modelsToUndeploy = new HashSet<>(); for (String modelId : expiredModelToNodes.keySet()) { - if (expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) { + if (modelWorkerNodes.containsKey(modelId) && expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) { // this model has expired in all the nodes modelWorkerNodes.remove(modelId); modelsToUndeploy.add(modelId);