Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable customize http client parameters #1558

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.opensearch.ml.engine.algorithms.remote;
zane-neo marked this conversation as resolved.
Show resolved Hide resolved

import lombok.Getter;
import lombok.Setter;

public abstract class AbstractConnectorExecutor implements RemoteConnectorExecutor{
zane-neo marked this conversation as resolved.
Show resolved Hide resolved
@Setter
@Getter
private Integer maxConnections;
@Setter
@Getter
private Integer connectionTimeoutInMillis;
@Setter
@Getter
private Integer readTimeoutInMillis;

public void validate() {
if (connectionTimeoutInMillis == null) {
throw new IllegalArgumentException("connectionTimeoutInMillis must be set to non null value, please check your configuration");
}
if (readTimeoutInMillis == null) {
throw new IllegalArgumentException("readTimeoutInMillis must be set to non null value, please check your configuration");
}
if (maxConnections == null) {
throw new IllegalArgumentException("maxConnections must be set to non null value, please check your configuration");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.utils.AttributeMap;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
import java.util.List;
import java.util.Map;

Expand All @@ -41,11 +44,11 @@

@Log4j2
@ConnectorExecutor(AWS_SIGV4)
public class AwsConnectorExecutor implements RemoteConnectorExecutor{
public class AwsConnectorExecutor extends AbstractConnectorExecutor{

@Getter
private AwsConnector connector;
private final SdkHttpClient httpClient;
private SdkHttpClient httpClient;
@Setter @Getter
private ScriptService scriptService;

Expand All @@ -55,7 +58,30 @@ public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) {
}

public AwsConnectorExecutor(Connector connector) {
this(connector, new DefaultSdkHttpClientBuilder().build());
this.connector = (AwsConnector) connector;
}

@Override
public void initialize() {
super.validate();
Duration connectionTimeout = Duration.ofMillis(super.getConnectionTimeoutInMillis());
Duration readTimeout = Duration.ofMillis(super.getReadTimeoutInMillis());
try (
AttributeMap attributeMap = AttributeMap
.builder()
.put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout)
.put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout)
.put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getMaxConnections())
.build()
) {
log.info(
"Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}",
connectionTimeout,
readTimeout,
super.getMaxConnections()
);
this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,20 @@

@Log4j2
@ConnectorExecutor(HTTP)
public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor {
public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor {

@Getter
private HttpConnector connector;
@Setter @Getter
private ScriptService scriptService;

private CloseableHttpClient httpClient;

public HttpJsonConnectorExecutor(Connector connector, CloseableHttpClient httpClient) {
this(connector);
this.httpClient = httpClient;
}

public HttpJsonConnectorExecutor(Connector connector) {
this.connector = (HttpConnector)connector;
}
Expand Down Expand Up @@ -95,8 +102,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
}

AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
try (CloseableHttpClient httpClient = getHttpClient();
CloseableHttpResponse response = httpClient.execute(request)) {
try (CloseableHttpResponse response = httpClient.execute(request)) {
HttpEntity responseEntity = response.getEntity();
String responseBody = EntityUtils.toString(responseEntity);
EntityUtils.consume(responseEntity);
Expand All @@ -123,7 +129,8 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
}
}

public CloseableHttpClient getHttpClient() {
return MLHttpClientFactory.getCloseableHttpClient();
public void initialize() {
super.validate();
this.httpClient = MLHttpClientFactory.getCloseableHttpClient(super.getConnectionTimeoutInMillis(), super.getReadTimeoutInMillis(), super.getMaxConnections());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ default void setScriptService(ScriptService scriptService){}
default void setClient(Client client){}
default void setXContentRegistry(NamedXContentRegistry xContentRegistry){}
default void setClusterService(ClusterService clusterService){}
default void setConnectionTimeoutInMillis(Integer connectionTimeout){}
default void setReadTimeoutInMillis(Integer readTimeout){}
default void setMaxConnections(Integer maxConnections){}
default void initialize(){}

default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List<ModelTensors> tensorOutputs) {
Connector connector = getConnector();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ public class RemoteModel implements Predictable {
public static final String CLIENT = "client";
public static final String XCONTENT_REGISTRY = "xcontent_registry";

public static final String CONNECTION_TIMEOUT = "ConnectionTimeout";
public static final String READ_TIMEOUT = "ReadTimeout";
public static final String MAX_CONNECTIONS = "MaxConnections";

private RemoteConnectorExecutor connectorExecutor;

@VisibleForTesting
Expand Down Expand Up @@ -79,10 +83,14 @@ public void initModel(MLModel model, Map<String, Object> params, Encryptor encry
Connector connector = model.getConnector().cloneConnector();
connector.decrypt((credential) -> encryptor.decrypt(credential));
this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
this.connectorExecutor.setConnectionTimeoutInMillis((Integer) params.get(CONNECTION_TIMEOUT));
this.connectorExecutor.setReadTimeoutInMillis((Integer) params.get(READ_TIMEOUT));
this.connectorExecutor.setMaxConnections((Integer) params.get(MAX_CONNECTIONS));
this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE));
this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));
this.connectorExecutor.setClient((Client) params.get(CLIENT));
this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY));
this.connectorExecutor.initialize();
} catch (RuntimeException e) {
log.error("Failed to init remote model", e);
throw e;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpResponse;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.UnsupportedSchemeException;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
Expand All @@ -27,11 +28,11 @@
@Log4j2
public class MLHttpClientFactory {

public static CloseableHttpClient getCloseableHttpClient() {
return createHttpClient();
public static CloseableHttpClient getCloseableHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) {
return createHttpClient(connectionTimeout, readTimeout, maxConnections);
}

private static CloseableHttpClient createHttpClient() {
private static CloseableHttpClient createHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) {
HttpClientBuilder builder = HttpClientBuilder.create();

// Only allow HTTP and HTTPS schemes
Expand All @@ -52,6 +53,13 @@ public boolean isRedirected(HttpRequest request, HttpResponse response, HttpCont
return false;
}
});
builder.setMaxConnTotal(maxConnections);
builder.setMaxConnPerRoute(maxConnections);
RequestConfig requestConfig = RequestConfig.custom()
.setConnectTimeout(connectionTimeout)
.setSocketTimeout(readTimeout)
.build();
builder.setDefaultRequestConfig(requestConfig);
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.opensearch.ml.engine.algorithms.remote;

import org.junit.Test;
import org.mockito.Answers;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;

public class AbstractConnectorExecutorTest {
private final AbstractConnectorExecutor connectorExecutor = mock(AbstractConnectorExecutor.class, Answers.CALLS_REAL_METHODS);

@Test
public void test_setters() {
connectorExecutor.setMaxConnections(10);
connectorExecutor.setReadTimeoutInMillis(1000);
connectorExecutor.setConnectionTimeoutInMillis(1000);
}

@Test
public void test_getters() {
connectorExecutor.setMaxConnections(10);
connectorExecutor.setReadTimeoutInMillis(1000);
connectorExecutor.setConnectionTimeoutInMillis(1000);
assertEquals(10L, (long)connectorExecutor.getMaxConnections());
assertEquals(1000L, (long)connectorExecutor.getReadTimeoutInMillis());
assertEquals(1000L, (long)connectorExecutor.getConnectionTimeoutInMillis());
}

@Test
public void test_validate() {
connectorExecutor.setMaxConnections(10);
connectorExecutor.setReadTimeoutInMillis(1000);
connectorExecutor.setConnectionTimeoutInMillis(1000);
connectorExecutor.validate();
}

@Test
public void test_validate_fail() {
try {
connectorExecutor.validate();
} catch (IllegalArgumentException e) {
assertEquals("connectionTimeoutInMillis must be set to non null value, please check your configuration", e.getMessage());
}
connectorExecutor.setConnectionTimeoutInMillis(1000);
try {
connectorExecutor.validate();
} catch (IllegalArgumentException e) {
assertEquals("readTimeoutInMillis must be set to non null value, please check your configuration", e.getMessage());
}
connectorExecutor.setReadTimeoutInMillis(1000);
try {
connectorExecutor.validate();
} catch (IllegalArgumentException e) {
assertEquals("maxConnections must be set to non null value, please check your configuration", e.getMessage());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -113,7 +110,6 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}
Expand Down Expand Up @@ -143,7 +139,6 @@ public void executePredict_RemoteInferenceInput_InvalidToken() throws IOExceptio
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}
Expand Down Expand Up @@ -171,7 +166,6 @@ public void executePredict_RemoteInferenceInput() throws IOException {
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Expand Down Expand Up @@ -205,7 +199,6 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Expand All @@ -214,4 +207,27 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key"));
}

@Test
public void test_initialize() {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));
initializeExecutor(executor);
}

private void initializeExecutor(RemoteConnectorExecutor executor) {
executor.setConnectionTimeoutInMillis(1000);
executor.setReadTimeoutInMillis(1000);
executor.setMaxConnections(30);
executor.initialize();
}
}
Loading