Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Sep 4, 2024
1 parent 2c383dd commit 431d5fd
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.connector.Connector;
Expand Down Expand Up @@ -167,8 +168,8 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inferenceInputDataSet).build();
String modelId = mlTask.getModelId();

try {
mlModelManager.getModel(modelId, null, null, ActionListener.wrap(model -> {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLModel> getModelListener = ActionListener.wrap(model -> {
if (model.getConnector() != null) {
Connector connector = model.getConnector();
executeConnector(connector, mlInput, taskId, mlTask, remoteJob, actionListener);
Expand All @@ -181,15 +182,17 @@ private void processRemoteBatchPrediction(MLTask mlTask, String taskId, ActionLi
});
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
connectorAccessControlHelper
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
.getConnector(client, model.getConnectorId(), ActionListener.runBefore(listener, threadContext::restore));
}
} else {
actionListener.onFailure(new ResourceNotFoundException("Can't find connector " + model.getConnectorId()));
}
}, e -> {
log.error("Failed to retrieve the ML model with the given ID", e);
actionListener.onFailure(e);
}));
});

mlModelManager.getModel(modelId, null, null, ActionListener.runBefore(getModelListener, context::restore));
} catch (Exception e) {
// fetch the connector
log.error("Unable to fetch status for ml task ", e);
Expand Down

0 comments on commit 431d5fd

Please sign in to comment.