From ce31d6be22938f4a3adc1f38681d56b86cefc794 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Sat, 30 Dec 2023 17:27:40 -0800 Subject: [PATCH] support charset input params and change default charset as utf8 (#1691) (#1828) --- .../remote/HttpJsonConnectorExecutor.java | 3 +- .../ml/rest/RestMLRemoteInferenceIT.java | 62 +++++++++++++------ 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 8282ed1c85..d881707195 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -77,7 +77,8 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S try { String predictEndpoint = connector.getPredictEndpoint(parameters); request = new HttpPost(predictEndpoint); - HttpEntity entity = new StringEntity(payload); + String charset = parameters.containsKey("charset") ? parameters.get("charset") : "UTF-8"; + HttpEntity entity = new StringEntity(payload, charset); ((HttpPost) request).setEntity(entity); } catch (Exception e) { throw new MLException("Failed to create http request for remote model", e); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 4eced40d94..6369bc5169 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -8,12 +8,13 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.function.Consumer; +import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.http.HttpEntity; import org.apache.http.HttpHeaders; import org.apache.http.message.BasicHeader; import org.junit.Before; -import org.junit.Ignore; import org.junit.Rule; import org.junit.rules.ExpectedException; import org.opensearch.client.Response; @@ -385,8 +386,28 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio assertTrue((Boolean) responseMap.get("violence")); } - @Ignore - public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException { + public void testOpenAITextEmbeddingModel_UTF8() throws IOException, InterruptedException { + testOpenAITextEmbeddingModel("UTF-8", (responseMap) -> { + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("data"); + assertFalse(responseList.isEmpty()); + }, null); + } + + public void testOpenAITextEmbeddingModel_ISO8859_1() throws IOException, InterruptedException { + testOpenAITextEmbeddingModel("ISO-8859-1", null, (exception) -> { + assertTrue(exception instanceof org.opensearch.client.ResponseException); + String stackTrace = ExceptionUtils.getStackTrace(exception); + assertTrue(stackTrace.contains("'utf-8' codec can't decode byte 0xeb")); + }); + } + + private void testOpenAITextEmbeddingModel(String charset, Consumer verifyResponse, Consumer verifyException) + throws IOException, + InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { return; @@ -397,9 +418,6 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept + " \"version\": 1,\n" + " \"protocol\": \"http\",\n" + " \"parameters\": {\n" - + " \"endpoint\": \"api.openai.com\",\n" - + " \"auth\": \"API_Key\",\n" - + " \"content_type\": \"application/json\",\n" + " \"model\": \"text-embedding-ada-002\"\n" + " },\n" + " \"credential\": {\n" @@ -415,9 +433,9 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept + " \"headers\": { \n" + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n" - + " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n" - + " \"post_process_function\": \"openai_embedding\"\n" + + " \"request_body\": \"{ \\\"input\\\": ${parameters.input}, \\\"model\\\": \\\"${parameters.model}\\\" }\",\n" + + " \"pre_process_function\": \"connector.pre_process.openai.embedding\",\n" + + " \"post_process_function\": \"connector.post_process.openai.embedding\"\n" + " }\n" + " ]\n" + "}"; @@ -435,17 +453,21 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"The food was delicious\"\n" + " }\n" + "}"; - response = predictRemoteModel(modelId, predictInput); - responseMap = parseResponseToMap(response); - List responseList = (List) responseMap.get("inference_results"); - responseMap = (Map) responseList.get(0); - responseList = (List) responseMap.get("output"); - responseMap = (Map) responseList.get(0); - responseMap = (Map) responseMap.get("dataAsMap"); - responseList = (List) responseMap.get("data"); - responseMap = (Map) responseList.get(0); - assertFalse(((List) responseMap.get("embedding")).isEmpty()); + String predictInput = "{\n" + + " \"parameters\": {\n" + + " \"input\": [\"This is a string containing Moët Hennessy\"],\n" + + " \"charset\": \"" + + charset + + "\"\n" + + " }\n" + + "}"; + try { + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + verifyResponse.accept(responseMap); + } catch (Exception e) { + verifyException.accept(e); + } } public void testCohereGenerateTextModel() throws IOException, InterruptedException {