Skip to content

Commit

Permalink
add IT
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed May 26, 2024
1 parent 69ca427 commit dfc3fc9
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -960,4 +960,25 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt
}, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS);
assertTrue(taskDone.get());
}

public String registerRemoteModel(String createConnectorInput, String modelName, boolean deploy) throws IOException,
InterruptedException {
Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = RestMLRemoteInferenceIT.registerRemoteModel(modelName, modelName, connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
if (deploy) {
response = RestMLRemoteInferenceIT.deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
}
return modelId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.utils.TestHelper;

import com.google.common.collect.ImmutableList;
import com.jayway.jsonpath.JsonPath;

public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
private final String OPENAI_KEY = System.getenv("OPENAI_KEY");
private String modelId;
private String openAIChatModelId;
private String bedrockEmbeddingModelId;
private final String completionModelConnectorEntity = "{\n"
+ " \"name\": \"OpenAI text embedding model Connector\",\n"
+ " \"description\": \"The connector to public OpenAI text embedding model service\",\n"
Expand All @@ -52,26 +52,58 @@ public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
+ " ]\n"
+ "}";

private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
private static final String GITHUB_CI_AWS_REGION = "us-west-2";

private final String bedrockEmbeddingModelConnectorEntity = "{\n"
+ " \"name\": \"Amazon Bedrock Connector: embedding\",\n"
+ " \"description\": \"The connector to bedrock Titan embedding model\",\n"
+ " \"version\": 1,\n"
+ " \"protocol\": \"aws_sigv4\",\n"
+ " \"parameters\": {\n"
+ " \"region\": \""
+ GITHUB_CI_AWS_REGION
+ "\",\n"
+ " \"service_name\": \"bedrock\",\n"
+ " \"model_name\": \"amazon.titan-embed-text-v1\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"access_key\": \""
+ AWS_ACCESS_KEY_ID
+ "\",\n"
+ " \"secret_key\": \""
+ AWS_SECRET_ACCESS_KEY
+ "\",\n"
+ " \"session_token\": \""
+ AWS_SESSION_TOKEN
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {\n"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n"
+ " \"headers\": {\n"
+ " \"content-type\": \"application/json\",\n"
+ " \"x-amz-content-sha256\": \"required\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n"
+ " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n"
+ " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n"
+ " }\n"
+ " ]\n"
+ "}";

@Before
public void setup() throws IOException, InterruptedException {
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl();
Thread.sleep(20000);

// create connectors for OPEN AI and register model
Response response = RestMLRemoteInferenceIT.createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String openAIConnectorId = (String) responseMap.get("connector_id");
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 chat model", openAIConnectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
this.modelId = (String) responseMap.get("model_id");
response = RestMLRemoteInferenceIT.deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String openAIChatModelName = "openAI-GPT-3.5 chat model " + randomAlphaOfLength(5);
this.openAIChatModelId = registerRemoteModel(completionModelConnectorEntity, openAIChatModelName, true);
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true);
}

public void testMLInferenceProcessorWithObjectFieldType() throws Exception {
Expand All @@ -82,7 +114,7 @@ public void testMLInferenceProcessorWithObjectFieldType() throws Exception {
+ " {\n"
+ " \"ml_inference\": {\n"
+ " \"model_id\": \""
+ this.modelId
+ this.openAIChatModelId
+ "\",\n"
+ " \"input_map\": [\n"
+ " {\n"
Expand Down Expand Up @@ -141,7 +173,7 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception {
+ " {\n"
+ " \"ml_inference\": {\n"
+ " \"model_id\": \""
+ this.modelId
+ this.openAIChatModelId
+ "\",\n"
+ " \"input_map\": [\n"
+ " {\n"
Expand Down Expand Up @@ -228,6 +260,96 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception {
Assert.assertEquals(0.014352738, (Double) embedding4.get(0), 0.005);
}

public void testMLInferenceProcessorWithForEachProcessor() throws Exception {
String indexName = "my_books";
String pipelineName = "my_books_bedrock_embedding_pipeline";
String createIndexRequestBody = "{\n"
+ " \"settings\": {\n"
+ " \"index\": {\n"
+ " \"default_pipeline\": \""
+ pipelineName
+ "\"\n"
+ " }\n"
+ " },\n"
+ " \"mappings\": {\n"
+ " \"properties\": {\n"
+ " \"books\": {\n"
+ " \"type\": \"nested\",\n"
+ " \"properties\": {\n"
+ " \"title_embedding\": {\n"
+ " \"type\": \"float\"\n"
+ " },\n"
+ " \"title\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"description\": {\n"
+ " \"type\": \"text\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";
createIndex(indexName, createIndexRequestBody);

String createPipelineRequestBody = "{\n"
+ " \"description\": \"Test bedrock embeddings\",\n"
+ " \"processors\": [\n"
+ " {\n"
+ " \"foreach\": {\n"
+ " \"field\": \"books\",\n"
+ " \"processor\": {\n"
+ " \"ml_inference\": {\n"
+ " \"model_id\": \""
+ this.bedrockEmbeddingModelId
+ "\",\n"
+ " \"input_map\": [\n"
+ " {\n"
+ " \"input\": \"_ingest._value.title\"\n"
+ " }\n"
+ " ],\n"
+ " \"output_map\": [\n"
+ " {\n"
+ " \"_ingest._value.title_embedding\": \"$.embedding\"\n"
+ " }\n"
+ " ],\n"
+ " \"ignore_missing\": false,\n"
+ " \"ignore_failure\": false\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";
createPipelineProcessor(createPipelineRequestBody, pipelineName);

// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
return;
}
String uploadDocumentRequestBody = "{\n"
+ " \"books\": [{\n"
+ " \"title\": \"first book\",\n"
+ " \"description\": \"This is first book\"\n"
+ " },\n"
+ " {\n"
+ " \"title\": \"second book\",\n"
+ " \"description\": \"This is second book\"\n"
+ " }\n"
+ " ]\n"
+ "}";
uploadDocument(indexName, "1", uploadDocumentRequestBody);
Map document = getDocument(indexName, "1");

List embeddingList = JsonPath.parse(document).read("_source.books[*].title_embedding");
Assert.assertEquals(2, embeddingList.size());

List embedding1 = JsonPath.parse(document).read("_source.books[0].title_embedding");
Assert.assertEquals(1536, embedding1.size());
List embedding2 = JsonPath.parse(document).read("_source.books[1].title_embedding");
Assert.assertEquals(1536, embedding2.size());
}

protected void createPipelineProcessor(String requestBody, final String pipelineName) throws Exception {
Response pipelineCreateResponse = TestHelper
.makeRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,14 @@ public static Response createConnector(String input) throws IOException {
}

public static Response registerRemoteModel(String name, String connectorId) throws IOException {
return registerRemoteModel("remote_model_group", name, connectorId);
}

public static Response registerRemoteModel(String modelGroupName, String name, String connectorId) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
+ " \"name\": \""
+ modelGroupName
+ "\",\n"
+ " \"description\": \"This is an example description\"\n"
+ "}";
Response response = TestHelper
Expand Down

0 comments on commit dfc3fc9

Please sign in to comment.