Skip to content

Commit

Permalink
Add TTL to un-deploy model automatically (#2365) (#2374)
Browse files Browse the repository at this point in the history
* add ttl to un-deploy model automatically

Signed-off-by: Xun Zhang <[email protected]>

* undeploy only for models that expired in all nodes

Signed-off-by: Xun Zhang <[email protected]>

* add bwc version for model ttl

Signed-off-by: Xun Zhang <[email protected]>

* only use minutes in ttl

Signed-off-by: Xun Zhang <[email protected]>

* move MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL to MLDeploySetting

Signed-off-by: Xun Zhang <[email protected]>

---------

Signed-off-by: Xun Zhang <[email protected]>
(cherry picked from commit e380395)

Co-authored-by: Xun Zhang <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and Zhangxunmt authored Apr 30, 2024
1 parent 7e76fa9 commit ef435c9
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -24,25 +25,42 @@
@Getter
public class MLDeploySetting implements ToXContentObject, Writeable {
public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled";
public static final String MODEL_TTL_MINUTES_FIELD = "model_ttl_minutes";
private static final long DEFAULT_TTL_MINUTES = -1;
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL = Version.V_2_14_0;

private Boolean isAutoDeployEnabled;
private Long modelTTLInMinutes; // in minutes

@Builder(toBuilder = true)
public MLDeploySetting(Boolean isAutoDeployEnabled) {
public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTLInMinutes) {
this.isAutoDeployEnabled = isAutoDeployEnabled;
this.modelTTLInMinutes = modelTTLInMinutes;
if (modelTTLInMinutes == null) {
this.modelTTLInMinutes = DEFAULT_TTL_MINUTES;
}
}

public MLDeploySetting(StreamInput in) throws IOException {
this.isAutoDeployEnabled = in.readOptionalBoolean();
Version streamInputVersion = in.getVersion();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
this.modelTTLInMinutes = in.readOptionalLong();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeOptionalBoolean(isAutoDeployEnabled);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
out.writeOptionalLong(modelTTLInMinutes);
}
}

public static MLDeploySetting parse(XContentParser parser) throws IOException {
Boolean isAutoDeployEnabled = null;
Long modelTTLMinutes = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
Expand All @@ -51,12 +69,14 @@ public static MLDeploySetting parse(XContentParser parser) throws IOException {
case IS_AUTO_DEPLOY_ENABLED_FIELD:
isAutoDeployEnabled = parser.booleanValue();
break;
case MODEL_TTL_MINUTES_FIELD:
modelTTLMinutes = parser.longValue();
default:
parser.skipChildren();
break;
}
}
return new MLDeploySetting(isAutoDeployEnabled);
return new MLDeploySetting(isAutoDeployEnabled, modelTTLMinutes);
}

@Override
Expand All @@ -65,6 +85,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (isAutoDeployEnabled != null) {
builder.field(IS_AUTO_DEPLOY_ENABLED_FIELD, isAutoDeployEnabled);
}
if (modelTTLInMinutes != null) {
builder.field(MODEL_TTL_MINUTES_FIELD, modelTTLInMinutes);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.model.MLDeploySetting;

import java.io.IOException;

Expand All @@ -22,22 +24,28 @@ public class MLSyncUpNodeResponse extends BaseNodeResponse {
private String[] deployedModelIds;
private String[] runningDeployModelIds; // model ids which have deploying model task running
private String[] runningDeployModelTaskIds; // deploy model task ids which is running
private String[] expiredModelIds;

public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] deployedModelIds, String[] runningDeployModelIds,
String[] runningDeployModelTaskIds) {
String[] runningDeployModelTaskIds, String[] expiredModelIds) {
super(node);
this.modelStatus = modelStatus;
this.deployedModelIds = deployedModelIds;
this.runningDeployModelIds = runningDeployModelIds;
this.runningDeployModelTaskIds = runningDeployModelTaskIds;
this.expiredModelIds = expiredModelIds;
}

public MLSyncUpNodeResponse(StreamInput in) throws IOException {
super(in);
Version streamInputVersion = in.getVersion();
this.modelStatus = in.readOptionalString();
this.deployedModelIds = in.readOptionalStringArray();
this.runningDeployModelIds = in.readOptionalStringArray();
this.runningDeployModelTaskIds = in.readOptionalStringArray();
if (streamInputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
this.expiredModelIds = in.readOptionalStringArray();
}
}

public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException {
Expand All @@ -46,11 +54,14 @@ public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
super.writeTo(out);
out.writeOptionalString(modelStatus);
out.writeOptionalStringArray(deployedModelIds);
out.writeOptionalStringArray(runningDeployModelIds);
out.writeOptionalStringArray(runningDeployModelTaskIds);
if (streamOutputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
out.writeOptionalStringArray(expiredModelIds);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class MLDeployingSettingTests {

private MLDeploySetting deploySettingNull;

private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true}";
private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl_minutes\":-1}";

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -66,7 +66,7 @@ public void testToXContent() throws Exception {

@Test
public void testToXContentIncomplete() throws Exception {
final String expectedIncompleteInputStr = "{}";
final String expectedIncompleteInputStr = "{\"model_ttl_minutes\":-1}";

String jsonStr = serializationWithToXContent(deploySettingNull);
assertEquals(expectedIncompleteInputStr, jsonStr);
Expand Down Expand Up @@ -109,12 +109,12 @@ public void parseWithIllegalArgumentInteger() throws Exception {

@Test
public void parseWithIllegalField() throws Exception {
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," +
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl_hours\":0," +
"\"illegal_field\":\"This field need to be skipped.\"}";

testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> {
try {
assertEquals(expectedInputStr, serializationWithToXContent(parsedInput));
assertEquals("{\"is_auto_deploy_enabled\":true,\"model_ttl_minutes\":-1}", serializationWithToXContent(parsedInput));
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class MLSyncUpNodeResponseTest {
private final String[] loadedModelIds = {"loadedModelIds"};
private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"};
private final String[] runningLoadModelIds = {"modelid1"};

private final String[] expiredModelIds = {"modelExpired"};
@Before
public void setUp() throws Exception {
localNode = new DiscoveryNode(
Expand All @@ -38,7 +40,7 @@ public void setUp() throws Exception {

@Test
public void testSerializationDeserialization() throws IOException {
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput());
Expand All @@ -51,7 +53,7 @@ public void testSerializationDeserialization() throws IOException {

@Test
public void testReadProfile() throws IOException {
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ private void executePredict(
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
modelCacheHelper.refreshLastAccessTime(modelId);
log.debug("completed predict request " + requestId + " for model " + modelId);
})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
String[] deployedModelIds = null;
String[] runningDeployModelTaskIds = null;
String[] runningDeployModelIds = null;
String[] expiredModelIds = null;
if (syncUpInput.isGetDeployedModels()) {
deployedModelIds = mlModelManager.getLocalDeployedModels();
List<String[]> localRunningDeployModel = mlTaskManager.getLocalRunningDeployModelTasks();
runningDeployModelTaskIds = localRunningDeployModel.get(0);
runningDeployModelIds = localRunningDeployModel.get(1);
expiredModelIds = mlModelManager.getExpiredModels();
}

if (syncUpInput.isClearRoutingTable()) {
Expand All @@ -186,7 +188,8 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
"ok",
deployedModelIds,
runningDeployModelIds,
runningDeployModelTaskIds
runningDeployModelTaskIds,
expiredModelIds
);
}

Expand Down
38 changes: 38 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -101,8 +103,17 @@ public void run() {
Map<String, Set<String>> runningDeployModelTasks = new HashMap<>();
// key is model id, value is set of worker node ids
Map<String, Set<String>> deployingModels = new HashMap<>();
// key is expired model_id, value is set of worker node ids
Map<String, Set<String>> expiredModelToNodes = new HashMap<>();
for (MLSyncUpNodeResponse response : responses) {
String nodeId = response.getNode().getId();
String[] expiredModelIds = response.getExpiredModelIds();
if (expiredModelIds != null && expiredModelIds.length > 0) {
Arrays
.stream(expiredModelIds)
.forEach(modelId -> { expiredModelToNodes.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); });
}

String[] deployedModelIds = response.getDeployedModelIds();
if (deployedModelIds != null && deployedModelIds.length > 0) {
for (String modelId : deployedModelIds) {
Expand All @@ -126,6 +137,17 @@ public void run() {
}
}
}

Set<String> modelsToUndeploy = new HashSet<>();
for (String modelId : expiredModelToNodes.keySet()) {
if (modelWorkerNodes.containsKey(modelId)
&& expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) {
// this model has expired in all the nodes
modelWorkerNodes.remove(modelId);
modelsToUndeploy.add(modelId);
}
}

for (Map.Entry<String, Set<String>> entry : modelWorkerNodes.entrySet()) {
String modelId = entry.getKey();
log.debug("will sync model worker nodes for model: {}: {}", modelId, entry.getValue().toArray(new String[0]));
Expand Down Expand Up @@ -154,6 +176,8 @@ public void run() {
log.error("Failed to sync model routing", ex);
})
);
// Undeploy expired models
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes);

// refresh model status
mlIndicesHandler
Expand All @@ -163,6 +187,20 @@ public void run() {
}, e -> { log.error("Failed to sync model routing", e); }));
}

private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes) {
expiredModels.forEach(modelId -> {
String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]);

MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(
targetNodeIds,
new String[] { modelId }
);
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
log.debug("model {} is un_deployed", modelId);
}, e -> { log.error("Failed to undeploy model {}", modelId, e); }));
});
}

@VisibleForTesting
void initMLConfig() {
if (mlConfigInited) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.model;

import java.time.Instant;
import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -51,6 +52,7 @@ public class MLModelCache {
// In rare case, this could be null, e.g. model info not synced up yet a predict request comes in.
@Setter
private Boolean deployToAllNodes;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime;

public MLModelCache() {
targetWorkerNodes = ConcurrentHashMap.newKeySet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT;

import java.time.Duration;
import java.time.Instant;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -68,6 +70,7 @@ public synchronized void initModelState(
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCache.setDeployToAllNodes(deployToAllNodes);
modelCache.setLastAccessTime(Instant.now());
modelCaches.put(modelId, modelCache);
}

Expand All @@ -87,6 +90,7 @@ public synchronized void initModelStateLocal(
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCache.setDeployToAllNodes(false);
modelCache.setLastAccessTime(Instant.now());
modelCaches.put(modelId, modelCache);
}

Expand Down Expand Up @@ -341,6 +345,29 @@ public String[] getLocalDeployedModels() {
.toArray(new String[0]);
}

/**
* Get expired models on node.
*
* @return array of expired model id
*/
public String[] getExpiredModels() {
return modelCaches.entrySet().stream().filter(entry -> {
MLModel mlModel = entry.getValue().getCachedModelInfo();
if (mlModel.getDeploySetting() == null) {
return false; // no TTL, never expire
}
Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now());
Long ttlInMinutes = mlModel.getDeploySetting().getModelTTLInMinutes();
if (ttlInMinutes < 0) {
return false;
}
Duration ttl = Duration.ofMinutes(ttlInMinutes);
boolean isModelExpired = liveDuration.getSeconds() >= ttl.getSeconds();
return isModelExpired
&& (mlModel.getModelState() == MLModelState.DEPLOYED || mlModel.getModelState() == MLModelState.PARTIALLY_DEPLOYED);
}).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
}

/**
* Check if model is running on node.
*
Expand Down Expand Up @@ -403,6 +430,16 @@ public void setTargetWorkerNodes(String modelId, List<String> targetWorkerNodes)
}
}

/**
* Set the last access time to Instant.now()
*
* @param modelId model id
*/
public void refreshLastAccessTime(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
modelCache.setLastAccessTime(Instant.now());
}

/**
* Remove model.
*
Expand Down
Loading

0 comments on commit ef435c9

Please sign in to comment.