From aa2337b5f0f010f97a0a8069f1d86c43b739509d Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 11 Oct 2023 06:22:56 -0500 Subject: [PATCH] fix no worker node exception for remote embedding model (#1482) (#1483) * fix no worker node exception for remote embedding model Signed-off-by: Yaliang Wu (cherry picked from commit 6f83b9fee002026d7a8d0fa3550fe8cf80b30371) Co-authored-by: Yaliang Wu --- .../action/prediction/TransportPredictionTaskAction.java | 9 ++++++++- .../java/org/opensearch/ml/model/MLModelCacheHelper.java | 6 ++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 9ac661196a..6117cc553c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -137,9 +137,16 @@ private void executePredict( String requestId = mlPredictionTaskRequest.getRequestID(); log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); long startTime = System.nanoTime(); + // For remote text embedding model, neural search will set mlPredictionTaskRequest.getMlInput().getAlgorithm() as + // TEXT_EMBEDDING. In ml-commons we should always use the real function name of model: REMOTE. So we try to get + // from model cache first. + FunctionName functionName = modelCacheHelper + .getOptionalFunctionName(modelId) + .orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm()); mlPredictTaskRunner .run( - mlPredictionTaskRequest.getMlInput().getAlgorithm(), + // This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here + functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(wrappedListener, () -> { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 554065ed95..553ffeb664 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -431,8 +431,10 @@ public boolean getDeployToAllNodes(String modelId) { } public void setModelInfo(String modelId, MLModel mlModel) { - MLModelCache mlModelCache = getExistingModelCache(modelId); - mlModelCache.setModelInfo(mlModel); + MLModelCache mlModelCache = modelCaches.get(modelId); + if (mlModelCache != null) { + mlModelCache.setModelInfo(mlModel); + } } public MLModel getModelInfo(String modelId) {