Skip to content

Commit

Permalink
support charset input params and change default charset as utf8 (#1691)…
Browse files Browse the repository at this point in the history
… (#1828)
  • Loading branch information
ylwu-amzn authored Dec 31, 2023
1 parent 1ab0c8c commit ce31d6b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
try {
String predictEndpoint = connector.getPredictEndpoint(parameters);
request = new HttpPost(predictEndpoint);
HttpEntity entity = new StringEntity(payload);
String charset = parameters.containsKey("charset") ? parameters.get("charset") : "UTF-8";
HttpEntity entity = new StringEntity(payload, charset);
((HttpPost) request).setEntity(entity);
} catch (Exception e) {
throw new MLException("Failed to create http request for remote model", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHeaders;
import org.apache.http.message.BasicHeader;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.client.Response;
Expand Down Expand Up @@ -385,8 +386,28 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio
assertTrue((Boolean) responseMap.get("violence"));
}

@Ignore
public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException {
public void testOpenAITextEmbeddingModel_UTF8() throws IOException, InterruptedException {
testOpenAITextEmbeddingModel("UTF-8", (responseMap) -> {
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("data");
assertFalse(responseList.isEmpty());
}, null);
}

public void testOpenAITextEmbeddingModel_ISO8859_1() throws IOException, InterruptedException {
testOpenAITextEmbeddingModel("ISO-8859-1", null, (exception) -> {
assertTrue(exception instanceof org.opensearch.client.ResponseException);
String stackTrace = ExceptionUtils.getStackTrace(exception);
assertTrue(stackTrace.contains("'utf-8' codec can't decode byte 0xeb"));
});
}

private void testOpenAITextEmbeddingModel(String charset, Consumer<Map> verifyResponse, Consumer<Exception> verifyException)
throws IOException,
InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
return;
Expand All @@ -397,9 +418,6 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept
+ " \"version\": 1,\n"
+ " \"protocol\": \"http\",\n"
+ " \"parameters\": {\n"
+ " \"endpoint\": \"api.openai.com\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"model\": \"text-embedding-ada-002\"\n"
+ " },\n"
+ " \"credential\": {\n"
Expand All @@ -415,9 +433,9 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept
+ " \"headers\": { \n"
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\" }\",\n"
+ " \"pre_process_function\": \"text_docs_to_openai_embedding_input\",\n"
+ " \"post_process_function\": \"openai_embedding\"\n"
+ " \"request_body\": \"{ \\\"input\\\": ${parameters.input}, \\\"model\\\": \\\"${parameters.model}\\\" }\",\n"
+ " \"pre_process_function\": \"connector.pre_process.openai.embedding\",\n"
+ " \"post_process_function\": \"connector.post_process.openai.embedding\"\n"
+ " }\n"
+ " ]\n"
+ "}";
Expand All @@ -435,17 +453,21 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"input\": \"The food was delicious\"\n" + " }\n" + "}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("data");
responseMap = (Map) responseList.get(0);
assertFalse(((List) responseMap.get("embedding")).isEmpty());
String predictInput = "{\n"
+ " \"parameters\": {\n"
+ " \"input\": [\"This is a string containing Moët Hennessy\"],\n"
+ " \"charset\": \""
+ charset
+ "\"\n"
+ " }\n"
+ "}";
try {
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
verifyResponse.accept(responseMap);
} catch (Exception e) {
verifyException.accept(e);
}
}

public void testCohereGenerateTextModel() throws IOException, InterruptedException {
Expand Down

0 comments on commit ce31d6b

Please sign in to comment.