diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index e17bf2aa0..533d82c1e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -16,6 +16,7 @@ import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorAction.ActionType; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -24,8 +25,10 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.CompletableFuture; @@ -84,39 +87,44 @@ public void onFailure(Exception e) { String description = null; String version = null; String protocol = null; - Map parameters = new HashMap<>(); - Map credentials = new HashMap<>(); - List actions = new ArrayList<>(); - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - name = (String) content.get(NAME_FIELD); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case VERSION_FIELD: - version = (String) content.get(VERSION_FIELD); - break; - case PROTOCOL_FIELD: - protocol = (String) content.get(PROTOCOL_FIELD); - break; - case PARAMETERS_FIELD: - parameters = getParameterMap((Map) content.get(PARAMETERS_FIELD)); - break; - case CREDENTIALS_FIELD: - credentials = (Map) content.get(CREDENTIALS_FIELD); - break; - case ACTIONS_FIELD: - actions = (List) content.get(ACTIONS_FIELD); - break; + Map parameters = Collections.emptyMap(); + Map credentials = Collections.emptyMap(); + List actions = Collections.emptyList(); + + try { + for (WorkflowData workflowData : data) { + for (Entry entry : workflowData.getContent().entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + name = (String) entry.getValue(); + break; + case DESCRIPTION_FIELD: + description = (String) entry.getValue(); + break; + case VERSION_FIELD: + version = (String) entry.getValue(); + break; + case PROTOCOL_FIELD: + protocol = (String) entry.getValue(); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap(entry.getValue()); + break; + case CREDENTIALS_FIELD: + credentials = getStringToStringMap(entry.getValue(), CREDENTIALS_FIELD); + break; + case ACTIONS_FIELD: + actions = getConnectorActionList(entry.getValue()); + break; + } } - } + } catch (IllegalArgumentException iae) { + createConnectorFuture.completeExceptionally(new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST)); + return createConnectorFuture; + } catch (PrivilegedActionException pae) { + createConnectorFuture.completeExceptionally(new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED)); + return createConnectorFuture; } if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) { @@ -145,21 +153,48 @@ public String getName() { return NAME; } - private static Map getParameterMap(Map params) { + @SuppressWarnings("unchecked") + private static Map getStringToStringMap(Object map, String fieldName) { + if (map instanceof Map) { + return (Map) map; + } + throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); + } + private static Map getParameterMap(Object parameterMap) throws PrivilegedActionException { Map parameters = new HashMap<>(); - for (String key : params.keySet()) { - String value = params.get(key); - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - parameters.put(key, value); - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException(e); - } + for (Entry entry : getStringToStringMap(parameterMap, PARAMETERS_FIELD).entrySet()) { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + parameters.put(entry.getKey(), entry.getValue()); + return null; + }); } return parameters; } + private static List getConnectorActionList(Object array) { + if (!(array instanceof Map[])) { + throw new IllegalArgumentException("[" + ACTIONS_FIELD + "] must be an array of key-value maps."); + } + List actions = new ArrayList<>(); + for (Map map : (Map[]) array) { + String actionType = (String) map.get(ConnectorAction.ACTION_TYPE_FIELD); + if (actionType == null) { + throw new IllegalArgumentException("[" + ConnectorAction.ACTION_TYPE_FIELD + "] is missing."); + } + @SuppressWarnings("unchecked") + ConnectorAction action = ConnectorAction.builder() + .actionType(ActionType.valueOf(actionType.toUpperCase(Locale.ROOT))) + .method((String) map.get(ConnectorAction.METHOD_FIELD)) + .url((String) map.get(ConnectorAction.URL_FIELD)) + .headers((Map) map.get(ConnectorAction.HEADERS_FIELD)) + .requestBody((String) map.get(ConnectorAction.REQUEST_BODY_FIELD)) + .preProcessFunction((String) map.get(ConnectorAction.ACTION_PRE_PROCESS_FUNCTION)) + .postProcessFunction((String) map.get(ConnectorAction.ACTION_POST_PROCESS_FUNCTION)) + .build(); + actions.add(action); + } + return actions; + } + } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index b54b2a27c..63855f7bd 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -10,6 +10,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; @@ -27,7 +28,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -44,38 +44,30 @@ public void setUp() throws Exception { Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); + Map[] actions = new Map[] { + Map.ofEntries( + Map.entry(ConnectorAction.ACTION_TYPE_FIELD, ConnectorAction.ActionType.PREDICT.name()), + Map.entry(ConnectorAction.METHOD_FIELD, "post"), + Map.entry(ConnectorAction.URL_FIELD, "foo.test"), + Map.entry( + ConnectorAction.REQUEST_BODY_FIELD, + "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }" + ) + ) }; MockitoAnnotations.openMocks(this); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "post"; - String url = "foot.test"; - inputData = new WorkflowData( Map.ofEntries( - Map.entry("name", "test"), - Map.entry("description", "description"), - Map.entry("version", "1"), - Map.entry("protocol", "test"), - Map.entry("params", params), - Map.entry("credentials", credentials), - Map.entry( - "actions", - List.of( - new ConnectorAction( - actionType, - method, - url, - null, - "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }", - null, - null - ) - ) - ) + Map.entry(CommonValue.NAME_FIELD, "test"), + Map.entry(CommonValue.DESCRIPTION_FIELD, "description"), + Map.entry(CommonValue.VERSION_FIELD, "1"), + Map.entry(CommonValue.PROTOCOL_FIELD, "test"), + Map.entry(CommonValue.PARAMETERS_FIELD, params), + Map.entry(CommonValue.CREDENTIALS_FIELD, credentials), + Map.entry(CommonValue.ACTIONS_FIELD, actions) ) ); - } public void testCreateConnector() throws IOException, ExecutionException, InterruptedException { @@ -83,6 +75,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr String connectorId = "connect"; CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { @@ -104,6 +97,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr public void testCreateConnectorFailure() throws IOException { CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> {