From 3495e8dc81eaade593056d78a0b387e4e165c6fd Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 6 Oct 2023 13:45:33 -0700 Subject: [PATCH] add status code to model tensor (#1443) (#1453) Signed-off-by: Yaliang Wu --- .../ml/common/output/model/ModelTensors.java | 9 ++++ .../remote/AwsConnectorExecutor.java | 2 + .../remote/HttpJsonConnectorExecutor.java | 3 ++ .../remote/AwsConnectorExecutorTest.java | 45 +++++++++++----- .../remote/HttpJsonConnectorExecutorTest.java | 54 +++++++------------ 5 files changed, 64 insertions(+), 49 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index 9073345550..03b0ce5fca 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -7,6 +7,7 @@ import lombok.Builder; import lombok.Getter; +import lombok.Setter; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; @@ -24,7 +25,10 @@ @Getter public class ModelTensors implements Writeable, ToXContentObject { public static final String OUTPUT_FIELD = "output"; + public static final String STATUS_CODE_FIELD = "status_code"; private List mlModelTensors; + @Setter + private Integer statusCode; @Builder public ModelTensors(List mlModelTensors) { @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); } + if (statusCode != null) { + builder.field(STATUS_CODE_FIELD, statusCode); + } builder.endObject(); return builder; } @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException { mlModelTensors.add(new ModelTensor(in)); } } + statusCode = in.readOptionalInt(); } @Override @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalInt(statusCode); } public void filter(ModelResultFilter resultFilter) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 2c18b363f8..6fc69621af 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -86,6 +86,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { return httpClient.prepareRequest(executeRequest).call(); }); + int statusCode = response.httpResponse().statusCode(); AbortableInputStream body = null; if (response.responseBody().isPresent()) { @@ -106,6 +107,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S String modelResponse = responseBuilder.toString(); ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); + tensors.setStatusCode(statusCode); tensorOutputs.add(tensors); } catch (RuntimeException exception) { log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 5337fd9948..3dea04a7e7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -54,6 +54,7 @@ public HttpJsonConnectorExecutor(Connector connector) { public void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs) { try { AtomicReference responseRef = new AtomicReference<>(""); + AtomicReference statusCodeRef = new AtomicReference<>(); HttpUriRequest request; switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { @@ -98,12 +99,14 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S String responseBody = EntityUtils.toString(responseEntity); EntityUtils.consume(responseEntity); responseRef.set(responseBody); + statusCodeRef.set(response.getStatusLine().getStatusCode()); } return null; }); String modelResponse = responseRef.get(); ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); + tensors.setStatusCode(statusCodeRef.get()); tensorOutputs.add(tensors); } catch (RuntimeException e) { log.error("Fail to execute http connector", e); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 6ce1b00df6..4cdb4387c2 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -5,21 +5,12 @@ package org.opensearch.ml.engine.algorithms.remote; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; -import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; -import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; -import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.util.Arrays; -import java.util.Map; -import java.util.Optional; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.http.ProtocolVersion; +import org.apache.http.StatusLine; +import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -49,6 +40,23 @@ import software.amazon.awssdk.http.ExecutableHttpRequest; import software.amazon.awssdk.http.HttpExecuteResponse; import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpResponse; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Map; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; public class AwsConnectorExecutorTest { @@ -101,6 +109,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio exceptionRule.expectMessage("No response from model"); when(response.responseBody()).thenReturn(Optional.empty()); when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); ConnectorAction predictAction = ConnectorAction @@ -135,6 +146,9 @@ public void executePredict_RemoteInferenceInput() throws IOException { InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpRequest.call()).thenReturn(response); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); @@ -177,6 +191,9 @@ public void executePredict_TextDocsInferenceInput() throws IOException { AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); ConnectorAction predictAction = ConnectorAction diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 7a52d621f5..4c4e20b6d4 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -13,9 +13,12 @@ import java.util.Arrays; import org.apache.http.HttpEntity; +import org.apache.http.ProtocolVersion; +import org.apache.http.StatusLine; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -23,6 +26,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -99,6 +103,8 @@ public void executePredict_RemoteInferenceInput() throws IOException { when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); ModelTensorOutput modelTensorOutput = executor @@ -125,13 +131,9 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); @@ -174,34 +176,16 @@ public void executePredict_TextDocsInput() throws IOException { HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\n" - + " \"object\": \"list\",\n" - + " \"data\": [\n" - + " {\n" - + " \"object\": \"embedding\",\n" - + " \"index\": 0,\n" - + " \"embedding\": [\n" - + " -0.014555434,\n" - + " -0.002135904,\n" - + " 0.0035105038\n" - + " ]\n" - + " },\n" - + " {\n" - + " \"object\": \"embedding\",\n" - + " \"index\": 1,\n" - + " \"embedding\": [\n" - + " -0.014555434,\n" - + " -0.002135904,\n" - + " 0.0035105038\n" - + " ]\n" - + " }\n" - + " ],\n" - + " \"model\": \"text-embedding-ada-002-v2\",\n" - + " \"usage\": {\n" - + " \"prompt_tokens\": 5,\n" - + " \"total_tokens\": 5\n" - + " }\n" - + "}"; + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n" + + " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n" + + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient);