Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Performance enhacement for predict action by caching model info #1508

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
Expand Down Expand Up @@ -87,42 +88,66 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLTaskResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
wrappedListener
.onFailure(
new MLValidationException("User Doesn't have privilege to perform this operation on this model")
);
} else {
String requestId = mlPredictionTaskRequest.getRequestID();
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
long startTime = System.nanoTime();
mlPredictTaskRunner
.run(
functionName,
mlPredictionTaskRequest,
transportService,
ActionListener.runAfter(wrappedListener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
log.debug("completed predict request " + requestId + " for model " + modelId);
})
);
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
wrappedListener.onFailure(e);
}));
}, e -> {
log.error("Failed to find model " + modelId, e);
wrappedListener.onFailure(e);
}));
MLModel cachedMlModel = modelCacheHelper.getModelInfo(modelId);
ActionListener<MLModel> modelActionListener = new ActionListener<>() {
@Override
public void onResponse(MLModel mlModel) {
context.restore();
modelCacheHelper.setModelInfo(modelId, mlModel);
FunctionName functionName = mlModel.getAlgorithm();
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
wrappedListener
.onFailure(
new MLValidationException("User Doesn't have privilege to perform this operation on this model")
);
} else {
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
wrappedListener.onFailure(e);
}));
}

@Override
public void onFailure(Exception e) {
log.error("Failed to find model " + modelId, e);
wrappedListener.onFailure(e);
}
};

if (cachedMlModel != null) {
modelActionListener.onResponse(cachedMlModel);
} else if (modelAccessControlHelper.skipModelAccessControl(user)) {
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
} else {
mlModelManager.getModel(modelId, modelActionListener);
}
}
}

private void executePredict(
MLPredictionTaskRequest mlPredictionTaskRequest,
ActionListener<MLTaskResponse> wrappedListener,
String modelId
) {
String requestId = mlPredictionTaskRequest.getRequestID();
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
long startTime = System.nanoTime();
mlPredictTaskRunner
.run(
mlPredictionTaskRequest.getMlInput().getAlgorithm(),
mlPredictionTaskRequest,
transportService,
ActionListener.runAfter(wrappedListener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
log.debug("completed predict request " + requestId + " for model " + modelId);
})
);
}
}
20 changes: 19 additions & 1 deletion plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.stream.DoubleStream;

import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.Predictable;
Expand All @@ -34,6 +35,7 @@ public class MLModelCache {
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLExecutable executor;
private final Set<String> targetWorkerNodes;
private final Set<String> workerNodes;
private MLModel modelInfo;
private final Queue<Double> modelInferenceDurationQueue;
private final Queue<Double> predictRequestDurationQueue;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU;
Expand Down Expand Up @@ -77,12 +79,16 @@ public String[] getTargetWorkerNodes() {
* @param isFromUndeploy
*/
public void removeWorkerNode(String nodeId, boolean isFromUndeploy) {
if ((deployToAllNodes != null && deployToAllNodes) || isFromUndeploy) {
if (this.isDeployToAllNodes() || isFromUndeploy) {
targetWorkerNodes.remove(nodeId);
}
if (isFromUndeploy)
deployToAllNodes = false;
workerNodes.remove(nodeId);
// when the model is not deployed to any node, we should remove the modelInfo from cache
if (targetWorkerNodes.isEmpty() || workerNodes.isEmpty()) {
modelInfo = null;
}
}

public void removeWorkerNodes(Set<String> removedNodes, boolean isFromUndeploy) {
Expand All @@ -92,6 +98,9 @@ public void removeWorkerNodes(Set<String> removedNodes, boolean isFromUndeploy)
if (isFromUndeploy)
deployToAllNodes = false;
workerNodes.removeAll(removedNodes);
if (targetWorkerNodes.isEmpty() || workerNodes.isEmpty()) {
modelInfo = null;
}
}

/**
Expand All @@ -112,6 +121,14 @@ public String[] getWorkerNodes() {
return workerNodes.toArray(new String[0]);
}

public void setModelInfo(MLModel modelInfo) {
this.modelInfo = modelInfo;
}

public MLModel getCachedModelInfo() {
return modelInfo;
}

public void syncWorkerNode(Set<String> workerNodes) {
this.workerNodes.clear();
this.workerNodes.addAll(workerNodes);
Expand All @@ -129,6 +146,7 @@ public void clear() {
modelState = null;
functionName = null;
workerNodes.clear();
modelInfo = null;
modelInferenceDurationQueue.clear();
predictRequestDurationQueue.clear();
if (predictor != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
Expand Down Expand Up @@ -429,6 +430,16 @@ public boolean getDeployToAllNodes(String modelId) {
return mlModelCache.isDeployToAllNodes();
}

public void setModelInfo(String modelId, MLModel mlModel) {
MLModelCache mlModelCache = getExistingModelCache(modelId);
mlModelCache.setModelInfo(mlModel);
}

public MLModel getModelInfo(String modelId) {
MLModelCache mlModelCache = getExistingModelCache(modelId);
return mlModelCache.getCachedModelInfo();
}

private MLModelCache getExistingModelCache(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
if (modelCache == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.model;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -26,6 +27,7 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
Expand Down Expand Up @@ -251,6 +253,7 @@ public void testSyncWorkerNodes_ModelState() {
cacheHelper.syncWorkerNodes(modelWorkerNodes);
assertEquals(2, cacheHelper.getAllModels().length);
assertEquals(0, cacheHelper.getWorkerNodes(modelId2).length);
assertNull(cacheHelper.getModelInfo(modelId2));
assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(modelId));
}

Expand Down Expand Up @@ -323,6 +326,15 @@ public void test_removeWorkerNodes_with_deployToAllNodesStatus_isTrue() {
cacheHelper.removeWorkerNodes(ImmutableSet.of(nodeId), false);
cacheHelper.removeWorkerNode(modelId, nodeId, false);
assertEquals(0, cacheHelper.getWorkerNodes(modelId).length);
assertNull(cacheHelper.getModelInfo(modelId));
}

public void test_setModelInfo_success() {
cacheHelper.initModelState(modelId, MLModelState.DEPLOYED, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true);
MLModel model = mock(MLModel.class);
when(model.getModelId()).thenReturn("mockId");
cacheHelper.setModelInfo(modelId, model);
assertEquals("mockId", cacheHelper.getModelInfo(modelId).getModelId());
}

}