From 6f83b9fee002026d7a8d0fa3550fe8cf80b30371 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 11 Oct 2023 02:38:06 -0700 Subject: [PATCH] fix no worker node exception for remote embedding model (#1482) * fix no worker node exception for remote embedding model Signed-off-by: Yaliang Wu * only add model info to cache if model cache exist Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu --- .../action/prediction/TransportPredictionTaskAction.java | 9 ++++++++- .../java/org/opensearch/ml/model/MLModelCacheHelper.java | 6 ++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 53aa0fa267..1d58a5d0db 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -137,9 +137,16 @@ private void executePredict( String requestId = mlPredictionTaskRequest.getRequestID(); log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); long startTime = System.nanoTime(); + // For remote text embedding model, neural search will set mlPredictionTaskRequest.getMlInput().getAlgorithm() as + // TEXT_EMBEDDING. In ml-commons we should always use the real function name of model: REMOTE. So we try to get + // from model cache first. + FunctionName functionName = modelCacheHelper + .getOptionalFunctionName(modelId) + .orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm()); mlPredictTaskRunner .run( - mlPredictionTaskRequest.getMlInput().getAlgorithm(), + // This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here + functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(wrappedListener, () -> { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 554065ed95..553ffeb664 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -431,8 +431,10 @@ public boolean getDeployToAllNodes(String modelId) { } public void setModelInfo(String modelId, MLModel mlModel) { - MLModelCache mlModelCache = getExistingModelCache(modelId); - mlModelCache.setModelInfo(mlModel); + MLModelCache mlModelCache = modelCaches.get(modelId); + if (mlModelCache != null) { + mlModelCache.setModelInfo(mlModel); + } } public MLModel getModelInfo(String modelId) {