From 9acb60dac8bf263ba209421bae0c3a6f74043b1f Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 27 Sep 2023 12:43:07 +0800 Subject: [PATCH] Add neural search default processor for non OpenAI/Cohere scenario (#1274) * Fix breaking change caused by opensearch core Signed-off-by: zane-neo * Add neural search default pre/post process function support Signed-off-by: zane-neo * Fix UT failures Signed-off-by: zane-neo * Fix conflicts when backport Signed-off-by: zane-neo * Fix conflict when backport Signed-off-by: zane-neo --------- Signed-off-by: zane-neo --- .../ml/common/connector/Connector.java | 22 +-- .../connector/MLPostProcessFunction.java | 85 +++++------ .../connector/MLPreProcessFunction.java | 41 +++--- .../input/remote/RemoteInferenceMLInput.java | 5 - .../ml/common/utils/StringUtils.java | 1 + .../connector/MLPostProcessFunctionTest.java | 29 ++++ .../org/opensearch/ml/engine/ModelHelper.java | 4 +- .../algorithms/remote/ConnectorUtils.java | 135 ++++++++++-------- .../remote/RemoteConnectorExecutor.java | 12 +- .../ml/engine/utils/ScriptUtils.java | 33 ++--- .../remote/AwsConnectorExecutorTest.java | 34 +++++ .../algorithms/remote/ConnectorUtilsTest.java | 32 +++-- .../remote/HttpJsonConnectorExecutorTest.java | 44 +++--- .../ml/engine/utils/ScriptUtilsTest.java | 59 ++++++++ 14 files changed, 332 insertions(+), 204 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index dad7d86012..c227d4431d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -5,6 +5,17 @@ package org.opensearch.ml.common.connector; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.commons.text.StringSubstitutor; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; @@ -21,17 +32,6 @@ import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.model.ModelTensor; -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.gson; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 662db37341..9d9ba90171 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -5,61 +5,64 @@ package org.opensearch.ml.common.connector; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; public class MLPostProcessFunction { - private static Map POST_PROCESS_FUNCTIONS; public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; public static final String OPENAI_EMBEDDING = "connector.post_process.openai.embedding"; + public static final String DEFAULT_EMBEDDING = "connector.post_process.default.embedding"; + + private static final Map JSON_PATH_EXPRESSION = new HashMap<>(); + + private static final Map>, List>> POST_PROCESS_FUNCTIONS = new HashMap<>(); + + static { - POST_PROCESS_FUNCTIONS = new HashMap<>(); - POST_PROCESS_FUNCTIONS.put(COHERE_EMBEDDING, "\n def name = \"sentence_embedding\";\n" + - " def dataType = \"FLOAT32\";\n" + - " if (params.embeddings == null || params.embeddings.length == 0) {\n" + - " return null;\n" + - " }\n" + - " def embeddings = params.embeddings;\n" + - " StringBuilder builder = new StringBuilder(\"[\");\n" + - " for (int i=0; i>, List> buildModelTensorList() { + return embeddings -> { + List modelTensors = new ArrayList<>(); + if (embeddings == null) { + throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); + } + embeddings.forEach(embedding -> modelTensors.add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[]{embedding.size()}) + .data(embedding.toArray(new Number[0])) + .build() + )); + return modelTensors; + }; } - public static boolean contains(String functionName) { - return POST_PROCESS_FUNCTIONS.containsKey(functionName); + public static String getResponseFilter(String postProcessFunction) { + return JSON_PATH_EXPRESSION.get(postProcessFunction); } - public static String get(String postProcessFunction) { + public static Function>, List> get(String postProcessFunction) { return POST_PROCESS_FUNCTIONS.get(postProcessFunction); } + + public static boolean contains(String postProcessFunction) { + return POST_PROCESS_FUNCTIONS.containsKey(postProcessFunction); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java index b49e075aea..0a41e17a9b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPreProcessFunction.java @@ -6,44 +6,37 @@ package org.opensearch.ml.common.connector; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Function; public class MLPreProcessFunction { - private static Map PRE_PROCESS_FUNCTIONS; + private static final Map, Map>> PRE_PROCESS_FUNCTIONS = new HashMap<>(); public static final String TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT = "connector.pre_process.cohere.embedding"; public static final String TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT = "connector.pre_process.openai.embedding"; + public static final String TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT = "connector.pre_process.default.embedding"; + + private static Function, Map> cohereTextEmbeddingPreProcess() { + return inputs -> Map.of("parameters", Map.of("texts", inputs)); + } + + private static Function, Map> openAiTextEmbeddingPreProcess() { + return inputs -> Map.of("parameters", Map.of("input", inputs)); + } + static { - PRE_PROCESS_FUNCTIONS = new HashMap<>(); - //TODO: change to java for openAI, embedding and Titan - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" + - " builder.append(\"[\");\n" + - " for (int i=0; i< params.text_docs.length; i++) {\n" + - " builder.append(\"\\\"\");\n" + - " builder.append(params.text_docs[i]);\n" + - " builder.append(\"\\\"\");\n" + - " if (i < params.text_docs.length - 1) {\n" + - " builder.append(\",\")\n" + - " }\n" + - " }\n" + - " builder.append(\"]\");\n" + - " def parameters = \"{\" +\"\\\"texts\\\":\" + builder + \"}\";\n" + - " return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"); - - PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" + - " builder.append(\"\\\"\");\n" + - " builder.append(params.text_docs[0]);\n" + - " builder.append(\"\\\"\");\n" + - " def parameters = \"{\" +\"\\\"input\\\":\" + builder + \"}\";\n" + - " return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";"); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, cohereTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); + PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT, openAiTextEmbeddingPreProcess()); } public static boolean contains(String functionName) { return PRE_PROCESS_FUNCTIONS.containsKey(functionName); } - public static String get(String postProcessFunction) { + public static Function, Map> get(String postProcessFunction) { return PRE_PROCESS_FUNCTIONS.get(postProcessFunction); } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index 6b262129c4..445759ac66 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -14,14 +14,9 @@ import org.opensearch.ml.common.utils.StringUtils; import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.HashMap; import java.util.Map; import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.gson; @org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE}) public class RemoteInferenceMLInput extends MLInput { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 968cda1575..edbd94b37f 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -24,6 +24,7 @@ public class StringUtils { public static final Gson gson; + static { gson = new Gson(); } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index 346d5901a8..5d4c0c88d7 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -6,12 +6,21 @@ package org.opensearch.ml.common.connector; import org.junit.Assert; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; public class MLPostProcessFunctionTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + @Test public void contains() { Assert.assertTrue(MLPostProcessFunction.contains(OPENAI_EMBEDDING)); @@ -23,4 +32,24 @@ public void get() { Assert.assertNotNull(MLPostProcessFunction.get(OPENAI_EMBEDDING)); Assert.assertNull(MLPostProcessFunction.get("wrong value")); } + + @Test + public void test_getResponseFilter() { + assert null != MLPostProcessFunction.getResponseFilter(OPENAI_EMBEDDING); + assert null == MLPostProcessFunction.getResponseFilter("wrong value"); + } + + @Test + public void test_buildModelTensorList() { + Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList()); + List> numbersList = new ArrayList<>(); + numbersList.add(Collections.singletonList(1.0f)); + Assert.assertNotNull(MLPostProcessFunction.buildModelTensorList().apply(numbersList)); + } + + @Test + public void test_buildModelTensorList_exception() { + exceptionRule.expect(IllegalArgumentException.class); + MLPostProcessFunction.buildModelTensorList().apply(null); + } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index abcc9a9ecb..df70c66ea7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -7,7 +7,6 @@ import ai.djl.training.util.DownloadUtils; import ai.djl.training.util.ProgressBar; -import com.google.gson.Gson; import com.google.gson.stream.JsonReader; import lombok.extern.log4j.Log4j2; import org.opensearch.action.ActionListener; @@ -32,6 +31,7 @@ import java.util.zip.ZipEntry; import java.util.zip.ZipFile; +import static org.opensearch.ml.common.utils.StringUtils.gson; import static org.opensearch.ml.engine.utils.FileUtils.calculateFileHash; import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.engine.utils.FileUtils.splitFileIntoChunks; @@ -48,11 +48,9 @@ public class ModelHelper { public static final String PYTORCH_ENGINE = "PyTorch"; public static final String ONNX_ENGINE = "OnnxRuntime"; private final MLEngine mlEngine; - private Gson gson; public ModelHelper(MLEngine mlEngine) { this.mlEngine = mlEngine; - gson = new Gson(); } public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput registerModelInput, ActionListener listener) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index cd3038f49c..ac3f8a7eda 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -5,17 +5,19 @@ package org.opensearch.ml.engine.algorithms.remote; -import com.google.common.collect.ImmutableMap; import com.jayway.jsonpath.JsonPath; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringSubstitutor; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.script.ScriptService; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentials; @@ -37,10 +39,12 @@ import static org.apache.commons.text.StringEscapeUtils.escapeJson; import static org.opensearch.ml.common.connector.HttpConnector.RESPONSE_FILTER_FIELD; -import static org.opensearch.ml.engine.utils.ScriptUtils.executePostprocessFunction; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.utils.ScriptUtils.executeBuildInPostProcessFunction; +import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction; import static org.opensearch.ml.engine.utils.ScriptUtils.executePreprocessFunction; -import static org.opensearch.ml.engine.utils.ScriptUtils.gson; +@Log4j2 public class ConnectorUtils { private static final Aws4Signer signer; @@ -54,43 +58,7 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } RemoteInferenceInputDataSet inputData; if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { - TextDocsInputDataSet inputDataSet = (TextDocsInputDataSet)mlInput.getInputDataset(); - List docs = new ArrayList<>(inputDataSet.getDocs()); - Map params = ImmutableMap.of("text_docs", docs); - Optional predictAction = connector.findPredictAction(); - if (!predictAction.isPresent()) { - throw new IllegalArgumentException("no predict action found"); - } - String preProcessFunction = predictAction.get().getPreProcessFunction(); - if (preProcessFunction == null) { - throw new IllegalArgumentException("Must provide pre_process_function for predict action to process text docs input."); - } - if (preProcessFunction != null && preProcessFunction.contains("${parameters")) { - StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); - preProcessFunction = substitutor.replace(preProcessFunction); - } - Optional processedResponse = executePreprocessFunction(scriptService, preProcessFunction, params); - if (!processedResponse.isPresent()) { - throw new IllegalArgumentException("Wrong input"); - } - Map map = gson.fromJson(processedResponse.get(), Map.class); - Map parametersMap = (Map) map.get("parameters"); - Map processedParameters = new HashMap<>(); - for (String key : parametersMap.keySet()) { - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - if (parametersMap.get(key) instanceof String) { - processedParameters.put(key, (String) parametersMap.get(key)); - } else { - processedParameters.put(key, gson.toJson(parametersMap.get(key))); - } - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } - } - inputData = RemoteInferenceInputDataSet.builder().parameters(processedParameters).build(); + inputData = processTextDocsInput((TextDocsInputDataSet) mlInput.getInputDataset(), connector, parameters, scriptService); } else if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) { inputData = (RemoteInferenceInputDataSet)mlInput.getInputDataset(); } else { @@ -98,20 +66,65 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } if (inputData.getParameters() != null) { Map newParameters = new HashMap<>(); - inputData.getParameters().entrySet().forEach(entry -> { - if (entry.getValue() == null) { - newParameters.put(entry.getKey(), entry.getValue()); - } else if (StringUtils.isJson(entry.getValue())) { + inputData.getParameters().forEach((key, value) -> { + if (value == null) { + newParameters.put(key, null); + } else if (org.opensearch.ml.common.utils.StringUtils.isJson(value)) { // no need to escape if it's already valid json - newParameters.put(entry.getKey(), entry.getValue()); + newParameters.put(key, value); } else { - newParameters.put(entry.getKey(), escapeJson(entry.getValue())); + newParameters.put(key, escapeJson(value)); } }); inputData.setParameters(newParameters); } return inputData; } + private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDataSet inputDataSet, Connector connector, Map parameters, ScriptService scriptService) { + List docs = new ArrayList<>(inputDataSet.getDocs()); + Optional predictAction = connector.findPredictAction(); + if (predictAction.isEmpty()) { + throw new IllegalArgumentException("no predict action found"); + } + String preProcessFunction = predictAction.get().getPreProcessFunction(); + preProcessFunction = preProcessFunction == null ? MLPreProcessFunction.TEXT_DOCS_TO_DEFAULT_EMBEDDING_INPUT : preProcessFunction; + if (MLPreProcessFunction.contains(preProcessFunction)) { + Map buildInFunctionResult = MLPreProcessFunction.get(preProcessFunction).apply(docs); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(buildInFunctionResult)).build(); + } else { + if (preProcessFunction.contains("${parameters")) { + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + preProcessFunction = substitutor.replace(preProcessFunction); + } + Optional processedInput = executePreprocessFunction(scriptService, preProcessFunction, docs); + if (processedInput.isEmpty()) { + throw new IllegalArgumentException("Wrong input"); + } + Map map = gson.fromJson(processedInput.get(), Map.class); + return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(map)).build(); + } + } + + private static Map convertScriptStringToJsonString(Map processedInput) { + Map parameterStringMap = new HashMap<>(); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + Map parametersMap = (Map) processedInput.get("parameters"); + for (String key : parametersMap.keySet()) { + if (parametersMap.get(key) instanceof String) { + parameterStringMap.put(key, (String) parametersMap.get(key)); + } else { + parameterStringMap.put(key, gson.toJson(parametersMap.get(key))); + } + } + return null; + }); + } catch (PrivilegedActionException e) { + log.error("Error processing parameters", e); + throw new RuntimeException(e); + } + return parameterStringMap; + } public static ModelTensors processOutput(String modelResponse, Connector connector, ScriptService scriptService, Map parameters) throws IOException { if (modelResponse == null) { @@ -119,26 +132,36 @@ public static ModelTensors processOutput(String modelResponse, Connector connect } List modelTensors = new ArrayList<>(); Optional predictAction = connector.findPredictAction(); - if (!predictAction.isPresent()) { + if (predictAction.isEmpty()) { throw new IllegalArgumentException("no predict action found"); } - String postProcessFunction = predictAction.get().getPostProcessFunction(); + ConnectorAction connectorAction = predictAction.get(); + String postProcessFunction = connectorAction.getPostProcessFunction(); if (postProcessFunction != null && postProcessFunction.contains("${parameters")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); postProcessFunction = substitutor.replace(postProcessFunction); } - Optional processedResponse = executePostprocessFunction(scriptService, postProcessFunction, modelResponse); + String responseFilter = parameters.get(RESPONSE_FILTER_FIELD); + if (MLPostProcessFunction.contains(postProcessFunction)) { + // in this case, we can use jsonpath to build a List> result from model response. + if (StringUtils.isBlank(responseFilter)) responseFilter = MLPostProcessFunction.getResponseFilter(postProcessFunction); + List> vectors = JsonPath.read(modelResponse, responseFilter); + List processedResponse = executeBuildInPostProcessFunction(vectors, MLPostProcessFunction.get(postProcessFunction)); + return ModelTensors.builder().mlModelTensors(processedResponse).build(); + } + + // execute user defined painless script. + Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); - if (parameters.get(RESPONSE_FILTER_FIELD) == null) { - connector.parseResponse(response, modelTensors, postProcessFunction != null && processedResponse.isPresent()); + boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent(); + if (responseFilter == null) { + connector.parseResponse(response, modelTensors, scriptReturnModelTensor); } else { Object filteredResponse = JsonPath.parse(response).read(parameters.get(RESPONSE_FILTER_FIELD)); - connector.parseResponse(filteredResponse, modelTensors, postProcessFunction != null && processedResponse.isPresent()); + connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor); } - - ModelTensors tensors = ModelTensors.builder().mlModelTensors(modelTensors).build(); - return tensors; + return ModelTensors.builder().mlModelTensors(modelTensors).build(); } public static SdkHttpFullRequest signRequest(SdkHttpFullRequest request, String accessKey, String secretKey, String sessionToken, String signingName, String region) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index c9b6e78873..8712f771c7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -32,14 +32,8 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - List textDocs = new ArrayList(textDocsInputDataSet.getDocs()); - for (int i = 0; i < textDocsInputDataSet.getDocs().size(); i++) { - preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs); - if (tensorOutputs.size() >= textDocsInputDataSet.getDocs().size()) { - break; - } - textDocs.remove(0); - } + List textDocs = new ArrayList<>(textDocsInputDataSet.getDocs()); + preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs); } else { preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); } @@ -65,7 +59,7 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List executePreprocessFunction(ScriptService scriptService, String preProcessFunction, List inputSentences) { + return Optional.ofNullable(executeScript(scriptService, preProcessFunction, ImmutableMap.of("text_docs", inputSentences))); } - public static Optional executePreprocessFunction(ScriptService scriptService, - String preProcessFunction, - Map params) { - if (MLPreProcessFunction.contains(preProcessFunction)) { - preProcessFunction = MLPreProcessFunction.get(preProcessFunction); - } - if (preProcessFunction != null) { - return Optional.ofNullable(executeScript(scriptService, preProcessFunction, params)); - } - return Optional.empty(); + public static List executeBuildInPostProcessFunction(List> vectors, Function>, List> function) { + return function.apply(vectors); } - public static Optional executePostprocessFunction(ScriptService scriptService, - String postProcessFunction, - String resultJson) { + public static Optional executePostProcessFunction(ScriptService scriptService, String postProcessFunction, String resultJson) { Map result = StringUtils.fromJson(resultJson, "result"); - if (MLPostProcessFunction.contains(postProcessFunction)) { - postProcessFunction = MLPostProcessFunction.get(postProcessFunction); - } if (postProcessFunction != null) { return Optional.ofNullable(executeScript(scriptService, postProcessFunction, result)); } return Optional.empty(); } - public static String executeScript(ScriptService scriptService, String painlessScript, Map params) { Script script = new Script(ScriptType.INLINE, "painless", painlessScript, Collections.emptyMap()); TemplateScript templateScript = scriptService.compile(script, TemplateScript.CONTEXT).newInstance(params); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index ecc143ea6f..5dbbf2090e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.junit.Assert; import org.junit.Before; @@ -18,7 +19,9 @@ import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; @@ -136,4 +139,35 @@ public void executePredict_RemoteInferenceInput() throws IOException { Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); } + + @Test + public void executePredict_TextDocsInferenceInput() throws IOException { + String jsonString = "{\"key\":\"value\"}"; + InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); + AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); + when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + when(httpRequest.call()).thenReturn(response); + when(httpClient.prepareRequest(any())).thenReturn(httpRequest); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build(); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); + Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 2a84e2fee1..8e046b151c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -27,12 +27,16 @@ import org.opensearch.script.ScriptService; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.utils.StringUtils.gson; public class ConnectorUtilsTest { @@ -56,8 +60,6 @@ public void processInput_NullInput() { @Test public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test1", "test2")).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); @@ -121,18 +123,20 @@ private void processInput_RemoteInferenceInputDataSet(String input, String expec @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_OneTextDoc() { + List input = Collections.singletonList("test_value"); + String inputJson = gson.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( - "{\"input\": \"${parameters.input}\"}", - "{\"parameters\": { \"input\": \"test_value\" } }", - "test_value"); + "{\"input\": \"${parameters.input}\"}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT, "texts"); } @Test public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() { + List input = new ArrayList<>(); + input.add("test_value1"); + input.add("test_value2"); + String inputJson = gson.toJson(input); processInput_TextDocsInputDataSet_PreprocessFunction( - "{\"input\": ${parameters.input}}", - "{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }", - "[\"test_value1\",\"test_value2\"]"); + "{\"input\": ${parameters.input}}", input, inputJson, MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "input"); } @Test @@ -143,7 +147,7 @@ public void processOutput_NullResponse() throws IOException { } @Test - public void processOutput_NoPostprocessFunction() throws IOException { + public void processOutput_NoPostprocessFunction_jsonResponse() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") @@ -186,10 +190,8 @@ public void processOutput_PostprocessFunction() throws IOException { Assert.assertEquals(0.0035105038, tensors.getMlModelTensors().get(0).getData()[2]); } - private void processInput_TextDocsInputDataSet_PreprocessFunction(String requestBody, String preprocessResult, String expectedProcessedInput) { - when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult)); - - TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test1", "test2")).build(); + private void processInput_TextDocsInputDataSet_PreprocessFunction(String requestBody, List inputs, String expectedProcessedInput, String preProcessName, String resultKey) { + TextDocsInputDataSet dataSet = TextDocsInputDataSet.builder().docs(inputs).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build(); ConnectorAction predictAction = ConnectorAction.builder() @@ -197,7 +199,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request .method("POST") .url("http://test.com/mock") .requestBody(requestBody) - .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_COHERE_EMBEDDING_INPUT) + .preProcessFunction(preProcessName) .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); @@ -205,6 +207,6 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); - Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get("input")); + Assert.assertEquals(expectedProcessedInput, remoteInferenceInputDataSet.getParameters().get(resultKey)); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 8d04603d2a..9caf621087 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -29,12 +29,15 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import java.io.IOException; import java.util.Arrays; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -94,32 +97,34 @@ public void executePredict_RemoteInferenceInput() throws IOException { } @Test - public void executePredict_TextDocsInput_NoPreprocessFunction() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); + public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") + .requestBody("{\"input\": ${parameters.input}}") .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); + Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")); } @Test public void executePredict_TextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; - String postprocessResult1 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[1, 2, 3]}"; String preprocessResult2 = "{\"parameters\": { \"input\": \"test doc2\" } }"; - String postprocessResult2 = "{\"name\":\"sentence_embedding\",\"data_type\":\"FLOAT32\",\"shape\":[3],\"data\":[4, 5, 6]}"; when(scriptService.compile(any(), any())) .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult1)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)) - .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(postprocessResult2)); + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory(preprocessResult2)); ConnectorAction predictAction = ConnectorAction.builder() .actionType(ConnectorAction.ActionType.PREDICT) @@ -127,21 +132,28 @@ public void executePredict_TextDocsInput() throws IOException { .url("http://test.com/mock") .preProcessFunction(MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT) .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) - .requestBody("{\"input\": \"${parameters.input}\"}") + .requestBody("{\"input\": ${parameters.input}}") .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n" + + " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n" + + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); - Assert.assertArrayEquals(new Number[] {1, 2, 3}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); - Assert.assertArrayEquals(new Number[] {4, 5, 6}, modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors().get(0).getData()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); + Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java new file mode 100644 index 0000000000..6ca1401efd --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/ScriptUtilsTest.java @@ -0,0 +1,59 @@ +package org.opensearch.ml.engine.utils; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ingest.TestTemplateService; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.script.ScriptService; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class ScriptUtilsTest { + + @Mock + ScriptService scriptService; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("test result")); + } + + @Test + public void test_executePreprocessFunction() { + Optional resultOpt = ScriptUtils.executePreprocessFunction(scriptService, "any function", Collections.singletonList("any input")); + assertEquals("test result", resultOpt.get()); + } + + @Test + public void test_executeBuildInPostProcessFunction() { + List> input = Arrays.asList(Arrays.asList(1.0f, 2.0f), Arrays.asList(3.0f, 4.0f)); + List modelTensors = ScriptUtils.executeBuildInPostProcessFunction(input, MLPostProcessFunction.get(MLPostProcessFunction.DEFAULT_EMBEDDING)); + assertNotNull(modelTensors); + assertEquals(2, modelTensors.size()); + } + + @Test + public void test_executePostProcessFunction() { + when(scriptService.compile(any(), any())).then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"test result\"}")); + Optional resultOpt = ScriptUtils.executePostProcessFunction(scriptService, "any function", "{\"result\": \"test result\"}"); + assertEquals("{\"result\": \"test result\"}", resultOpt.get()); + } + + @Test + public void test_executeScript() { + String result = ScriptUtils.executeScript(scriptService, "any function", Collections.singletonMap("key", "value")); + assertEquals("test result", result); + } +}