From f902193b7a51447a823a402b4b23ef68cccef7d7 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 25 Oct 2023 22:47:51 -0700 Subject: [PATCH] test Signed-off-by: Yaliang Wu --- .../opensearch/ml/rest/RestMLPredictionAction.java | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index e5806103a9..60bfe07984 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -18,15 +18,13 @@ import java.util.Optional; import org.opensearch.client.node.NodeClient; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.transport.model.MLModelGetAction; -import org.opensearch.ml.common.transport.model.MLModelGetRequest; -import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.model.MLModelManager; @@ -91,9 +89,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client } return channel -> { - MLModelGetRequest getModelRequest = new MLModelGetRequest(modelId, false); - ActionListener listener = ActionListener.wrap(r -> { - MLModel mlModel = r.getMlModel(); + ActionListener listener = ActionListener.wrap(mlModel -> { String algoName = mlModel.getAlgorithm().name(); client .execute( @@ -109,8 +105,9 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client log.error("Failed to send error response", ex); } }); - client.execute(MLModelGetAction.INSTANCE, getModelRequest, listener); - + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + modelManager.getModel(modelId, ActionListener.runBefore(listener, () -> context.restore())); + } }; }