diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 4a7555d69b..aed6288629 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -63,6 +63,9 @@ public ConnectorAction( if (method == null) { throw new IllegalArgumentException("method can't null"); } + if (requestBody == null) { + throw new IllegalArgumentException("request body can't null"); + } this.actionType = actionType; this.method = method; this.url = url; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 697f27494f..17af851714 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -93,6 +93,15 @@ public MLCreateConnectorInput( if (protocol == null) { throw new IllegalArgumentException("Connector protocol is null"); } + if (description == null) { + throw new IllegalArgumentException("Connector description is null"); + } + if (parameters == null || parameters.isEmpty()) { + throw new IllegalArgumentException("Connector parameters is null or empty list"); + } + if (credential == null || credential.isEmpty()) { + throw new IllegalArgumentException("Connector credential is null or empty list"); + } } this.name = name; this.description = description; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java index 1539b9b432..c4af406ecf 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction; import java.io.IOException; @@ -12,10 +14,7 @@ import java.util.HashMap; import java.util.Map; -import org.junit.Assert; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; @@ -27,37 +26,54 @@ import org.opensearch.search.SearchModule; public class ConnectorActionTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); @Test public void constructor_NullActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("action type can't null"); - ConnectorAction.ActionType actionType = null; - String method = "post"; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + ConnectorAction.ActionType actionType = null; + String method = "post"; + String url = "https://test.com"; + String requestBody = "{\"input\": \"${parameters.input}\"}"; + new ConnectorAction(actionType, method, url, null, requestBody, null, null); + }); + assertEquals("action type can't null", exception.getMessage()); + } @Test public void constructor_NullUrl() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("url can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "post"; - String url = null; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = "post"; + String url = null; + String requestBody = "{\"input\": \"${parameters.input}\"}"; + new ConnectorAction(actionType, method, url, null, requestBody, null, null); + }); + assertEquals("url can't null", exception.getMessage()); } @Test public void constructor_NullMethod() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("method can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = null; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = null; + String url = "https://test.com"; + String requestBody = "{\"input\": \"${parameters.input}\"}"; + new ConnectorAction(actionType, method, url, null, requestBody, null, null); + }); + assertEquals("method can't null", exception.getMessage()); + } + + @Test + public void constructor_NullRequestBody() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = "post"; + String url = "https://test.com"; + String requestBody = null; + new ConnectorAction(actionType, method, url, null, requestBody, null, null); + }); + assertEquals("request body can't null", exception.getMessage()); } @Test @@ -65,11 +81,12 @@ public void writeTo_NullValue() throws IOException { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "http"; String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + String requestBody = "{\"input\": \"${parameters.input}\"}"; + ConnectorAction action = new ConnectorAction(actionType, method, url, null, requestBody, null, null); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test @@ -95,7 +112,7 @@ public void writeTo() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test @@ -103,12 +120,17 @@ public void toXContent_NullValue() throws IOException { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "http"; String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + String requestBody = "{\"input\": \"${parameters.input}\"}"; + ConnectorAction action = new ConnectorAction(actionType, method, url, null, requestBody, null, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com",\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}"}\ + """; + assertEquals(expctedContent, content); } @Test @@ -135,22 +157,23 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert - .assertEquals( - "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}", - content - ); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}\ + """; + assertEquals(expctedContent, content); } @Test public void parse() throws IOException { - String jsonStr = "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}"; + String jsonStr = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}"\ + """; XContentParser parser = XContentType.JSON .xContent() .createParser( @@ -160,24 +183,23 @@ public void parse() throws IOException { ); parser.nextToken(); ConnectorAction action = ConnectorAction.parse(parser); - Assert.assertEquals("http", action.getMethod()); - Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); - Assert.assertEquals("https://test.com", action.getUrl()); - Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody()); - Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); - Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); + assertEquals("http", action.getMethod()); + assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); + assertEquals("https://test.com", action.getUrl()); + assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody()); + assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); + assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); } @Test public void test_wrongActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Action Type"); - ConnectorAction.ActionType.from("badAction"); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); }); + assertEquals("Wrong Action Type of badAction", exception.getMessage()); } @Test public void test_invalidActionInModelPrediction() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute"); - Assert.assertEquals(isValidActionInModelPrediction(actionType), false); + assertEquals(isValidActionInModelPrediction(actionType), false); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index 28e597e186..a3b44f1321 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; @@ -19,9 +20,7 @@ import java.util.function.Consumer; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -46,20 +45,19 @@ public class MLCreateConnectorInputTests { private MLCreateConnectorInput mlCreateConnectorInput; private MLCreateConnectorInput mlCreateDryRunConnectorInput; - @Rule - public final ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\",\"client_config\":{\"max_connection\":20," - + "\"connection_timeout\":10000,\"read_timeout\":10000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + private final String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1","protocol":"http",\ + "parameters":{"input":"test input value"},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,\ + "access_mode":"PUBLIC","client_config":{"max_connection":20,\ + "connection_timeout":10000,"read_timeout":10000,\ + "retry_backoff_millis":10,"retry_timeout_seconds":10,"max_retry_times":-1,"retry_backoff_policy":"constant"}}\ + """; @Before public void setUp() { @@ -102,59 +100,162 @@ public void setUp() { @Test public void constructorMLCreateConnectorInput_NullName() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - MLCreateConnectorInput - .builder() - .name(null) - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(null) + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector name is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullVersion() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector version is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version(null) - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version(null) + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector version is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullProtocol() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector protocol is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol(null) - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol(null) + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector protocol is null", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullDescription() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description(null) + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector description is null", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullParameters() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(null) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector parameters is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_EmptyParameters() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of()) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector parameters is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(null) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_EmptyCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of()) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); } @Test @@ -187,16 +288,15 @@ public void testParse() throws Exception { @Test public void testParse_ArrayParameter() throws Exception { - String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\"}"; + String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1",\ + "protocol":"http","parameters":{"input":["test input value"]},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,"access_mode":"PUBLIC"};\ + """; testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("test_connector_name", parsedInput.getName()); assertEquals(1, parsedInput.getParameters().size()); @@ -223,8 +323,11 @@ public void readInputStream_SuccessWithNullFields() throws IOException { MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput .builder() .name("test_connector_name") + .description("this is a test connector") .version("1") .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) .build(); readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); @@ -258,10 +361,8 @@ public void testParse_MissingNameField_ShouldThrowException() throws IOException String jsonMissingName = "{\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"}"; XContentParser parser = createParser(jsonMissingName); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - - MLCreateConnectorInput.parse(parser); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { MLCreateConnectorInput.parse(parser); }); + assertEquals("Connector name is null", exception.getMessage()); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index f2c93ef5fd..a0206d7036 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -336,7 +336,7 @@ public static ConnectorAction createConnectorAction(Connector connector, Connect // Initialize the default method and requestBody String method = "POST"; - String requestBody = null; + String requestBody = "{}"; String url = ""; switch (getRemoteServerFromURL(predictEndpoint)) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 335dc95245..cb73e18e1f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -344,7 +344,7 @@ public void testGetTask_createBatchStatusActionForOpenAI() { assertEquals(ConnectorAction.ActionType.BATCH_PREDICT_STATUS, result.getActionType()); assertEquals("GET", result.getMethod()); assertEquals("https://api.openai.com/v1/batches/${parameters.id}", result.getUrl()); - assertNull(result.getRequestBody()); + assertEquals("{}", result.getRequestBody()); assertTrue(result.getHeaders().containsKey("Authorization")); } @@ -355,6 +355,7 @@ public void testGetTask_createCancelBatchActionForBedrock() { .name("test") .protocol("http") .version("1") + .description("this is a test connector") .credential(Map.of("api_key", "credential_value")) .parameters(Map.of("param1", "value1")) .actions( @@ -384,6 +385,6 @@ public void testGetTask_createCancelBatchActionForBedrock() { "https://bedrock.${parameters.region}.amazonaws.com/model-invocation-job/${parameters.processedJobArn}/stop", result.getUrl() ); - assertNull(result.getRequestBody()); + assertEquals("{}", result.getRequestBody()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index e16400bc56..33052e40d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -133,6 +133,7 @@ public void setup() { .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") + .requestBody("{ \"inputText\": \"${parameters.inputText}\" }") .url("https://${parameters.endpoint}/v1/completions") .build() ); @@ -142,6 +143,7 @@ public void setup() { input = MLCreateConnectorInput .builder() .name("test_name") + .description("this is a test connector") .version("1") .actions(actions) .parameters(parameters) @@ -447,21 +449,24 @@ public void test_execute_URL_notMatchingExpression_exception() { .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") .url("https://${parameters.endpoint}/v1/completions") + .requestBody("{ \"inputText\": \"${parameters.inputText}\" }") .build() ); + Map parameters = ImmutableMap.of("endpoint", "api.openai1.com"); + Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput .builder() .name(randomAlphaOfLength(5)) .description(randomAlphaOfLength(10)) .version("1") .protocol(ConnectorProtocols.HTTP) + .parameters(parameters) + .credential(credential) .actions(actions) .build(); MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput); - Map parameters = ImmutableMap.of("endpoint", "api.openai1.com"); - mlCreateConnectorInput.setParameters(parameters); TransportCreateConnectorAction action = new TransportCreateConnectorAction( transportService, actionFilters,