From 683648ef6b139c2723770a224a06e3db58900727 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Tue, 19 Sep 2023 10:47:19 -0700 Subject: [PATCH] Add WorkflowStepFactory class Signed-off-by: Daniel Widdis --- src/main/java/demo/DataDemo.java | 13 +-- src/main/java/demo/Demo.java | 16 +--- .../template/TemplateParser.java | 42 +++------- .../workflow/WorkflowStepFactory.java | 80 +++++++++++++++++++ .../template/TemplateParserTests.java | 2 +- src/test/resources/template/datademo.json | 2 + src/test/resources/template/demo.json | 12 ++- 7 files changed, 106 insertions(+), 61 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java diff --git a/src/main/java/demo/DataDemo.java b/src/main/java/demo/DataDemo.java index f2d606f07..4f304d2e5 100644 --- a/src/main/java/demo/DataDemo.java +++ b/src/main/java/demo/DataDemo.java @@ -14,16 +14,13 @@ import org.opensearch.common.io.PathUtils; import org.opensearch.flowframework.template.ProcessNode; import org.opensearch.flowframework.template.TemplateParser; -import org.opensearch.flowframework.workflow.WorkflowStep; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -35,14 +32,6 @@ public class DataDemo { private static final Logger logger = LogManager.getLogger(DataDemo.class); - // This is temporary. We need a factory class to generate these workflow steps - // based on a field in the JSON. - private static Map workflowMap = new HashMap<>(); - static { - workflowMap.put("create_index", new CreateIndexWorkflowStep()); - workflowMap.put("create_another_index", new CreateIndexWorkflowStep()); - } - /** * Demonstrate parsing a JSON graph. * @@ -60,7 +49,7 @@ public static void main(String[] args) { } logger.info("Parsing graph to sequence..."); - List processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap); + List processSequence = TemplateParser.parseJsonGraphToSequence(json); List> futureList = new ArrayList<>(); for (ProcessNode n : processSequence) { diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 58d977827..15dde22a9 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -14,16 +14,13 @@ import org.opensearch.common.io.PathUtils; import org.opensearch.flowframework.template.ProcessNode; import org.opensearch.flowframework.template.TemplateParser; -import org.opensearch.flowframework.workflow.WorkflowStep; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Locale; -import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -35,16 +32,6 @@ public class Demo { private static final Logger logger = LogManager.getLogger(Demo.class); - // This is temporary. We need a factory class to generate these workflow steps - // based on a field in the JSON. - private static Map workflowMap = new HashMap<>(); - static { - workflowMap.put("fetch_model", new DemoWorkflowStep(3000)); - workflowMap.put("create_ingest_pipeline", new DemoWorkflowStep(3000)); - workflowMap.put("create_search_pipeline", new DemoWorkflowStep(5000)); - workflowMap.put("create_neural_search_index", new DemoWorkflowStep(2000)); - } - /** * Demonstrate parsing a JSON graph. * @@ -62,7 +49,7 @@ public static void main(String[] args) { } logger.info("Parsing graph to sequence..."); - List processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap); + List processSequence = TemplateParser.parseJsonGraphToSequence(json); List> futureList = new ArrayList<>(); for (ProcessNode n : processSequence) { @@ -78,7 +65,6 @@ public static void main(String[] args) { predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) ) ); - // TODO need to handle this better, passing an argument when we start them all at the beginning is silly futureList.add(n.execute()); } futureList.forEach(CompletableFuture::join); diff --git a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java index bce07c616..6a224179e 100644 --- a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java +++ b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java @@ -13,9 +13,9 @@ import com.google.gson.JsonObject; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import java.util.ArrayDeque; import java.util.ArrayList; @@ -36,13 +36,16 @@ public class TemplateParser { private static final Logger logger = LogManager.getLogger(TemplateParser.class); - // Field names in the JSON. Package private for tests. + // Field names in the JSON. + // Currently package private for tests. + // These may eventually become part of the template definition in which case they might be better declared public static final String WORKFLOW = "sequence"; static final String NODES = "nodes"; static final String NODE_ID = "id"; static final String EDGES = "edges"; static final String SOURCE = "source"; static final String DESTINATION = "dest"; + static final String STEP_TYPE = "step_type"; /** * Prevent instantiating this class. @@ -52,10 +55,9 @@ private TemplateParser() {} /** * Parse a JSON representation of nodes and edges into a topologically sorted list of process nodes. * @param json A string containing a JSON representation of nodes and edges - * @param workflowSteps A map linking JSON node names to Java objects implementing {@link WorkflowStep} * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ - public static List parseJsonGraphToSequence(String json, Map workflowSteps) { + public static List parseJsonGraphToSequence(String json) { Gson gson = new Gson(); JsonObject jsonObject = gson.fromJson(json, JsonObject.class); @@ -67,31 +69,13 @@ public static List parseJsonGraphToSequence(String json, Map getContent() { - // See CreateIndexRequest ParseFields for source of content keys needed - return Map.of("mappings", request.mappings(), "settings", request.settings(), "aliases", request.aliases()); - } - - @Override - public Map getParams() { - // See RestCreateIndexAction for source of param keys needed - return Map.of("index", request.index()); - } - - }; - } + String stepType = nodeObject.get(STEP_TYPE).getAsString(); + WorkflowStep workflowStep = WorkflowStepFactory.get().createStep(stepType); + + // TODO as part of this PR: Fetch data from the template here + WorkflowData inputData = new WorkflowData() { + // TODO override params and content from user template + }; nodes.add(new ProcessNode(nodeId, workflowStep, inputData)); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java new file mode 100644 index 000000000..20fc2be23 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import demo.CreateIndexWorkflowStep; +import demo.DemoWorkflowStep; + +/** + * Generates instances implementing {@link WorkflowStep}. + */ +public class WorkflowStepFactory { + + private static final WorkflowStepFactory INSTANCE = new WorkflowStepFactory(); + + private final Map stepMap = new HashMap<>(); + + private WorkflowStepFactory() { + populateMap(); + } + + /** + * Gets the singleton instance of this class + * @return The instance of this class + */ + public static WorkflowStepFactory get() { + return INSTANCE; + } + + private void populateMap() { + // TODO: These are from the demo class as placeholders + // Replace with actual implementations such as + // https://github.com/opensearch-project/opensearch-ai-flow-framework/pull/38 + // https://github.com/opensearch-project/opensearch-ai-flow-framework/pull/44 + stepMap.put("create_index", new CreateIndexWorkflowStep()); + stepMap.put("fetch_model", new DemoWorkflowStep(3000)); + stepMap.put("create_ingest_pipeline", new DemoWorkflowStep(3000)); + stepMap.put("create_search_pipeline", new DemoWorkflowStep(5000)); + stepMap.put("create_neural_search_index", new DemoWorkflowStep(2000)); + + // Use until all the actual implementations are ready + stepMap.put("placeholder", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + future.complete(WorkflowData.EMPTY); + return future; + } + + @Override + public String getName() { + return "placeholder"; + } + }); + } + + /** + * Create a new instance of a {@link WorkflowStep}. + * @param type The type of instance to create + * @return an instance of the specified type + */ + public WorkflowStep createStep(String type) { + if (stepMap.containsKey(type)) { + return stepMap.get(type); + } + // TODO: replace this with a FlowFrameworkException + // https://github.com/opensearch-project/opensearch-ai-flow-framework/pull/43 + throw new UnsupportedOperationException("No workflow steps of type [" + type + "] are implemented."); + } +} diff --git a/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java index 24dcf0640..78e8ff28a 100644 --- a/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java +++ b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java @@ -50,7 +50,7 @@ private static ProcessNode expectedNode(String id) { // Less verbose parser private static List parse(String json) { - return TemplateParser.parseJsonGraphToSequence(json, Collections.emptyMap()); + return TemplateParser.parseJsonGraphToSequence(json); } @Override diff --git a/src/test/resources/template/datademo.json b/src/test/resources/template/datademo.json index a1323ed2c..10a2bfdc6 100644 --- a/src/test/resources/template/datademo.json +++ b/src/test/resources/template/datademo.json @@ -3,10 +3,12 @@ "nodes": [ { "id": "create_index", + "step_type": "create_index", "index_name": "demo" }, { "id": "create_another_index", + "step_type": "create_index", "index_name": "second_demo" } ], diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json index 38f1d0644..c068301e8 100644 --- a/src/test/resources/template/demo.json +++ b/src/test/resources/template/demo.json @@ -2,16 +2,20 @@ "sequence": { "nodes": [ { - "id": "fetch_model" + "id": "fetch_model", + "step_type": "fetch_model" }, { - "id": "create_ingest_pipeline" + "id": "create_ingest_pipeline", + "step_type": "create_ingest_pipeline" }, { - "id": "create_search_pipeline" + "id": "create_search_pipeline", + "step_type": "create_search_pipeline" }, { - "id": "create_neural_search_index" + "id": "create_neural_search_index", + "step_type": "create_neural_search_index" } ], "edges": [