diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java index dae61b6c6c..7ca22c3cdc 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); - Map processedResult = Map.of("parameters", Map.of("inputText", processTextDocs(inputData).get(0))); + Map processedResult = Map.of("parameters", Map.of("inputText", inputData.getDocs().get(0))); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java index d82210f4a3..0b66be089d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/CohereEmbeddingPreProcessFunction.java @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); - Map processedResult = Map.of("parameters", Map.of("texts", processTextDocs(inputData))); + Map processedResult = Map.of("parameters", Map.of("texts", inputData.getDocs())); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java index d29c70048e..eae2cb6524 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/ConnectorPreProcessFunction.java @@ -10,12 +10,8 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; -import java.util.ArrayList; -import java.util.List; import java.util.function.Function; -import static org.opensearch.ml.common.utils.StringUtils.gson; - @Log4j2 public abstract class ConnectorPreProcessFunction implements Function { @@ -38,21 +34,6 @@ public RemoteInferenceInputDataSet apply(MLInput mlInput) { public abstract RemoteInferenceInputDataSet process(MLInput mlInput); - List processTextDocs(TextDocsInputDataSet inputDataSet) { - List docs = new ArrayList<>(); - for (String doc : inputDataSet.getDocs()) { - if (doc != null) { - String gsonString = gson.toJson(doc); - // in 2.9, user will add " before and after string - // gson.toString(string) will add extra " before after string, so need to remove - docs.add(gsonString.substring(1, gsonString.length() - 1)); - } else { - docs.add(null); - } - } - return docs; - } - public void validateTextDocsInput(MLInput mlInput) { if (!(mlInput.getInputDataset() instanceof TextDocsInputDataSet)) { throw new IllegalArgumentException("This pre_process_function can only support TextDocsInputDataSet"); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java index 32f294fdcc..83f7ebd74d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/OpenAIEmbeddingPreProcessFunction.java @@ -28,7 +28,7 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); - Map processedResult = Map.of("parameters", Map.of("input", processTextDocs(inputData))); + Map processedResult = Map.of("parameters", Map.of("input", inputData.getDocs())); return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index f66d1e58c4..bf3432e5ba 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -18,6 +18,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -120,4 +121,23 @@ public static Map convertScriptStringToJsonString(Map processTextDocs(List inputDocs) { + List docs = new ArrayList<>(); + for (String doc : inputDocs) { + docs.add(processTextDoc(doc)); + } + return docs; + } + + public static String processTextDoc(String doc) { + if (doc != null) { + String gsonString = gson.toJson(doc); + // in 2.9, user will add " before and after string + // gson.toString(string) will add extra " before after string, so need to remove + return gsonString.substring(1, gsonString.length() - 1); + } else { + return null; + } + } } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java index 93d23b338a..2e344cbd0f 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/DefaultPreProcessFunctionTest.java @@ -19,7 +19,6 @@ import org.opensearch.script.ScriptService; import java.util.Arrays; -import java.util.Map; import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index 3022c97e0a..63654cf8d5 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -8,6 +8,7 @@ import org.junit.Assert; import org.junit.Test; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -99,4 +100,13 @@ public void getParameterMap() { Assert.assertEquals("[10,20]", parameterMap.get("key4")); Assert.assertEquals("[1.01,\"abc\"]", parameterMap.get("key5")); } + + @Test + public void processTextDocs() { + List processedDocs = StringUtils.processTextDocs(Arrays.asList("abc \n\n123\"4", null, "[1.01,\"abc\"]")); + Assert.assertEquals(3, processedDocs.size()); + Assert.assertEquals("abc \\n\\n123\\\"4", processedDocs.get(0)); + Assert.assertNull(processedDocs.get(1)); + Assert.assertEquals("[1.01,\\\"abc\\\"]", processedDocs.get(2)); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index c3e385ca4e..893f923fbd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -10,6 +10,8 @@ import static org.opensearch.ml.common.connector.MLPreProcessFunction.CONVERT_INPUT_TO_JSON_STRING; import static org.opensearch.ml.common.connector.MLPreProcessFunction.PROCESS_REMOTE_INFERENCE_INPUT; import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.processTextDoc; +import static org.opensearch.ml.common.utils.StringUtils.processTextDocs; import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; import java.io.IOException; @@ -29,6 +31,7 @@ import org.opensearch.ml.common.connector.functions.preprocess.DefaultPreProcessFunction; import org.opensearch.ml.common.connector.functions.preprocess.RemoteInferencePreProcessFunction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; @@ -68,20 +71,7 @@ public static RemoteInferenceInputDataSet processInput( throw new IllegalArgumentException("no predict action found"); } RemoteInferenceInputDataSet inputData = processMLInput(mlInput, connector, parameters, scriptService); - if (inputData.getParameters() != null) { - Map newParameters = new HashMap<>(); - inputData.getParameters().forEach((key, value) -> { - if (value == null) { - newParameters.put(key, null); - } else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { - // no need to escape if it's already valid json - newParameters.put(key, value); - } else { - newParameters.put(key, escapeJson(value)); - } - }); - inputData.setParameters(newParameters); - } + escapeRemoteInferenceInputData(inputData); return inputData; } @@ -112,6 +102,7 @@ private static RemoteInferenceInputDataSet processMLInput( return (RemoteInferenceInputDataSet) mlInput.getInputDataset(); } } else { + MLInput newInput = escapeMLInput(mlInput); boolean convertInputToJsonString = parameters.containsKey(CONVERT_INPUT_TO_JSON_STRING) && Boolean.parseBoolean(parameters.get(CONVERT_INPUT_TO_JSON_STRING)); DefaultPreProcessFunction function = DefaultPreProcessFunction @@ -120,11 +111,51 @@ private static RemoteInferenceInputDataSet processMLInput( .preProcessFunction(preProcessFunction) .convertInputToJsonString(convertInputToJsonString) .build(); - return function.apply(mlInput); + return function.apply(newInput); } } } + private static MLInput escapeMLInput(MLInput mlInput) { + if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { + List docs = ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs(); + List newDocs = processTextDocs(docs); + TextDocsInputDataSet newInputData = ((TextDocsInputDataSet) mlInput.getInputDataset()).toBuilder().docs(newDocs).build(); + return mlInput.toBuilder().inputDataset(newInputData).build(); + } + + if (mlInput.getInputDataset() instanceof TextSimilarityInputDataSet) { + String query = ((TextSimilarityInputDataSet) mlInput.getInputDataset()).getQueryText(); + String newQuery = processTextDoc(query); + List docs = ((TextSimilarityInputDataSet) mlInput.getInputDataset()).getTextDocs(); + List newDocs = processTextDocs(docs); + TextSimilarityInputDataSet newInputData = ((TextSimilarityInputDataSet) mlInput.getInputDataset()) + .toBuilder() + .queryText(newQuery) + .textDocs(newDocs) + .build(); + return mlInput.toBuilder().inputDataset(newInputData).build(); + } + return mlInput; + } + + public static void escapeRemoteInferenceInputData(RemoteInferenceInputDataSet inputData) { + Map newParameters = new HashMap<>(); + if (inputData.getParameters() != null) { + inputData.getParameters().forEach((key, value) -> { + if (value == null) { + newParameters.put(key, null); + } else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { + // no need to escape if it's already valid json + newParameters.put(key, value); + } else { + newParameters.put(key, escapeJson(value)); + } + }); + inputData.setParameters(newParameters); + } + } + private static String getPreprocessFunction(MLInput mlInput, Connector connector) { Optional predictAction = connector.findPredictAction(); String preProcessFunction = predictAction.get().getPreProcessFunction(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 4f46c67906..bb3e13b24a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput; import java.util.ArrayList; @@ -108,6 +109,7 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List inputParameters = new HashMap<>(); if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet) inputDataset).getParameters() != null) { + escapeRemoteInferenceInputData((RemoteInferenceInputDataSet) inputDataset); inputParameters.putAll(((RemoteInferenceInputDataSet) inputDataset).getParameters()); } parameters.putAll(inputParameters);