diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index ac3f8a7eda..c481725057 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -81,7 +81,6 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto return inputData; } private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, Connector connector, Map parameters, ScriptService scriptService) { - List docs = new ArrayList<>(inputDataSet.getDocs()); Optional predictAction = connector.findPredictAction(); if (predictAction.isEmpty()) { throw new IllegalArgumentException("no predict action found"); @@ -89,9 +88,17 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat String preProcessFunction = predictAction.get().getPreProcessFunction(); preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; if (MLPreProcessFunction.contains(preProcessFunction)) { - Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs); + Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(inputDataSet.getDocs()); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); } else { + List docs = new ArrayList<>(); + for (String doc : inputDataSet.getDocs()) { + if (doc != null) { + docs.add(gson.toJson(doc)); + } else { + docs.add(null); + } + } if (preProcessFunction.contains("${parameters")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); preProcessFunction = substitutor.replace(preProcessFunction);