From 401382096429aea0c477fb88e23fd4081df1325c Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Mon, 19 Feb 2024 14:39:02 -0800 Subject: [PATCH] Backport missing PR(#1443) to enable bwc in 2.10 (#2090) * add status code to model tensor (#1443) * add status code to model tensor Signed-off-by: Yaliang Wu * fix ut Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu Signed-off-by: Sicheng Song * fix cherrypick conflict Signed-off-by: Sicheng Song --------- Signed-off-by: Yaliang Wu Signed-off-by: Sicheng Song Co-authored-by: Yaliang Wu --- .../ml/common/output/model/ModelTensors.java | 9 +++++++ .../remote/AwsConnectorExecutor.java | 2 ++ .../remote/HttpJsonConnectorExecutor.java | 3 +++ .../remote/AwsConnectorExecutorTest.java | 10 +++++++- .../remote/HttpJsonConnectorExecutorTest.java | 25 +++++++++++++++++-- 5 files changed, 46 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index 9073345550..03b0ce5fca 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -7,6 +7,7 @@ import lombok.Builder; import lombok.Getter; +import lombok.Setter; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; @@ -24,7 +25,10 @@ @Getter public class ModelTensors implements Writeable, ToXContentObject { public static final String OUTPUT_FIELD = "output"; + public static final String STATUS_CODE_FIELD = "status_code"; private List mlModelTensors; + @Setter + private Integer statusCode; @Builder public ModelTensors(List mlModelTensors) { @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); } + if (statusCode != null) { + builder.field(STATUS_CODE_FIELD, statusCode); + } builder.endObject(); return builder; } @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException { mlModelTensors.add(new ModelTensor(in)); } } + statusCode = in.readOptionalInt(); } @Override @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalInt(statusCode); } public void filter(ModelResultFilter resultFilter) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index d2e692d001..bc0e954816 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -82,6 +82,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { return httpClient.prepareRequest(executeRequest).call(); }); + int statusCode = response.httpResponse().statusCode(); AbortableInputStream body = null; if (response.responseBody().isPresent()) { @@ -102,6 +103,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S String modelResponse = responseBuilder.toString(); ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); + tensors.setStatusCode(statusCode); tensorOutputs.add(tensors); } catch (RuntimeException exception) { log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 0c7d868351..016baf7229 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -52,6 +52,7 @@ public HttpJsonConnectorExecutor(Connector connector) { public void invokeRemoteModel(MLInput mlInput, Map parameters, String payload, List tensorOutputs) { try { AtomicReference responseRef = new AtomicReference<>(""); + AtomicReference statusCodeRef = new AtomicReference<>(); HttpUriRequest request; switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { @@ -97,12 +98,14 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S String responseBody = EntityUtils.toString(responseEntity); EntityUtils.consume(responseEntity); responseRef.set(responseBody); + statusCodeRef.set(response.getStatusLine().getStatusCode()); } return null; }); String modelResponse = responseRef.get(); ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); + tensors.setStatusCode(statusCodeRef.get()); tensorOutputs.add(tensors); } catch (RuntimeException e) { log.error("Fail to execute http connector", e); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index ecc143ea6f..33aaea2de2 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -29,6 +29,7 @@ import software.amazon.awssdk.http.ExecutableHttpRequest; import software.amazon.awssdk.http.HttpExecuteResponse; import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpResponse; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -38,6 +39,7 @@ import java.util.Optional; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; @@ -89,6 +91,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio exceptionRule.expectMessage("No response from model"); when(response.responseBody()).thenReturn(Optional.empty()); when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); ConnectorAction predictAction = ConnectorAction.builder() @@ -113,6 +118,9 @@ public void executePredict_RemoteInferenceInput() throws IOException { InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(200); + when(response.httpResponse()).thenReturn(httpResponse); when(httpRequest.call()).thenReturn(response); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); @@ -136,4 +144,4 @@ public void executePredict_RemoteInferenceInput() throws IOException { Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); } -} +} \ No newline at end of file diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 8d04603d2a..0c0f5a0741 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -7,9 +7,12 @@ import com.google.common.collect.ImmutableMap; import org.apache.http.HttpEntity; +import org.apache.http.ProtocolVersion; +import org.apache.http.StatusLine; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -17,6 +20,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -32,6 +36,7 @@ import org.opensearch.script.ScriptService; import java.io.IOException; +import java.io.UnsupportedEncodingException; import java.util.Arrays; import static org.mockito.ArgumentMatchers.any; @@ -84,6 +89,8 @@ public void executePredict_RemoteInferenceInput() throws IOException { when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -94,7 +101,7 @@ public void executePredict_RemoteInferenceInput() throws IOException { } @Test - public void executePredict_TextDocsInput_NoPreprocessFunction() { + public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input."); ConnectorAction predictAction = ConnectorAction.builder() @@ -103,6 +110,11 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); + when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); @@ -133,7 +145,16 @@ public void executePredict_TextDocsInput() throws IOException { HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; + String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" + + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" + + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" + + " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n" + + " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n" + + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + " }\n" + "}"; + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient);