Skip to content

Commit

Permalink
undeploy only for models that expired in all nodes
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Apr 29, 2024
1 parent 8e2f22e commit 3b682c9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,50 @@
@Getter
public class MLDeploySetting implements ToXContentObject, Writeable {
public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled";
public static final String MODEL_TTL_FIELD = "model_ttl";
public static final String MODEL_TTL_HOURS_FIELD = "model_ttl_hours";
public static final String MODEL_TTL_MINUTES_FIELD = "model_ttl_minutes";
private static final long DEFAULT_TTL_HOUR = -1;
private static final long DEFAULT_TTL_MINUTES = -1;

private Boolean isAutoDeployEnabled;
private Long modelTTL; // Time to live in hours
private Long modelTTLInHours; // Time to live in hours
private Long modelTTLInMinutes; // in minutes

@Builder(toBuilder = true)
public MLDeploySetting(Boolean isAutoDeployEnabled) {
public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTLInHours, Long modelTTLInMinutes) {
this.isAutoDeployEnabled = isAutoDeployEnabled;
this.modelTTL = DEFAULT_TTL_HOUR;
}
@Builder(toBuilder = true)
public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTL) {
this.isAutoDeployEnabled = isAutoDeployEnabled;
this.modelTTL = modelTTL;
this.modelTTLInHours = modelTTLInHours;
this.modelTTLInMinutes = modelTTLInMinutes;
if (modelTTLInHours == null && modelTTLInMinutes == null) {
this.modelTTLInHours = DEFAULT_TTL_HOUR;
this.modelTTLInMinutes = DEFAULT_TTL_MINUTES;
return;
}
if (modelTTLInHours == null) {
this.modelTTLInHours = 0L;
}
if (modelTTLInMinutes == null) {
this.modelTTLInMinutes = 0L;
}
}

public MLDeploySetting(StreamInput in) throws IOException {
this.isAutoDeployEnabled = in.readOptionalBoolean();
this.modelTTL = in.readOptionalLong();
this.modelTTLInHours = in.readOptionalLong();
this.modelTTLInMinutes = in.readOptionalLong();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalBoolean(isAutoDeployEnabled);
out.writeOptionalLong(modelTTL);
out.writeOptionalLong(modelTTLInHours);
out.writeOptionalLong(modelTTLInMinutes);
}

public static MLDeploySetting parse(XContentParser parser) throws IOException {
Boolean isAutoDeployEnabled = null;
Long modelTTL = null;
Long modelTTLHours = null;
Long modelTTLMinutes = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
Expand All @@ -64,14 +77,16 @@ public static MLDeploySetting parse(XContentParser parser) throws IOException {
case IS_AUTO_DEPLOY_ENABLED_FIELD:
isAutoDeployEnabled = parser.booleanValue();
break;
case MODEL_TTL_FIELD:
modelTTL = parser.longValue();
case MODEL_TTL_HOURS_FIELD:
modelTTLHours = parser.longValue();
case MODEL_TTL_MINUTES_FIELD:
modelTTLMinutes = parser.longValue();
default:
parser.skipChildren();
break;
}
}
return new MLDeploySetting(isAutoDeployEnabled, modelTTL);
return new MLDeploySetting(isAutoDeployEnabled, modelTTLHours, modelTTLMinutes);
}

@Override
Expand All @@ -80,8 +95,11 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (isAutoDeployEnabled != null) {
builder.field(IS_AUTO_DEPLOY_ENABLED_FIELD, isAutoDeployEnabled);
}
if (modelTTL != null) {
builder.field(MODEL_TTL_FIELD, modelTTL);
if (modelTTLInHours != null) {
builder.field(MODEL_TTL_HOURS_FIELD, modelTTLInHours);
}
if (modelTTLInMinutes != null) {
builder.field(MODEL_TTL_MINUTES_FIELD, modelTTLInMinutes);
}
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class MLDeployingSettingTests {

private MLDeploySetting deploySettingNull;

private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl\":-1}";
private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl_hours\":-1,\"model_ttl_minutes\":-1}";

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -66,7 +66,7 @@ public void testToXContent() throws Exception {

@Test
public void testToXContentIncomplete() throws Exception {
final String expectedIncompleteInputStr = "{\"model_ttl\":-1}";
final String expectedIncompleteInputStr = "{\"model_ttl_hours\":-1,\"model_ttl_minutes\":-1}";

String jsonStr = serializationWithToXContent(deploySettingNull);
assertEquals(expectedIncompleteInputStr, jsonStr);
Expand Down Expand Up @@ -109,12 +109,12 @@ public void parseWithIllegalArgumentInteger() throws Exception {

@Test
public void parseWithIllegalField() throws Exception {
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl\":-1," +
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl_hours\":0," +
"\"illegal_field\":\"This field need to be skipped.\"}";

testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> {
try {
assertEquals(expectedInputStr, serializationWithToXContent(parsedInput));
assertEquals("{\"is_auto_deploy_enabled\":true,\"model_ttl_hours\":0,\"model_ttl_minutes\":0}", serializationWithToXContent(parsedInput));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
19 changes: 13 additions & 6 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,19 @@ public void run() {
// key is model id, value is set of worker node ids
Map<String, Set<String>> deployingModels = new HashMap<>();
// key is expired model_id, value is set of worker node ids
Map<String, Set<String>> expiredModels = new HashMap<>();
Map<String, Set<String>> expiredModelToNodess = new HashMap<>();
for (MLSyncUpNodeResponse response : responses) {
String nodeId = response.getNode().getId();
String[] expiredModelIds = response.getExpiredModelIds();
if (expiredModelIds != null && expiredModelIds.length > 0) {
Arrays
.stream(expiredModelIds)
.forEach(modelId -> { expiredModels.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); });
.forEach(modelId -> { expiredModelToNodess.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); });
}

String[] deployedModelIds = response.getDeployedModelIds();
if (deployedModelIds != null && deployedModelIds.length > 0) {
for (String modelId : deployedModelIds) {
if (expiredModels.containsKey(modelId)) {
continue;
}
Set<String> workerNodes = modelWorkerNodes.computeIfAbsent(modelId, it -> new HashSet<>());
workerNodes.add(nodeId);
}
Expand All @@ -140,6 +137,16 @@ public void run() {
}
}
}

Set<String> modelsToUndeploy = new HashSet<>();
for (String modelId : expiredModelToNodess.keySet()) {
if (expiredModelToNodess.get(modelId) == modelWorkerNodes.get(modelId)) {
// this model has expired in all the nodes
modelWorkerNodes.remove(modelId);
modelsToUndeploy.add(modelId);
}
}

for (Map.Entry<String, Set<String>> entry : modelWorkerNodes.entrySet()) {
String modelId = entry.getKey();
log.debug("will sync model worker nodes for model: {}: {}", modelId, entry.getValue().toArray(new String[0]));
Expand Down Expand Up @@ -169,7 +176,7 @@ public void run() {
})
);
// Undeploy expired models
undeployExpiredModels(expiredModels.keySet(), modelWorkerNodes);
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes);

// refresh model status
mlIndicesHandler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,13 @@ public String[] getExpiredModels() {
return false; // no TTL, never expire
}
Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now());
Long timeToLive = mlModel.getDeploySetting().getModelTTL();
boolean isModelExpired = (timeToLive != null
&& timeToLive > 0
&& liveDuration.getSeconds() > Duration.ofHours(timeToLive).getSeconds());
Long ttlInHour = mlModel.getDeploySetting().getModelTTLInHours();
Long ttlInMinutes = mlModel.getDeploySetting().getModelTTLInMinutes();
if (ttlInHour < 0 || ttlInMinutes < 0) {
return false;
}
Duration ttl = Duration.ofHours(ttlInHour).plusMinutes(ttlInMinutes);
boolean isModelExpired = liveDuration.getSeconds() > ttl.getSeconds();
return isModelExpired && mlModel.getModelState() == MLModelState.DEPLOYED;
}).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
}
Expand Down

0 comments on commit 3b682c9

Please sign in to comment.