Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TTL to un-deploy model automatically #2365

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -24,25 +25,42 @@
@Getter
public class MLDeploySetting implements ToXContentObject, Writeable {
public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled";
public static final String MODEL_TTL_MINUTES_FIELD = "model_ttl_minutes";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why separate hours and minutes? Seems use minutes should be enough.

Copy link
Collaborator Author

@Zhangxunmt Zhangxunmt Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially only used Hours. But it may take too much time to test expiration in hours. Having both hours and minutes cover most cases. Most likely only hours will be used, e.g. 24 hours, 72 hours (6 days), etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use minutes ? 24 hours will be 24 * 60,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be easier to set a longer time in hours? For 1 week of 7 days, it would be 10080 minutes, but this is not an obvious number to understand the time length.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't require both hours and minutes, only 1 of them will create a valid deploySetting. If both of them are null, default values are -1 and this model will never expire.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest just use the minimal time unit for this, that could be easier. If you add hours, why not add days and months?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

private static final long DEFAULT_TTL_MINUTES = -1;
public static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL = Version.V_2_14_0;

private Boolean isAutoDeployEnabled;
private Long modelTTLInMinutes; // in minutes

@Builder(toBuilder = true)
public MLDeploySetting(Boolean isAutoDeployEnabled) {
public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTLInMinutes) {
this.isAutoDeployEnabled = isAutoDeployEnabled;
this.modelTTLInMinutes = modelTTLInMinutes;
if (modelTTLInMinutes == null) {
this.modelTTLInMinutes = DEFAULT_TTL_MINUTES;
}
}

public MLDeploySetting(StreamInput in) throws IOException {
this.isAutoDeployEnabled = in.readOptionalBoolean();
Version streamInputVersion = in.getVersion();
if (streamInputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
this.modelTTLInMinutes = in.readOptionalLong();
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
out.writeOptionalBoolean(isAutoDeployEnabled);
if (streamOutputVersion.onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
out.writeOptionalLong(modelTTLInMinutes);
}
}

public static MLDeploySetting parse(XContentParser parser) throws IOException {
Boolean isAutoDeployEnabled = null;
Long modelTTLMinutes = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
Expand All @@ -51,12 +69,14 @@ public static MLDeploySetting parse(XContentParser parser) throws IOException {
case IS_AUTO_DEPLOY_ENABLED_FIELD:
isAutoDeployEnabled = parser.booleanValue();
break;
case MODEL_TTL_MINUTES_FIELD:
modelTTLMinutes = parser.longValue();
default:
parser.skipChildren();
break;
}
}
return new MLDeploySetting(isAutoDeployEnabled);
return new MLDeploySetting(isAutoDeployEnabled, modelTTLMinutes);
}

@Override
Expand All @@ -65,6 +85,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (isAutoDeployEnabled != null) {
builder.field(IS_AUTO_DEPLOY_ENABLED_FIELD, isAutoDeployEnabled);
}
if (modelTTLInMinutes != null) {
builder.field(MODEL_TTL_MINUTES_FIELD, modelTTLInMinutes);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Data;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.action.support.nodes.BaseNodeResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.model.MLDeploySetting;

import java.io.IOException;

Expand All @@ -22,22 +24,28 @@ public class MLSyncUpNodeResponse extends BaseNodeResponse {
private String[] deployedModelIds;
private String[] runningDeployModelIds; // model ids which have deploying model task running
private String[] runningDeployModelTaskIds; // deploy model task ids which is running
private String[] expiredModelIds;

public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] deployedModelIds, String[] runningDeployModelIds,
String[] runningDeployModelTaskIds) {
String[] runningDeployModelTaskIds, String[] expiredModelIds) {
super(node);
this.modelStatus = modelStatus;
this.deployedModelIds = deployedModelIds;
this.runningDeployModelIds = runningDeployModelIds;
this.runningDeployModelTaskIds = runningDeployModelTaskIds;
this.expiredModelIds = expiredModelIds;
}

public MLSyncUpNodeResponse(StreamInput in) throws IOException {
super(in);
Version streamInputVersion = in.getVersion();
this.modelStatus = in.readOptionalString();
this.deployedModelIds = in.readOptionalStringArray();
this.runningDeployModelIds = in.readOptionalStringArray();
this.runningDeployModelTaskIds = in.readOptionalStringArray();
if (streamInputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
this.expiredModelIds = in.readOptionalStringArray();
}
}

public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException {
Expand All @@ -46,11 +54,14 @@ public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException

@Override
public void writeTo(StreamOutput out) throws IOException {
Version streamOutputVersion = out.getVersion();
super.writeTo(out);
out.writeOptionalString(modelStatus);
out.writeOptionalStringArray(deployedModelIds);
out.writeOptionalStringArray(runningDeployModelIds);
out.writeOptionalStringArray(runningDeployModelTaskIds);
if (streamOutputVersion.onOrAfter(MLDeploySetting.MINIMAL_SUPPORTED_VERSION_FOR_MODEL_TTL)) {
out.writeOptionalStringArray(expiredModelIds);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class MLDeployingSettingTests {

private MLDeploySetting deploySettingNull;

private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true}";
private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl_minutes\":-1}";

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -66,7 +66,7 @@ public void testToXContent() throws Exception {

@Test
public void testToXContentIncomplete() throws Exception {
final String expectedIncompleteInputStr = "{}";
final String expectedIncompleteInputStr = "{\"model_ttl_minutes\":-1}";

String jsonStr = serializationWithToXContent(deploySettingNull);
assertEquals(expectedIncompleteInputStr, jsonStr);
Expand Down Expand Up @@ -109,12 +109,12 @@ public void parseWithIllegalArgumentInteger() throws Exception {

@Test
public void parseWithIllegalField() throws Exception {
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," +
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl_hours\":0," +
"\"illegal_field\":\"This field need to be skipped.\"}";

testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> {
try {
assertEquals(expectedInputStr, serializationWithToXContent(parsedInput));
assertEquals("{\"is_auto_deploy_enabled\":true,\"model_ttl_minutes\":-1}", serializationWithToXContent(parsedInput));
Copy link
Collaborator

@ylwu-amzn ylwu-amzn Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if model_ttl_minutes is 0 ? Will the model stay on node forever ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the opposite, with model_ttl_minutes == 0, the model will be removed in the next cron job running cycle right away. By default, without any given value for this ttl in minutes, model_ttl_minutes == -1 and it will make the model never expire.

} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class MLSyncUpNodeResponseTest {
private final String[] loadedModelIds = {"loadedModelIds"};
private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"};
private final String[] runningLoadModelIds = {"modelid1"};

private final String[] expiredModelIds = {"modelExpired"};
@Before
public void setUp() throws Exception {
localNode = new DiscoveryNode(
Expand All @@ -38,7 +40,7 @@ public void setUp() throws Exception {

@Test
public void testSerializationDeserialization() throws IOException {
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput());
Expand All @@ -51,7 +53,7 @@ public void testSerializationDeserialization() throws IOException {

@Test
public void testReadProfile() throws IOException {
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ private void executePredict(
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
modelCacheHelper.refreshLastAccessTime(modelId);
log.debug("completed predict request " + requestId + " for model " + modelId);
})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
String[] deployedModelIds = null;
String[] runningDeployModelTaskIds = null;
String[] runningDeployModelIds = null;
String[] expiredModelIds = null;
if (syncUpInput.isGetDeployedModels()) {
deployedModelIds = mlModelManager.getLocalDeployedModels();
List<String[]> localRunningDeployModel = mlTaskManager.getLocalRunningDeployModelTasks();
runningDeployModelTaskIds = localRunningDeployModel.get(0);
runningDeployModelIds = localRunningDeployModel.get(1);
expiredModelIds = mlModelManager.getExpiredModels();
}

if (syncUpInput.isClearRoutingTable()) {
Expand All @@ -186,7 +188,8 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
"ok",
deployedModelIds,
runningDeployModelIds,
runningDeployModelTaskIds
runningDeployModelTaskIds,
expiredModelIds
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -101,8 +103,17 @@ public void run() {
Map<String, Set<String>> runningDeployModelTasks = new HashMap<>();
// key is model id, value is set of worker node ids
Map<String, Set<String>> deployingModels = new HashMap<>();
// key is expired model_id, value is set of worker node ids
Map<String, Set<String>> expiredModelToNodes = new HashMap<>();
for (MLSyncUpNodeResponse response : responses) {
String nodeId = response.getNode().getId();
String[] expiredModelIds = response.getExpiredModelIds();
if (expiredModelIds != null && expiredModelIds.length > 0) {
Arrays
.stream(expiredModelIds)
.forEach(modelId -> { expiredModelToNodes.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); });
}

String[] deployedModelIds = response.getDeployedModelIds();
if (deployedModelIds != null && deployedModelIds.length > 0) {
for (String modelId : deployedModelIds) {
Expand All @@ -126,6 +137,17 @@ public void run() {
}
}
}

Set<String> modelsToUndeploy = new HashSet<>();
for (String modelId : expiredModelToNodes.keySet()) {
if (modelWorkerNodes.containsKey(modelId)
&& expiredModelToNodes.get(modelId).size() == modelWorkerNodes.get(modelId).size()) {
// this model has expired in all the nodes
modelWorkerNodes.remove(modelId);
Copy link
Collaborator

@ylwu-amzn ylwu-amzn Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cron job use modelWorkerNodes to check if model has been deployed to any nodes (check line 349 of this class). This will cause sync up cron job set model as deploy_failed.

modelsToUndeploy.add(modelId);
}
}

for (Map.Entry<String, Set<String>> entry : modelWorkerNodes.entrySet()) {
String modelId = entry.getKey();
log.debug("will sync model worker nodes for model: {}: {}", modelId, entry.getValue().toArray(new String[0]));
Expand Down Expand Up @@ -154,6 +176,8 @@ public void run() {
log.error("Failed to sync model routing", ex);
})
);
// Undeploy expired models
undeployExpiredModels(modelsToUndeploy, modelWorkerNodes);

// refresh model status
mlIndicesHandler
Expand All @@ -163,6 +187,20 @@ public void run() {
}, e -> { log.error("Failed to sync model routing", e); }));
}

private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes) {
expiredModels.forEach(modelId -> {
String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]);

MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(
targetNodeIds,
new String[] { modelId }
);
client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
log.debug("model {} is un_deployed", modelId);
}, e -> { log.error("Failed to undeploy model {}", modelId, e); }));
});
}

@VisibleForTesting
void initMLConfig() {
if (mlConfigInited) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.model;

import java.time.Instant;
import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -51,6 +52,7 @@ public class MLModelCache {
// In rare case, this could be null, e.g. model info not synced up yet a predict request comes in.
@Setter
private Boolean deployToAllNodes;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime;

public MLModelCache() {
targetWorkerNodes = ConcurrentHashMap.newKeySet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT;

import java.time.Duration;
import java.time.Instant;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -68,6 +70,7 @@ public synchronized void initModelState(
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCache.setDeployToAllNodes(deployToAllNodes);
modelCache.setLastAccessTime(Instant.now());
modelCaches.put(modelId, modelCache);
}

Expand All @@ -87,6 +90,7 @@ public synchronized void initModelStateLocal(
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCache.setDeployToAllNodes(false);
modelCache.setLastAccessTime(Instant.now());
modelCaches.put(modelId, modelCache);
}

Expand Down Expand Up @@ -341,6 +345,29 @@ public String[] getLocalDeployedModels() {
.toArray(new String[0]);
}

/**
* Get expired models on node.
*
* @return array of expired model id
*/
public String[] getExpiredModels() {
return modelCaches.entrySet().stream().filter(entry -> {
MLModel mlModel = entry.getValue().getCachedModelInfo();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cached model info only set when predict request comes. Check https://github.com/opensearch-project/ml-commons/pull/1472/files

For a deployed model, if no predict request comes, the cached model info is null. Then the mlModel will be null and line 356 will throw null pointer exception, then sync up cron job will get no response.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this line modelCacheHelper.setModelInfo(modelId, mlModel); when deploy model should fix the issue

if (mlModel.getDeploySetting() == null) {
return false; // no TTL, never expire
}
Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now());
Long ttlInMinutes = mlModel.getDeploySetting().getModelTTLInMinutes();
if (ttlInMinutes < 0) {
return false;
}
Duration ttl = Duration.ofMinutes(ttlInMinutes);
boolean isModelExpired = liveDuration.getSeconds() >= ttl.getSeconds();
return isModelExpired
&& (mlModel.getModelState() == MLModelState.DEPLOYED || mlModel.getModelState() == MLModelState.PARTIALLY_DEPLOYED);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we care the model status for model TTL ? For example, if we have some bug, and model stays on DEPLOY_FAILED, I think TTL could be a good way to remove model to avoid memory leak.

Copy link
Collaborator Author

@Zhangxunmt Zhangxunmt Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Deploy_Failed may indicate some insights to the cluster or ml service. Auto undeploying the model won't fix the problem but only hide it. We shouldn't use this as a tool to clear the failure status? Otherwise, the model will likely be in Failed status when it's deployed again. Let's be conservative in the first release for this mechanism

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have auto deploy, if model deploy stuck there, why shouldn't we use TTL to undeploy model to remove the unnecessary memory usage?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's doable. But is it hiding problems that may bite us more in the future? When the deploy stuck, why shouldn't we look into the reasons first and then un-deploy if it's the real solution? Also the auto deploy is only for remote model which takes little memory. So there might be other different reasons for the deployment failure like the one I showed in the meeting?

Copy link
Collaborator

@ylwu-amzn ylwu-amzn Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, you can add a todo and test more

}).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
}

/**
* Check if model is running on node.
*
Expand Down Expand Up @@ -403,6 +430,16 @@ public void setTargetWorkerNodes(String modelId, List<String> targetWorkerNodes)
}
}

/**
* Set the last access time to Instant.now()
*
* @param modelId model id
*/
public void refreshLastAccessTime(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
modelCache.setLastAccessTime(Instant.now());
}

/**
* Remove model.
*
Expand Down
Loading
Loading