diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 051f7ff3cb..d88ccc9b00 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -35,6 +35,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.connector.Connector; @@ -167,8 +168,8 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build(); String modelId = mlTask.getModelId(); - try { - mlModelManager.getModel(modelId, null, null, ActionListener.wrap(model -> { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener getModelListener = ActionListener.wrap(model -> { if (model.getConnector() != null) { Connector connector = model.getConnector(); executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener); @@ -181,7 +182,7 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi }); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { connectorAccessControlHelper - .getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore)); + .getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore)); } } else { actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId())); @@ -189,7 +190,9 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi }, e -> { log.error("Failed to retrieve the ML model with the given ID", e); actionListener.onFailure(e); - })); + }); + + mlModelManager.getModel(modelId, null, null, ActionListener.runBefore(getModelListener, context::restore)); } catch (Exception e) { // fetch the connector log.error("Unable to fetch status for ml task ", e);