Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

throw exception if remote model doesn't return 2xx status code; fix p… #1473

Merged
merged 4 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
}
String modelResponse = responseBuilder.toString();
if (statusCode < 200 || statusCode >= 300) {
Copy link
Contributor

@navneet1v navneet1v Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason why are not throwing exception for 201 or other status code? 200 seems to be the only case in which we should not throw exception.

Or another way to ask is can a remote model give 201 status code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That depends on the remote model service to determine return 200 or other 2xx code.

Check HTTP RFC

2xx (Successful): The request was successfully received,
understood, and accepted

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are treating all 2XX as successful response.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because that's what HTTP RFC defines

throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat
docs.add(null);
}
}
if (preProcessFunction.contains("${parameters")) {
if (preProcessFunction.contains("${parameters.")) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
preProcessFunction = substitutor.replace(preProcessFunction);
}
Expand Down Expand Up @@ -164,7 +164,7 @@ public static ModelTensors processOutput(String modelResponse, Connector connect
// execute user defined painless script.
Optional<String> processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
String response = processedResponse.orElse(modelResponse);
boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent();
boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response);
if (responseFilter == null) {
connector.parseResponse(response, modelTensors, scriptReturnModelTensor);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.util.EntityUtils;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.rest.RestStatus;
import org.opensearch.script.ScriptService;

import java.security.AccessController;
Expand Down Expand Up @@ -103,9 +105,13 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
return null;
});
String modelResponse = responseRef.get();
Integer statusCode = statusCodeRef.get();
if (statusCode < 200 || statusCode >= 300) {
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCodeRef.get());
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException e) {
log.error("Fail to execute http connector", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,15 @@ default ModelTensorOutput executePredict(MLInput mlInput) {

if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
List<String> textDocs = new ArrayList<>(textDocsInputDataSet.getDocs());
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs);
int processedDocs = 0;
while(processedDocs < textDocsInputDataSet.getDocs().size()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
List<ModelTensors> tempTensorOutputs = new ArrayList<>();
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs);
processedDocs += Math.max(tempTensorOutputs.size(), 1);
tensorOutputs.addAll(tempTensorOutputs);
}

} else {
preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,36 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}

@Test
public void executePredict_RemoteInferenceInput_InvalidToken() throws IOException {
exceptionRule.expect(OpenSearchStatusException.class);
exceptionRule.expectMessage("{\"message\":\"The security token included in the request is invalid\"}");
String jsonString = "{\"message\":\"The security token included in the request is invalid\"}";
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(403);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}

@Test
public void executePredict_RemoteInferenceInput() throws IOException {
String jsonString = "{\"key\":\"value\"}";
Expand Down Expand Up @@ -176,7 +206,7 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build();
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -120,12 +121,34 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response"));
}

@Test
public void executePredict_TextDocsInput_LimitExceed() throws IOException {
exceptionRule.expect(OpenSearchStatusException.class);
exceptionRule.expectMessage("{\"message\": \"Too many requests\"}");
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}

@Test
public void executePredict_TextDocsInput() throws IOException {
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
Expand Down Expand Up @@ -161,7 +184,7 @@ public void executePredict_TextDocsInput() throws IOException {
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,8 @@ public void deployModel(
CLUSTER_SERVICE,
clusterService
);
// deploy remote model or model trained by built-in algorithm like kmeans
if (mlModel.getConnector() != null) {
// deploy remote model with internal connector or model trained by built-in algorithm like kmeans
if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) {
setupPredictable(modelId, mlModel, params);
wrappedListener.onResponse("successful");
return;
Expand All @@ -756,6 +756,7 @@ public void deployModel(
GetRequest getConnectorRequest = new GetRequest();
FetchSourceContext fetchContext = new FetchSourceContext(true, null, null);
getConnectorRequest.index(ML_CONNECTOR_INDEX).id(mlModel.getConnectorId()).fetchSourceContext(fetchContext);
// get connector and deploy remote model with standalone connector
client.get(getConnectorRequest, ActionListener.wrap(getResponse -> {
if (getResponse != null && getResponse.isExists()) {
try (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
FunctionName algorithm = mlInput.getAlgorithm();
// run predict
if (modelId != null) {
try {
Predictable predictor = mlModelManager.getPredictor(modelId);
if (predictor != null) {
Predictable predictor = mlModelManager.getPredictor(modelId);
if (predictor != null) {
try {
if (!predictor.isModelReady()) {
throw new IllegalArgumentException("Model not ready: " + modelId);
}
Expand All @@ -226,11 +226,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
internalListener.onResponse(response);
return;
} else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
throw new IllegalArgumentException("Model not ready to be used: " + modelId);
} catch (Exception e) {
handlePredictFailure(mlTask, internalListener, e, false);
return;
}
} catch (Exception e) {
handlePredictFailure(mlTask, internalListener, e, false);
} else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
throw new IllegalArgumentException("Model not ready to be used: " + modelId);
}

// search model by model id.
Expand All @@ -249,6 +250,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
GetResponse getResponse = r;
String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString();
MLModel mlModel = MLModel.parse(xContentParser, algorithmName);
mlModel.setModelId(modelId);
User resourceUser = mlModel.getUser();
User requestUser = getUserContext(client);
if (!checkUserPermissions(requestUser, resourceUser, modelId)) {
Expand All @@ -260,7 +262,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
return;
}
// run predict
mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync());
if (mlTaskManager.contains(mlTask.getTaskId())) {
mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync());
}
MLOutput output = mlEngine.predict(mlInput, mlModel);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
Expand Down