Skip to content

Commit

Permalink
add ttl to un-deploy model automatically
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Apr 29, 2024
1 parent c9758ca commit 8e2f22e
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,45 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.time.Duration;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@Setter
@Getter
public class MLDeploySetting implements ToXContentObject, Writeable {
public static final String IS_AUTO_DEPLOY_ENABLED_FIELD = "is_auto_deploy_enabled";
public static final String MODEL_TTL_FIELD = "model_ttl";
private static final long DEFAULT_TTL_HOUR = -1;

private Boolean isAutoDeployEnabled;
private Long modelTTL; // Time to live in hours

@Builder(toBuilder = true)
public MLDeploySetting(Boolean isAutoDeployEnabled) {
this.isAutoDeployEnabled = isAutoDeployEnabled;
this.modelTTL = DEFAULT_TTL_HOUR;
}
@Builder(toBuilder = true)
public MLDeploySetting(Boolean isAutoDeployEnabled, Long modelTTL) {
this.isAutoDeployEnabled = isAutoDeployEnabled;
this.modelTTL = modelTTL;
}

public MLDeploySetting(StreamInput in) throws IOException {
this.isAutoDeployEnabled = in.readOptionalBoolean();
this.modelTTL = in.readOptionalLong();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalBoolean(isAutoDeployEnabled);
out.writeOptionalLong(modelTTL);
}

public static MLDeploySetting parse(XContentParser parser) throws IOException {
Boolean isAutoDeployEnabled = null;
Long modelTTL = null;
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
Expand All @@ -51,12 +64,14 @@ public static MLDeploySetting parse(XContentParser parser) throws IOException {
case IS_AUTO_DEPLOY_ENABLED_FIELD:
isAutoDeployEnabled = parser.booleanValue();
break;
case MODEL_TTL_FIELD:
modelTTL = parser.longValue();
default:
parser.skipChildren();
break;
}
}
return new MLDeploySetting(isAutoDeployEnabled);
return new MLDeploySetting(isAutoDeployEnabled, modelTTL);
}

@Override
Expand All @@ -65,6 +80,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (isAutoDeployEnabled != null) {
builder.field(IS_AUTO_DEPLOY_ENABLED_FIELD, isAutoDeployEnabled);
}
if (modelTTL != null) {
builder.field(MODEL_TTL_FIELD, modelTTL);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ public class MLSyncUpNodeResponse extends BaseNodeResponse {
private String[] deployedModelIds;
private String[] runningDeployModelIds; // model ids which have deploying model task running
private String[] runningDeployModelTaskIds; // deploy model task ids which is running
private String[] expiredModelIds;

public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] deployedModelIds, String[] runningDeployModelIds,
String[] runningDeployModelTaskIds) {
String[] runningDeployModelTaskIds, String[] expiredModelIds) {
super(node);
this.modelStatus = modelStatus;
this.deployedModelIds = deployedModelIds;
this.runningDeployModelIds = runningDeployModelIds;
this.runningDeployModelTaskIds = runningDeployModelTaskIds;
this.expiredModelIds = expiredModelIds;
}

public MLSyncUpNodeResponse(StreamInput in) throws IOException {
Expand All @@ -38,6 +40,7 @@ public MLSyncUpNodeResponse(StreamInput in) throws IOException {
this.deployedModelIds = in.readOptionalStringArray();
this.runningDeployModelIds = in.readOptionalStringArray();
this.runningDeployModelTaskIds = in.readOptionalStringArray();
this.expiredModelIds = in.readOptionalStringArray();
}

public static MLSyncUpNodeResponse readStats(StreamInput in) throws IOException {
Expand All @@ -51,6 +54,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalStringArray(deployedModelIds);
out.writeOptionalStringArray(runningDeployModelIds);
out.writeOptionalStringArray(runningDeployModelTaskIds);
out.writeOptionalStringArray(expiredModelIds);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class MLDeployingSettingTests {

private MLDeploySetting deploySettingNull;

private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true}";
private final String expectedInputStr = "{\"is_auto_deploy_enabled\":true,\"model_ttl\":-1}";

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
Expand Down Expand Up @@ -66,7 +66,7 @@ public void testToXContent() throws Exception {

@Test
public void testToXContentIncomplete() throws Exception {
final String expectedIncompleteInputStr = "{}";
final String expectedIncompleteInputStr = "{\"model_ttl\":-1}";

String jsonStr = serializationWithToXContent(deploySettingNull);
assertEquals(expectedIncompleteInputStr, jsonStr);
Expand Down Expand Up @@ -109,7 +109,7 @@ public void parseWithIllegalArgumentInteger() throws Exception {

@Test
public void parseWithIllegalField() throws Exception {
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," +
final String expectedInputStrWithIllegalField = "{\"is_auto_deploy_enabled\":true," + "\"model_ttl\":-1," +
"\"illegal_field\":\"This field need to be skipped.\"}";

testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class MLSyncUpNodeResponseTest {
private final String[] loadedModelIds = {"loadedModelIds"};
private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"};
private final String[] runningLoadModelIds = {"modelid1"};

private final String[] expiredModelIds = {"modelExpired"};
@Before
public void setUp() throws Exception {
localNode = new DiscoveryNode(
Expand All @@ -38,7 +40,7 @@ public void setUp() throws Exception {

@Test
public void testSerializationDeserialization() throws IOException {
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput());
Expand All @@ -51,7 +53,7 @@ public void testSerializationDeserialization() throws IOException {

@Test
public void testReadProfile() throws IOException {
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds);
MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds, expiredModelIds);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ private void executePredict(
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
modelCacheHelper.refreshLastAccessTime(modelId);
log.debug("completed predict request " + requestId + " for model " + modelId);
})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
String[] deployedModelIds = null;
String[] runningDeployModelTaskIds = null;
String[] runningDeployModelIds = null;
String[] expiredModelIds = null;
if (syncUpInput.isGetDeployedModels()) {
deployedModelIds = mlModelManager.getLocalDeployedModels();
List<String[]> localRunningDeployModel = mlTaskManager.getLocalRunningDeployModelTasks();
runningDeployModelTaskIds = localRunningDeployModel.get(0);
runningDeployModelIds = localRunningDeployModel.get(1);
expiredModelIds = mlModelManager.getExpiredModels();
}

if (syncUpInput.isClearRoutingTable()) {
Expand All @@ -186,7 +188,8 @@ private MLSyncUpNodeResponse createSyncUpNodeResponse(MLSyncUpNodesRequest syncU
"ok",
deployedModelIds,
runningDeployModelIds,
runningDeployModelTaskIds
runningDeployModelTaskIds,
expiredModelIds
);
}

Expand Down
31 changes: 31 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesRequest;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -101,11 +103,23 @@ public void run() {
Map<String, Set<String>> runningDeployModelTasks = new HashMap<>();
// key is model id, value is set of worker node ids
Map<String, Set<String>> deployingModels = new HashMap<>();
// key is expired model_id, value is set of worker node ids
Map<String, Set<String>> expiredModels = new HashMap<>();
for (MLSyncUpNodeResponse response : responses) {
String nodeId = response.getNode().getId();
String[] expiredModelIds = response.getExpiredModelIds();
if (expiredModelIds != null && expiredModelIds.length > 0) {
Arrays
.stream(expiredModelIds)
.forEach(modelId -> { expiredModels.computeIfAbsent(modelId, it -> new HashSet<>()).add(nodeId); });
}

String[] deployedModelIds = response.getDeployedModelIds();
if (deployedModelIds != null && deployedModelIds.length > 0) {
for (String modelId : deployedModelIds) {
if (expiredModels.containsKey(modelId)) {
continue;
}
Set<String> workerNodes = modelWorkerNodes.computeIfAbsent(modelId, it -> new HashSet<>());
workerNodes.add(nodeId);
}
Expand Down Expand Up @@ -154,6 +168,8 @@ public void run() {
log.error("Failed to sync model routing", ex);
})
);
// Undeploy expired models
undeployExpiredModels(expiredModels.keySet(), modelWorkerNodes);

// refresh model status
mlIndicesHandler
Expand All @@ -163,6 +179,21 @@ public void run() {
}, e -> { log.error("Failed to sync model routing", e); }));
}

private void undeployExpiredModels(Set<String> expiredModels, Map<String, Set<String>> modelWorkerNodes) {
expiredModels.forEach(modelId -> {
String[] targetNodeIds = modelWorkerNodes.keySet().toArray(new String[0]);

MLUndeployModelNodesRequest mlUndeployModelNodesRequest = new MLUndeployModelNodesRequest(
targetNodeIds,
new String[] { modelId }
);

client.execute(MLUndeployModelAction.INSTANCE, mlUndeployModelNodesRequest, ActionListener.wrap(r -> {
log.debug("model {} is un_deployed", modelId);
}, e -> { log.error("Failed to undeploy model {}", modelId, e); }));
});
}

@VisibleForTesting
void initMLConfig() {
if (mlConfigInited) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.model;

import java.time.Instant;
import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -51,6 +52,7 @@ public class MLModelCache {
// In rare case, this could be null, e.g. model info not synced up yet a predict request comes in.
@Setter
private Boolean deployToAllNodes;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Instant lastAccessTime;

public MLModelCache() {
targetWorkerNodes = ConcurrentHashMap.newKeySet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT;

import java.time.Duration;
import java.time.Instant;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -68,6 +70,7 @@ public synchronized void initModelState(
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCache.setDeployToAllNodes(deployToAllNodes);
modelCache.setLastAccessTime(Instant.now());
modelCaches.put(modelId, modelCache);
}

Expand All @@ -87,6 +90,7 @@ public synchronized void initModelStateLocal(
modelCache.setFunctionName(functionName);
modelCache.setTargetWorkerNodes(targetWorkerNodes);
modelCache.setDeployToAllNodes(false);
modelCache.setLastAccessTime(Instant.now());
modelCaches.put(modelId, modelCache);
}

Expand Down Expand Up @@ -341,6 +345,26 @@ public String[] getLocalDeployedModels() {
.toArray(new String[0]);
}

/**
* Get expired models on node.
*
* @return array of expired model id
*/
public String[] getExpiredModels() {
return modelCaches.entrySet().stream().filter(entry -> {
MLModel mlModel = entry.getValue().getCachedModelInfo();
if (mlModel.getDeploySetting() == null) {
return false; // no TTL, never expire
}
Duration liveDuration = Duration.between(entry.getValue().getLastAccessTime(), Instant.now());
Long timeToLive = mlModel.getDeploySetting().getModelTTL();
boolean isModelExpired = (timeToLive != null
&& timeToLive > 0
&& liveDuration.getSeconds() > Duration.ofHours(timeToLive).getSeconds());
return isModelExpired && mlModel.getModelState() == MLModelState.DEPLOYED;
}).map(entry -> entry.getKey()).collect(Collectors.toList()).toArray(new String[0]);
}

/**
* Check if model is running on node.
*
Expand Down Expand Up @@ -403,6 +427,16 @@ public void setTargetWorkerNodes(String modelId, List<String> targetWorkerNodes)
}
}

/**
* Set the last access time to Instant.now()
*
* @param modelId model id
*/
public void refreshLastAccessTime(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
modelCache.setLastAccessTime(Instant.now());
}

/**
* Remove model.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@ public void deployModel(
log.info("Set new target node ids {} for model {}", Arrays.toString(workerNodes.toArray(new String[0])), modelId);
modelCacheHelper.setDeployToAllNodes(modelId, deployToAllNodes);
modelCacheHelper.setTargetWorkerNodes(modelId, workerNodes);
modelCacheHelper.refreshLastAccessTime(modelId);
}
listener.onResponse("successful");
return;
Expand Down Expand Up @@ -1041,6 +1042,7 @@ public void deployModel(
modelCacheHelper.setMLExecutor(modelId, mlExecutable);
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
modelCacheHelper.refreshLastAccessTime(modelId);
wrappedListener.onResponse("successful");
} catch (Exception e) {
log.error("Failed to add predictor to cache", e);
Expand All @@ -1053,6 +1055,7 @@ public void deployModel(
modelCacheHelper.setPredictor(modelId, predictable);
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
modelCacheHelper.refreshLastAccessTime(modelId);
Long modelContentSizeInBytes = mlModel.getModelContentSizeInBytes();
long contentSize = modelContentSizeInBytes == null
? mlModel.getTotalChunks() * CHUNK_SIZE
Expand Down Expand Up @@ -1105,6 +1108,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou
setupParamsAndPredictable(modelId, mlModel);
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
modelCacheHelper.refreshLastAccessTime(modelId);
wrappedListener.onResponse("successful");
return;
}
Expand All @@ -1114,6 +1118,7 @@ private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCou
setupParamsAndPredictable(modelId, mlModel);
mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment();
modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED);
modelCacheHelper.refreshLastAccessTime(modelId);
wrappedListener.onResponse("successful");
log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId);
}, wrappedListener::onFailure));
Expand Down Expand Up @@ -1857,6 +1862,10 @@ public String[] getLocalDeployedModels() {
return modelCacheHelper.getDeployedModels();
}

public String[] getExpiredModels() {
return modelCacheHelper.getExpiredModels();
}

/**
* Sync model routing table.
*
Expand Down
Loading

0 comments on commit 8e2f22e

Please sign in to comment.