diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java index 33fc10ca73..3d92d15bb7 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java @@ -16,6 +16,9 @@ public class ConnectorProtocols { public static final Set VALID_PROTOCOLS = Set.of(HTTP, AWS_SIGV4); public static void validateProtocol(String protocol) { + if (protocol == null) { + throw new IllegalArgumentException("Connector protocol is null. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]"); + } if (!VALID_PROTOCOLS.contains(protocol)) { throw new IllegalArgumentException("Unsupported connector protocol. Please use one of [" + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))+ "]"); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index a811548629..8d6130566a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -80,7 +80,7 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - AwsConnector.awsConnectorBuilder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + AwsConnector.awsConnectorBuilder().name("test connector").protocol("http").version("1").actions(Arrays.asList(predictAction)).build(); } @Test @@ -99,7 +99,7 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio .build(); Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); @@ -124,7 +124,7 @@ public void executePredict_RemoteInferenceInput() throws IOException { .build(); Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 10bff8e82b..9c3057b3a5 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -67,7 +67,7 @@ public void processInput_TextDocsInputDataSet_NoPreprocessFunction() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); } @@ -104,7 +104,7 @@ public void processOutput_NoPostprocessFunction() throws IOException { .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); - Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); ModelTensors tensors = ConnectorUtils.processOutput("{\"response\": \"test response\"}", connector, scriptService, ImmutableMap.of()); Assert.assertEquals(1, tensors.getMlModelTensors().size()); Assert.assertEquals("response", tensors.getMlModelTensors().get(0).getName()); @@ -126,7 +126,7 @@ public void processOutput_PostprocessFunction() throws IOException { .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); - Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}"; ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, connector, scriptService, ImmutableMap.of()); Assert.assertEquals(1, tensors.getMlModelTensors().size()); @@ -153,7 +153,7 @@ private void processInput_TextDocsInputDataSet_PreprocessFunction(String request .build(); Map parameters = new HashMap<>(); parameters.put("key1", "value1"); - Connector connector = HttpConnector.builder().name("test connector").version("1").parameters(parameters).actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").parameters(parameters).actions(Arrays.asList(predictAction)).build(); RemoteInferenceInputDataSet remoteInferenceInputDataSet = ConnectorUtils.processInput(mlInput, connector, new HashMap<>(), scriptService); Assert.assertNotNull(remoteInferenceInputDataSet.getParameters()); Assert.assertEquals(1, remoteInferenceInputDataSet.getParameters().size()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 4d1bbec749..8d04603d2a 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -66,7 +66,7 @@ public void invokeRemoteModel_WrongHttpMethod() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); executor.invokeRemoteModel(null, null, null, null); } @@ -79,7 +79,7 @@ public void executePredict_RemoteInferenceInput() throws IOException { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"response\": \"test result\"}"); @@ -103,7 +103,7 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() { .url("http://test.com/mock") .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -129,7 +129,7 @@ public void executePredict_TextDocsInput() throws IOException { .postProcessFunction(MLPostProcessFunction.OPENAI_EMBEDDING) .requestBody("{\"input\": \"${parameters.input}\"}") .build(); - Connector connector = HttpConnector.builder().name("test connector").version("1").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response);