diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index 994298da63..9a2d7634bf 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -7,6 +7,7 @@ import lombok.Builder; import lombok.Getter; +import lombok.Setter; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; @@ -24,7 +25,10 @@ @Getter public class ModelTensors implements Writeable, ToXContentObject { public static final String OUTPUT_FIELD = "output"; + public static final String STATUS_CODE_FIELD = "status_code"; private List mlModelTensors; + @Setter + private Integer statusCode; @Builder public ModelTensors(List mlModelTensors) { @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); } + if (statusCode != null) { + builder.field(STATUS_CODE_FIELD, statusCode); + } builder.endObject(); return builder; } @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException { mlModelTensors.add(new ModelTensor(in)); } } + statusCode = in.readOptionalInt(); } @Override @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalInt(statusCode); } public void filter(ModelResultFilter resultFilter) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 3102bd91dc..e7f9c88c66 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -82,6 +82,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { return httpClient.prepareRequest(executeRequest).call(); }); + int statusCode = response.httpResponse().statusCode(); AbortableInputStream body = null; if (response.responseBody().isPresent()) { @@ -102,6 +103,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S String modelResponse = responseBuilder.toString(); ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); + tensors.setStatusCode(statusCode); tensorOutputs.add(tensors); } catch (RuntimeException exception) { log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 0c7d868351..016baf7229 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -52,6 +52,7 @@ public HttpJsonConnectorExecutor(Connector connector) { public void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs) { try { AtomicReference responseRef = new AtomicReference<>(""); + AtomicReference statusCodeRef = new AtomicReference<>(); HttpUriRequest request; switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { @@ -97,12 +98,14 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S String responseBody = EntityUtils.toString(responseEntity); EntityUtils.consume(responseEntity); responseRef.set(responseBody); + statusCodeRef.set(response.getStatusLine().getStatusCode()); } return null; }); String modelResponse = responseRef.get(); ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); + tensors.setStatusCode(statusCodeRef.get()); tensorOutputs.add(tensors); } catch (RuntimeException e) { log.error("Fail to execute http connector", e); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 5dbbf2090e..8d39db9a79 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -7,6 +7,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import org.apache.http.ProtocolVersion; +import org.apache.http.StatusLine; +import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -32,6 +35,7 @@ import software.amazon.awssdk.http.ExecutableHttpRequest; import software.amazon.awssdk.http.HttpExecuteResponse; import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpResponse; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -41,6 +45,7 @@ import java.util.Optional; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; @@ -92,6 +97,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio exceptionRule.expectMessage("No response from model"); when(response.responseBody()).thenReturn(Optional.empty()); when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); ConnectorAction predictAction = ConnectorAction.builder() @@ -116,6 +124,9 @@ public void executePredict_RemoteInferenceInput() throws IOException { InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpRequest.call()).thenReturn(response); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); @@ -147,6 +158,9 @@ public void executePredict_TextDocsInferenceInput() throws IOException { AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); ConnectorAction predictAction = ConnectorAction.builder() diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 9caf621087..5cbdce053e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -7,9 +7,12 @@ import com.google.common.collect.ImmutableMap; import org.apache.http.HttpEntity; +import org.apache.http.ProtocolVersion; +import org.apache.http.StatusLine; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -17,6 +20,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -87,6 +91,8 @@ public void executePredict_RemoteInferenceInput() throws IOException { when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -107,6 +113,8 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(executor.getHttpClient()).thenReturn(httpClient); @@ -146,6 +154,8 @@ public void executePredict_TextDocsInput() throws IOException { + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + " \"total_tokens\": 5\n" + " }\n" + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient);