diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java index ff2045451..c08fdf6ff 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkRestTestCase.java @@ -47,7 +47,7 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.test.rest.OpenSearchRestTestCase; -import org.junit.After; +import org.junit.AfterClass; import org.junit.Before; import javax.net.ssl.SSLEngine; @@ -62,6 +62,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; @@ -206,9 +207,10 @@ protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOE } + // Cleans up resources after all test execution has been completed @SuppressWarnings("unchecked") - @After - protected void wipeAllSystemIndices() throws IOException { + @AfterClass + protected static void wipeAllSystemIndices() throws IOException { Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); MediaType xContentType = MediaType.fromMediaType(response.getEntity().getContentType()); try ( @@ -299,6 +301,14 @@ protected boolean preserveIndicesUponCompletion() { return true; } + /** + * Required to persist cluster settings between test executions + */ + @Override + protected boolean preserveClusterSettings() { + return true; + } + /** * Helper method to invoke the Create Workflow Rest Action * @param template the template to create @@ -319,6 +329,24 @@ protected Response createWorkflowDryRun(Template template) throws Exception { return TestHelpers.makeRequest(client(), "POST", WORKFLOW_URI + "?dryrun=true", ImmutableMap.of(), template.toJson(), null); } + /** + * Helper method to invoke the Update Workflow API + * @param workflowId the document id + * @param template the template used to update + * @throws Exception if the request fails + * @return a rest response + */ + protected Response updateWorkflow(String workflowId, Template template) throws Exception { + return TestHelpers.makeRequest( + client(), + "PUT", + String.format(Locale.ROOT, "%s/%s", WORKFLOW_URI, workflowId), + ImmutableMap.of(), + template.toJson(), + null + ); + } + /** * Helper method to invoke the Provision Workflow Rest Action * @param workflowId the workflow ID to provision @@ -376,13 +404,18 @@ protected void getAndAssertWorkflowStatus(String workflowId, State stateStatus, /** * Helper method to wait until a workflow provisioning has completed and retrieve any resources created * @param workflowId the workflow id to retrieve resources from + * @param timeout the max wait time in seconds * @return a list of created resources * @throws Exception if the request fails */ - protected List getResourcesCreated(String workflowId) throws Exception { + protected List getResourcesCreated(String workflowId, int timeout) throws Exception { // wait and ensure state is completed/done - assertBusy(() -> { getAndAssertWorkflowStatus(workflowId, State.COMPLETED, ProvisioningProgress.DONE); }); + assertBusy( + () -> { getAndAssertWorkflowStatus(workflowId, State.COMPLETED, ProvisioningProgress.DONE); }, + timeout, + TimeUnit.SECONDS + ); Response response = getWorkflowStatus(workflowId, true); diff --git a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java index a1a6a73c1..6e977e88a 100644 --- a/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java +++ b/src/test/java/org/opensearch/flowframework/rest/FlowFrameworkRestApiIT.java @@ -19,7 +19,9 @@ import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; +import org.opensearch.flowframework.model.WorkflowNode; +import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -28,6 +30,71 @@ public class FlowFrameworkRestApiIT extends FlowFrameworkRestTestCase { + public void testCreateAndProvisionLocalModelWorkflow() throws Exception { + + // Using a 3 step template to create a model group, register a remote model and deploy model + Template template = TestHelpers.createTemplateFromFile("registermodelgroup-registerlocalmodel-deploymodel.json"); + + // Remove register model input to test validation + Workflow originalWorkflow = template.workflows().get(PROVISION_WORKFLOW); + + List modifiednodes = new ArrayList<>(); + modifiednodes.add( + new WorkflowNode( + "workflow_step_1", + "model_group", + Map.of(), + Map.of() // empty user inputs + ) + ); + for (WorkflowNode node : originalWorkflow.nodes()) { + if (!node.id().equals("workflow_step_1")) { + modifiednodes.add(node); + } + } + + Workflow missingInputs = new Workflow(originalWorkflow.userParams(), modifiednodes, originalWorkflow.edges()); + + Template templateWithMissingInputs = new Template.Builder().name(template.name()) + .description(template.description()) + .useCase(template.useCase()) + .templateVersion(template.templateVersion()) + .compatibilityVersion(template.compatibilityVersion()) + .workflows(Map.of(PROVISION_WORKFLOW, missingInputs)) + .uiMetadata(template.getUiMetadata()) + .user(template.getUser()) + .build(); + + // Hit Create Workflow API with invalid template + Response response = createWorkflow(templateWithMissingInputs); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + + // Retrieve workflow ID + Map responseMap = entityAsMap(response); + String workflowId = (String) responseMap.get(WORKFLOW_ID); + getAndAssertWorkflowStatus(workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); + + // Attempt provision + ResponseException exception = expectThrows(ResponseException.class, () -> provisionWorkflow(workflowId)); + assertTrue(exception.getMessage().contains("Invalid graph, missing the following required inputs : [name]")); + + // update workflow with updated inputs + response = updateWorkflow(workflowId, template); + assertEquals(RestStatus.CREATED, TestHelpers.restStatus(response)); + getAndAssertWorkflowStatus(workflowId, State.NOT_STARTED, ProvisioningProgress.NOT_STARTED); + + // Reattempt Provision + response = provisionWorkflow(workflowId); + assertEquals(RestStatus.OK, TestHelpers.restStatus(response)); + getAndAssertWorkflowStatus(workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); + + // Wait until provisioning has completed successfully before attempting to retrieve created resources + List resourcesCreated = getResourcesCreated(workflowId, 100); + + // TODO : This template should create 2 resources, model_group_id and model_id, need to fix after feature branch is merged + assertEquals(0, resourcesCreated.size()); + } + public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { // Using a 3 step template to create a connector, register remote model and deploy model @@ -69,7 +136,7 @@ public void testCreateAndProvisionRemoteModelWorkflow() throws Exception { getAndAssertWorkflowStatus(workflowId, State.PROVISIONING, ProvisioningProgress.IN_PROGRESS); // Wait until provisioning has completed successfully before attempting to retrieve created resources - List resourcesCreated = getResourcesCreated(workflowId); + List resourcesCreated = getResourcesCreated(workflowId, 10); // TODO : This template should create 2 resources, connector_id and model_id, need to fix after feature branch is merged assertEquals(1, resourcesCreated.size());