Skip to content

Commit

Permalink
Change httpclient to async (#1958)
Browse files Browse the repository at this point in the history
* Change httpclient from sync to async

Signed-off-by: zane-neo <[email protected]>

* Change from CRTAsyncHttpClient to NettyAsyncHttpClient

Signed-off-by: zane-neo <[email protected]>

* Add publisher to request

Signed-off-by: zane-neo <[email protected]>

* Change sync httpclient to async

Signed-off-by: zane-neo <[email protected]>

* Handle error case and return error response in actionLListener

Signed-off-by: zane-neo <[email protected]>

* Fix no response when exception

Signed-off-by: zane-neo <[email protected]>

* Add content type header

Signed-off-by: zane-neo <[email protected]>

* Fix issues found in functional test

Signed-off-by: zane-neo <[email protected]>

* Fix no response issue in functional test

Signed-off-by: zane-neo <[email protected]>

* fix default step size error

Signed-off-by: zane-neo <[email protected]>

* Add track inference duration for async httpclient

Signed-off-by: zane-neo <[email protected]>

* Change client appsec highlight issues implementation for async httpclient

Signed-off-by: zane-neo <[email protected]>

* Add UTs

Signed-off-by: zane-neo <[email protected]>

* Add UTs

Signed-off-by: zane-neo <[email protected]>

* Remove unused file

Signed-off-by: zane-neo <[email protected]>

* Add UTs

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Change error code to honor remote service error code

Signed-off-by: zane-neo <[email protected]>

* Add more UTs

Signed-off-by: zane-neo <[email protected]>

* Change SSRF code to make it correct for return error stattus

Signed-off-by: zane-neo <[email protected]>

* Fix failure UTs and add more UTs

Signed-off-by: zane-neo <[email protected]>

* Fix failure ITs

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Fix partial success response not correct issue

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Fix failure ITs

Signed-off-by: zane-neo <[email protected]>

* Add more UTs to increase code coverage

Signed-off-by: zane-neo <[email protected]>

* Change url regex

Signed-off-by: zane-neo <[email protected]>

* Address comments

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Fix failure UTs

Signed-off-by: zane-neo <[email protected]>

* Add UT for httpclientFactory throw exception when creating httpclient

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Address comments and add modelTensor status code

Signed-off-by: zane-neo <[email protected]>

* Address comments

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Add status code to process error response

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Rebase main after connector level http parameter support

Signed-off-by: zane-neo <[email protected]>

* Fix UT

Signed-off-by: zane-neo <[email protected]>

* Change error message when remote model return empty and chaange the behavior when one of the requests fails

Signed-off-by: zane-neo <[email protected]>

* Add comments\

Signed-off-by: zane-neo <[email protected]>

* Remove redundant builder and change the error code check

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Add more UTs for throw exception cases

Signed-off-by: zane-neo <[email protected]>

* fix failure UTs

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Fix test cases since the error message change

Signed-off-by: zane-neo <[email protected]>

* Rebase code

Signed-off-by: zane-neo <[email protected]>

* fix failure IT

Signed-off-by: zane-neo <[email protected]>

* Add more UTs

Signed-off-by: zane-neo <[email protected]>

* Fix duplicate response to client issue

Signed-off-by: zane-neo <[email protected]>

* fix duplicate response in channel

Signed-off-by: zane-neo <[email protected]>

* change code for all successfully responses case

Signed-off-by: zane-neo <[email protected]>

* Address comments

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

* Increase nio httpclient version to fix vulnerbility

Signed-off-by: zane-neo <[email protected]>

* Change validate localhost logic to same with existing code

Signed-off-by: zane-neo <[email protected]>

* change method signature to private

Signed-off-by: zane-neo <[email protected]>

* format code

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Apr 30, 2024
1 parent ef435c9 commit d8fd07c
Show file tree
Hide file tree
Showing 21 changed files with 1,298 additions and 759 deletions.
3 changes: 2 additions & 1 deletion ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ dependencies {
}
}

implementation platform('software.amazon.awssdk:bom:2.21.15')
implementation platform('software.amazon.awssdk:bom:2.25.40')
implementation 'software.amazon.awssdk:auth'
implementation 'software.amazon.awssdk:apache-client'
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.1'
implementation 'com.jayway.jsonpath:json-path:2.9.0'
implementation group: 'org.json', name: 'json', version: '20231013'
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: '2.25.40'
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import java.util.Map;

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;

/**
Expand All @@ -31,7 +33,13 @@ public interface Predictable {
* @param mlInput input data
* @return predicted results
*/
MLOutput predict(MLInput mlInput);
default MLOutput predict(MLInput mlInput) {
throw new IllegalStateException("Method is not implemented");
}

default void asyncPredict(MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
actionListener.onFailure(new IllegalStateException("Method is not implemented"));
}

/**
* Init model (load model into memory) with ML model content and params.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,43 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR;
import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;
import static software.amazon.awssdk.http.SdkHttpMethod.POST;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.Client;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;

@Log4j2
@ConnectorExecutor(AWS_SIGV4)
public class AwsConnectorExecutor extends AbstractConnectorExecutor {

@Getter
private AwsConnector connector;
private SdkHttpClient httpClient;
@Setter
@Getter
private ScriptService scriptService;
Expand All @@ -69,103 +58,52 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
@Getter
private MLGuard mlGuard;

public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) {
this.connector = (AwsConnector) connector;
this.httpClient = httpClient;
}
private SdkAsyncHttpClient httpClient;

public AwsConnectorExecutor(Connector connector) {
super.initialize(connector);
this.connector = (AwsConnector) connector;
Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeout());
Duration readTimeout = Duration.ofMillis(super.getConnectorClientConfig().getReadTimeout());
try (
AttributeMap attributeMap = AttributeMap
.builder()
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout)
.put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout)
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getConnectorClientConfig().getMaxConnections())
.build()
) {
log
.info(
"Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}",
connectionTimeout,
readTimeout,
super.getConnectorClientConfig().getMaxConnections()
);
this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap);
} catch (RuntimeException e) {
log.error("Error initializing AWS connector HTTP client.", e);
throw e;
} catch (Throwable e) {
log.error("Error initializing AWS connector HTTP client.", e);
throw new MLException(e);
}
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
}

@SuppressWarnings("removal")
@Override
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
public void invokeRemoteModel(
MLInput mlInput,
Map<String, String> parameters,
String payload,
Map<Integer, ModelTensors> tensorOutputs,
ExecutionContext countDownLatch,
ActionListener<List<ModelTensors>> actionListener
) {
try {
String endpoint = connector.getPredictEndpoint(parameters);
RequestBody requestBody = RequestBody.fromString(payload);

SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(POST)
.uri(URI.create(endpoint))
.contentStreamProvider(requestBody.contentStreamProvider());
Map<String, String> headers = connector.getDecryptedHeaders();
if (headers != null) {
for (String key : headers.keySet()) {
builder.putHeader(key, headers.get(key));
}
}
SdkHttpFullRequest request = builder.build();
HttpExecuteRequest executeRequest = HttpExecuteRequest
SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
AsyncExecuteRequest executeRequest = AsyncExecuteRequest
.builder()
.request(signRequest(request))
.contentStreamProvider(request.contentStreamProvider().orElse(null))
.requestContentPublisher(new SimpleHttpContentPublisher(request))
.responseHandler(
new MLSdkAsyncHttpResponseHandler(
countDownLatch,
actionListener,
parameters,
tensorOutputs,
connector,
scriptService,
mlGuard
)
)
.build();

HttpExecuteResponse response = AccessController
.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> httpClient.prepareRequest(executeRequest).call());
int statusCode = response.httpResponse().statusCode();

AbortableInputStream body = null;
if (response.responseBody().isPresent()) {
body = response.responseBody().get();
}

StringBuilder responseBuilder = new StringBuilder();
if (body != null) {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(body, StandardCharsets.UTF_8))) {
String line;
while ((line = reader.readLine()) != null) {
responseBuilder.append(line);
}
}
} else {
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
}
String modelResponse = responseBuilder.toString();
if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) {
throw new IllegalArgumentException("guardrails triggered for LLM output");
}
if (statusCode < 200 || statusCode >= 300) {
throw new OpenSearchStatusException(REMOTE_SERVICE_ERROR + modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
throw exception;
actionListener.onFailure(exception);
} catch (Throwable e) {
log.error("Failed to execute predict in aws connector", e);
throw new MLException("Fail to execute predict in aws connector", e);
actionListener.onFailure(new MLException("Fail to execute predict in aws connector", e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import static org.opensearch.ml.engine.utils.ScriptUtils.executePostProcessFunction;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -34,6 +36,7 @@
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.script.ScriptService;
Expand All @@ -46,7 +49,9 @@
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.params.Aws4SignerParams;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.Region;

@Log4j2
Expand Down Expand Up @@ -179,11 +184,15 @@ public static ModelTensors processOutput(
String modelResponse,
Connector connector,
ScriptService scriptService,
Map<String, String> parameters
Map<String, String> parameters,
MLGuard mlGuard
) throws IOException {
if (modelResponse == null) {
throw new IllegalArgumentException("model response is null");
}
if (mlGuard != null && !mlGuard.validate(modelResponse, MLGuard.Type.OUTPUT)) {
throw new IllegalArgumentException("guardrails triggered for LLM output");
}
List<ModelTensor> modelTensors = new ArrayList<>();
Optional<ConnectorAction> predictAction = connector.findPredictAction();
if (predictAction.isEmpty()) {
Expand Down Expand Up @@ -252,4 +261,42 @@ public static SdkHttpFullRequest signRequest(

return signer.sign(request, params);
}

public static SdkHttpFullRequest buildSdkRequest(
Connector connector,
Map<String, String> parameters,
String payload,
SdkHttpMethod method
) {
String charset = parameters.getOrDefault("charset", "UTF-8");
RequestBody requestBody;
if (payload != null) {
requestBody = RequestBody.fromString(payload, Charset.forName(charset));
} else {
requestBody = RequestBody.empty();
}
if (SdkHttpMethod.POST == method && 0 == requestBody.optionalContentLength().get()) {
log.error("Content length is 0. Aborting request to remote model");
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
}
String endpoint = connector.getPredictEndpoint(parameters);
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(method)
.uri(URI.create(endpoint))
.contentStreamProvider(requestBody.contentStreamProvider());
Map<String, String> headers = connector.getDecryptedHeaders();
if (headers != null) {
for (String key : headers.keySet()) {
builder.putHeader(key, headers.get(key));
}
}
if (builder.matchingHeaders("Content-Type").isEmpty()) {
builder.putHeader("Content-Type", "application/json");
}
if (builder.matchingHeaders("Content-Length").isEmpty()) {
builder.putHeader("Content-Length", requestBody.optionalContentLength().get().toString());
}
return builder.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
*
* * Copyright OpenSearch Contributors
* * SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.ml.engine.algorithms.remote;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import lombok.AllArgsConstructor;
import lombok.Data;

/**
* This class encapsulates several parameters that are used in a split-batch request case.
* A batch request is that in neural-search side multiple fields are send in one request to ml-commons,
* but the remote model doesn't accept list of string inputs so in ml-commons the request needs split.
* sequence is used to identify the index of the split request.
* countDownLatch is used to wait for all the split requests to finish.
* exceptionHolder is used to hold any exception thrown in a split-batch request.
*/
@Data
@AllArgsConstructor
public class ExecutionContext {
// Should never be null
private int sequence;
private CountDownLatch countDownLatch;
// This is to hold any exception thrown in a split-batch request
private AtomicReference<Exception> exceptionHolder;
}
Loading

0 comments on commit d8fd07c

Please sign in to comment.