diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index fd1e62b650..3c36d0b40e 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -14,9 +14,7 @@ import static org.opensearch.ml.common.CommonValue.NOT_FOUND; import static org.opensearch.ml.common.CommonValue.UNDEPLOYED; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; -import static org.opensearch.ml.common.MLTask.ERROR_FIELD; -import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; -import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTask.*; import static org.opensearch.ml.common.MLTaskState.COMPLETED; import static org.opensearch.ml.common.MLTaskState.FAILED; import static org.opensearch.ml.engine.ModelHelper.CHUNK_FILES; @@ -756,7 +754,8 @@ private void registerPrebuiltModel(MLRegisterModelInput registerModelInput, MLTa throw new IllegalArgumentException("This model is not in the pre-trained model list, please check your parameters."); } modelHelper.downloadPrebuiltModelConfig(taskId, registerModelInput, ActionListener.wrap(mlRegisterModelInput -> { - mlTask.setFunctionName(registerModelInput.getFunctionName()); + mlTask.setFunctionName(mlRegisterModelInput.getFunctionName()); + mlTaskManager.updateMLTask(taskId, ImmutableMap.of(FUNCTION_NAME_FIELD, mlRegisterModelInput.getFunctionName()), TIMEOUT_IN_MILLIS, false); registerModelFromUrl(mlRegisterModelInput, mlTask, modelVersion); }, e -> { log.error("Failed to register prebuilt model", e);