diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java new file mode 100644 index 0000000000..2a4ab98474 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote; + +import lombok.Getter; +import lombok.Setter; + +@Setter +@Getter +public abstract class AbstractConnectorExecutor implements RemoteConnectorExecutor{ + private Integer maxConnections; + private Integer connectionTimeoutInMillis; + private Integer readTimeoutInMillis; + + public void validate() { + if (connectionTimeoutInMillis == null) { + throw new IllegalArgumentException("connectionTimeoutInMillis must be set to non null value, please check your configuration"); + } + if (readTimeoutInMillis == null) { + throw new IllegalArgumentException("readTimeoutInMillis must be set to non null value, please check your configuration"); + } + if (maxConnections == null) { + throw new IllegalArgumentException("maxConnections must be set to non null value, please check your configuration"); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index c6cc7efa8b..d7989f42f2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -23,7 +23,9 @@ import software.amazon.awssdk.http.HttpExecuteRequest; import software.amazon.awssdk.http.HttpExecuteResponse; import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpConfigurationOption; import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.utils.AttributeMap; import java.io.BufferedReader; import java.io.InputStreamReader; @@ -31,6 +33,7 @@ import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedExceptionAction; +import java.time.Duration; import java.util.List; import java.util.Map; @@ -41,11 +44,11 @@ @Log4j2 @ConnectorExecutor(AWS_SIGV4) -public class AwsConnectorExecutor implements RemoteConnectorExecutor{ +public class AwsConnectorExecutor extends AbstractConnectorExecutor{ @Getter private AwsConnector connector; - private final SdkHttpClient httpClient; + private SdkHttpClient httpClient; @Setter @Getter private ScriptService scriptService; @@ -55,7 +58,30 @@ public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) { } public AwsConnectorExecutor(Connector connector) { - this(connector, new DefaultSdkHttpClientBuilder().build()); + this.connector = (AwsConnector) connector; + } + + @Override + public void initialize() { + super.validate(); + Duration connectionTimeout = Duration.ofMillis(super.getConnectionTimeoutInMillis()); + Duration readTimeout = Duration.ofMillis(super.getReadTimeoutInMillis()); + try ( + AttributeMap attributeMap = AttributeMap + .builder() + .put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout) + .put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout) + .put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getMaxConnections()) + .build() + ) { + log.info( + "Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}", + connectionTimeout, + readTimeout, + super.getMaxConnections() + ); + this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap); + } } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 677c821503..79c4ffa788 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -40,13 +40,20 @@ @Log4j2 @ConnectorExecutor(HTTP) -public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor { +public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { @Getter private HttpConnector connector; @Setter @Getter private ScriptService scriptService; + private CloseableHttpClient httpClient; + + public HttpJsonConnectorExecutor(Connector connector, CloseableHttpClient httpClient) { + this(connector); + this.httpClient = httpClient; + } + public HttpJsonConnectorExecutor(Connector connector) { this.connector = (HttpConnector)connector; } @@ -95,8 +102,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S } AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - try (CloseableHttpClient httpClient = getHttpClient(); - CloseableHttpResponse response = httpClient.execute(request)) { + try (CloseableHttpResponse response = httpClient.execute(request)) { HttpEntity responseEntity = response.getEntity(); String responseBody = EntityUtils.toString(responseEntity); EntityUtils.consume(responseEntity); @@ -123,7 +129,8 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S } } - public CloseableHttpClient getHttpClient() { - return MLHttpClientFactory.getCloseableHttpClient(); + public void initialize() { + super.validate(); + this.httpClient = MLHttpClientFactory.getCloseableHttpClient(super.getConnectionTimeoutInMillis(), super.getReadTimeoutInMillis(), super.getMaxConnections()); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 2366b150d4..7527802fbf 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -56,6 +56,10 @@ default void setScriptService(ScriptService scriptService){} default void setClient(Client client){} default void setXContentRegistry(NamedXContentRegistry xContentRegistry){} default void setClusterService(ClusterService clusterService){} + default void setConnectionTimeoutInMillis(Integer connectionTimeout){} + default void setReadTimeoutInMillis(Integer readTimeout){} + default void setMaxConnections(Integer maxConnections){} + default void initialize(){} default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List tensorOutputs) { Connector connector = getConnector(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java index 4449ee6996..3b7750c410 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteModel.java @@ -33,6 +33,10 @@ public class RemoteModel implements Predictable { public static final String CLIENT = "client"; public static final String XCONTENT_REGISTRY = "xcontent_registry"; + public static final String CONNECTION_TIMEOUT = "ConnectionTimeout"; + public static final String READ_TIMEOUT = "ReadTimeout"; + public static final String MAX_CONNECTIONS = "MaxConnections"; + private RemoteConnectorExecutor connectorExecutor; @VisibleForTesting @@ -79,10 +83,14 @@ public void initModel(MLModel model, Map params, Encryptor encry Connector connector = model.getConnector().cloneConnector(); connector.decrypt((credential) -> encryptor.decrypt(credential)); this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class); + this.connectorExecutor.setConnectionTimeoutInMillis((Integer) params.get(CONNECTION_TIMEOUT)); + this.connectorExecutor.setReadTimeoutInMillis((Integer) params.get(READ_TIMEOUT)); + this.connectorExecutor.setMaxConnections((Integer) params.get(MAX_CONNECTIONS)); this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE)); this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); this.connectorExecutor.setClient((Client) params.get(CLIENT)); this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY)); + this.connectorExecutor.initialize(); } catch (RuntimeException e) { log.error("Failed to init remote model", e); throw e; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index 782bd9501e..02c2467e27 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -11,6 +11,7 @@ import org.apache.http.HttpHost; import org.apache.http.HttpRequest; import org.apache.http.HttpResponse; +import org.apache.http.client.config.RequestConfig; import org.apache.http.conn.UnsupportedSchemeException; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; @@ -27,11 +28,11 @@ @Log4j2 public class MLHttpClientFactory { - public static CloseableHttpClient getCloseableHttpClient() { - return createHttpClient(); + public static CloseableHttpClient getCloseableHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) { + return createHttpClient(connectionTimeout, readTimeout, maxConnections); } - private static CloseableHttpClient createHttpClient() { + private static CloseableHttpClient createHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) { HttpClientBuilder builder = HttpClientBuilder.create(); // Only allow HTTP and HTTPS schemes @@ -52,6 +53,13 @@ public boolean isRedirected(HttpRequest request, HttpResponse response, HttpCont return false; } }); + builder.setMaxConnTotal(maxConnections); + builder.setMaxConnPerRoute(maxConnections); + RequestConfig requestConfig = RequestConfig.custom() + .setConnectTimeout(connectionTimeout) + .setSocketTimeout(readTimeout) + .build(); + builder.setDefaultRequestConfig(requestConfig); return builder.build(); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java new file mode 100644 index 0000000000..21cd0df74f --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote; + +import org.junit.Test; +import org.mockito.Answers; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; + +public class AbstractConnectorExecutorTest { + private final AbstractConnectorExecutor connectorExecutor = mock(AbstractConnectorExecutor.class, Answers.CALLS_REAL_METHODS); + + @Test + public void test_setters() { + connectorExecutor.setMaxConnections(10); + connectorExecutor.setReadTimeoutInMillis(1000); + connectorExecutor.setConnectionTimeoutInMillis(1000); + } + + @Test + public void test_getters() { + connectorExecutor.setMaxConnections(10); + connectorExecutor.setReadTimeoutInMillis(1000); + connectorExecutor.setConnectionTimeoutInMillis(1000); + assertEquals(10L, (long)connectorExecutor.getMaxConnections()); + assertEquals(1000L, (long)connectorExecutor.getReadTimeoutInMillis()); + assertEquals(1000L, (long)connectorExecutor.getConnectionTimeoutInMillis()); + } + + @Test + public void test_validate() { + connectorExecutor.setMaxConnections(10); + connectorExecutor.setReadTimeoutInMillis(1000); + connectorExecutor.setConnectionTimeoutInMillis(1000); + connectorExecutor.validate(); + } + + @Test + public void test_validate_fail() { + try { + connectorExecutor.validate(); + } catch (IllegalArgumentException e) { + assertEquals("connectionTimeoutInMillis must be set to non null value, please check your configuration", e.getMessage()); + } + connectorExecutor.setConnectionTimeoutInMillis(1000); + try { + connectorExecutor.validate(); + } catch (IllegalArgumentException e) { + assertEquals("readTimeoutInMillis must be set to non null value, please check your configuration", e.getMessage()); + } + connectorExecutor.setReadTimeoutInMillis(1000); + try { + connectorExecutor.validate(); + } catch (IllegalArgumentException e) { + assertEquals("maxConnections must be set to non null value, please check your configuration", e.getMessage()); + } + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 8b0d5a8173..c4183ebf90 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -7,9 +7,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import org.apache.http.ProtocolVersion; -import org.apache.http.StatusLine; -import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -113,7 +110,6 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); - MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); } @@ -143,7 +139,6 @@ public void executePredict_RemoteInferenceInput_InvalidToken() throws IOExceptio Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); - MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); } @@ -171,7 +166,6 @@ public void executePredict_RemoteInferenceInput() throws IOException { Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); - MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); @@ -205,7 +199,6 @@ public void executePredict_TextDocsInferenceInput() throws IOException { Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); @@ -214,4 +207,27 @@ public void executePredict_TextDocsInferenceInput() throws IOException { Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); } + + @Test + public void test_initialize() { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + initializeExecutor(executor); + } + + private void initializeExecutor(RemoteConnectorExecutor executor) { + executor.setConnectionTimeoutInMillis(1000); + executor.setReadTimeoutInMillis(1000); + executor.setMaxConnections(30); + executor.initialize(); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 11ae20c470..68b151f31c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -21,7 +21,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; -import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -34,15 +33,12 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import java.io.IOException; import java.util.Arrays; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -88,13 +84,12 @@ public void executePredict_RemoteInferenceInput() throws IOException { .requestBody("{\"input\": \"${parameters.input}\"}") .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); when(response.getEntity()).thenReturn(entity); StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); - when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); @@ -117,8 +112,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); - when(executor.getHttpClient()).thenReturn(httpClient); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); @@ -143,8 +137,7 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException { StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK"); when(response.getStatusLine()).thenReturn(statusLine); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); - when(executor.getHttpClient()).thenReturn(httpClient); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); } @@ -166,7 +159,7 @@ public void executePredict_TextDocsInput() throws IOException { .requestBody("{\"input\": ${parameters.input}}") .build(); Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" @@ -181,7 +174,6 @@ public void executePredict_TextDocsInput() throws IOException { when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); - when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); @@ -190,4 +182,24 @@ public void executePredict_TextDocsInput() throws IOException { Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); } + + @Test + public void test_initialize() { + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); + initializeExecutor(executor); + } + + private void initializeExecutor(RemoteConnectorExecutor executor) { + executor.setConnectionTimeoutInMillis(1000); + executor.setReadTimeoutInMillis(1000); + executor.setMaxConnections(30); + executor.initialize(); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java index 6016748a1e..0d4110c918 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java @@ -44,6 +44,8 @@ public class RemoteModelTest { RemoteModel remoteModel; Encryptor encryptor; + private Map params = Map.of(RemoteModel.CONNECTION_TIMEOUT, 1000, RemoteModel.READ_TIMEOUT, 1000, RemoteModel.MAX_CONNECTIONS, 30); + @Before public void setUp() { MockitoAnnotations.openMocks(this); @@ -71,7 +73,7 @@ public void predict_ModelDeployed_WrongInput() { exceptionRule.expectMessage("Wrong input type"); Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); - remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.initModel(mlModel, params, encryptor); remoteModel.predict(mlInput); } @@ -82,14 +84,14 @@ public void initModel_RuntimeException() { Connector connector = createConnector(null); when(mlModel.getConnector()).thenReturn(connector); doThrow(new IllegalArgumentException("Tag mismatch!")).when(encryptor).decrypt(any()); - remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.initModel(mlModel, params, encryptor); } @Test public void initModel_NullHeader() { Connector connector = createConnector(null); when(mlModel.getConnector()).thenReturn(connector); - remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.initModel(mlModel, params, encryptor); Map decryptedHeaders = connector.getDecryptedHeaders(); Assert.assertNull(decryptedHeaders); } @@ -98,7 +100,7 @@ public void initModel_NullHeader() { public void initModel_WithHeader() { Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}")); when(mlModel.getConnector()).thenReturn(connector); - remoteModel.initModel(mlModel, ImmutableMap.of(), encryptor); + remoteModel.initModel(mlModel, params, encryptor); Map decryptedHeaders = connector.getDecryptedHeaders(); RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor(); Assert.assertNotNull(executor); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java index adb0b72aa3..45e6284d25 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java @@ -22,7 +22,7 @@ public class MLHttpClientFactoryTests { @Test public void test_getCloseableHttpClient_success() { - CloseableHttpClient client = MLHttpClientFactory.getCloseableHttpClient(); + CloseableHttpClient client = MLHttpClientFactory.getCloseableHttpClient(1000, 1000, 30); assertNotNull(client); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index fa793bbcea..dad6010cb9 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -25,6 +25,9 @@ import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CONNECTION_TIMEOUT; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.MAX_CONNECTIONS; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.READ_TIMEOUT; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SCRIPT_SERVICE; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.XCONTENT_REGISTRY; import static org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel.ML_ENGINE; @@ -34,6 +37,9 @@ import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_HTTP_CLIENT_MAX_CONNECTIONS; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; @@ -156,6 +162,9 @@ public class MLModelManager { private volatile Integer maxModelPerNode; private volatile Integer maxRegisterTasksPerNode; private volatile Integer maxDeployTasksPerNode; + private volatile Integer maxConnections; + private volatile Integer connectionTimeoutInMillis; + private volatile Integer readTimeoutInMillis; public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet .of( @@ -209,6 +218,9 @@ public MLModelManager( clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it); + maxConnections = ML_COMMONS_HTTP_CLIENT_MAX_CONNECTIONS.get(settings); + connectionTimeoutInMillis = ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND.get(settings); + readTimeoutInMillis = ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND.get(settings); } public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { @@ -931,7 +943,13 @@ public void deployModel( XCONTENT_REGISTRY, xContentRegistry, CLUSTER_SERVICE, - clusterService + clusterService, + CONNECTION_TIMEOUT, + connectionTimeoutInMillis, + READ_TIMEOUT, + readTimeoutInMillis, + MAX_CONNECTIONS, + maxConnections ); // deploy remote model with internal connector or model trained by built-in algorithm like kmeans if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 23b1e9a047..65dc10cf49 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -695,7 +695,10 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED, MLCommonsSettings.ML_COMMONS_UPDATE_CONNECTOR_ENABLED, MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, - MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED + MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, + MLCommonsSettings.ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND, + MLCommonsSettings.ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND, + MLCommonsSettings.ML_COMMONS_HTTP_CLIENT_MAX_CONNECTIONS ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 03d4cf8647..73a421e155 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -177,4 +177,25 @@ private MLCommonsSettings() {} // Feature flag for enabling search processors for Retrieval Augmented Generation using OpenSearch and Remote Inference. public static final Setting ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED = GenerativeQAProcessorConstants.RAG_PIPELINE_FEATURE_ENABLED; + + public static final Setting ML_COMMONS_HTTP_CLIENT_CONNECTION_TIMEOUT_IN_MILLI_SECOND = Setting + .intSetting( + "plugins.ml_commons.http_client.connection_timeout.in_millisecond", + 1000, + 1, + Setting.Property.NodeScope, + Setting.Property.Final + ); + + public static final Setting ML_COMMONS_HTTP_CLIENT_READ_TIMEOUT_IN_MILLI_SECOND = Setting + .intSetting( + "plugins.ml_commons.http_client.read_timeout.in_millisecond", + 3000, + 1, + Setting.Property.NodeScope, + Setting.Property.Final + ); + + public static final Setting ML_COMMONS_HTTP_CLIENT_MAX_CONNECTIONS = Setting + .intSetting("plugins.ml_commons.http_client.max_connections", 20, 20, Setting.Property.NodeScope, Setting.Property.Final); }