From 99e34a0c3022d88c5b37f91b3477fe0461275901 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 10 Oct 2023 23:34:28 -0700 Subject: [PATCH 1/2] fix no worker node exception for remote embedding model Signed-off-by: Yaliang Wu --- .../action/prediction/TransportPredictionTaskAction.java | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 53aa0fa267..1d58a5d0db 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -137,9 +137,16 @@ private void executePredict( String requestId = mlPredictionTaskRequest.getRequestID(); log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); long startTime = System.nanoTime(); + // For remote text embedding model, neural search will set mlPredictionTaskRequest.getMlInput().getAlgorithm() as + // TEXT_EMBEDDING. In ml-commons we should always use the real function name of model: REMOTE. So we try to get + // from model cache first. + FunctionName functionName = modelCacheHelper + .getOptionalFunctionName(modelId) + .orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm()); mlPredictTaskRunner .run( - mlPredictionTaskRequest.getMlInput().getAlgorithm(), + // This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here + functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(wrappedListener, () -> { From ef8da7b43fc604ce4d8a94d8c19fd82381f96488 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Wed, 11 Oct 2023 00:32:13 -0700 Subject: [PATCH 2/2] only add model info to cache if model cache exist Signed-off-by: Yaliang Wu --- .../java/org/opensearch/ml/model/MLModelCacheHelper.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 554065ed95..553ffeb664 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -431,8 +431,10 @@ public boolean getDeployToAllNodes(String modelId) { } public void setModelInfo(String modelId, MLModel mlModel) { - MLModelCache mlModelCache = getExistingModelCache(modelId); - mlModelCache.setModelInfo(mlModel); + MLModelCache mlModelCache = modelCaches.get(modelId); + if (mlModelCache != null) { + mlModelCache.setModelInfo(mlModel); + } } public MLModel getModelInfo(String modelId) {