From a9f8559c84210228e2eb6b8121e705364f792607 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 22 Sep 2023 11:44:07 -0700 Subject: [PATCH] Separate WorkflowNode and ProcessNode functionality Signed-off-by: Daniel Widdis --- build.gradle | 1 - src/main/java/demo/DataDemo.java | 12 +- src/main/java/demo/Demo.java | 10 +- src/main/java/demo/TemplateParseDemo.java | 23 ++-- .../flowframework/template/ProcessNode.java | 59 ++------ .../flowframework/template/Template.java | 101 +++++++++++--- .../template/TemplateParser.java | 126 +++++++++++------- .../flowframework/template/WorkflowEdge.java | 4 +- .../flowframework/template/WorkflowNode.java | 8 +- .../flowframework/workflow/Workflow.java | 65 +++++---- .../template/ProcessNodeTests.java | 12 +- .../template/TemplateParserTests.java | 75 ++++++----- .../template/WorkflowEdgeTests.java | 4 +- src/test/resources/template/datademo.json | 22 --- src/test/resources/template/demo.json | 104 ++++++++++----- 15 files changed, 352 insertions(+), 274 deletions(-) delete mode 100644 src/test/resources/template/datademo.json diff --git a/build.gradle b/build.gradle index aa20423ee..748757484 100644 --- a/build.gradle +++ b/build.gradle @@ -105,7 +105,6 @@ repositories { dependencies { implementation "org.opensearch:opensearch:${opensearch_version}" implementation 'org.junit.jupiter:junit-jupiter:5.10.0' - implementation "com.google.code.gson:gson:2.10.1" compileOnly "com.google.guava:guava:32.1.2-jre" api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" diff --git a/src/main/java/demo/DataDemo.java b/src/main/java/demo/DataDemo.java index 4f304d2e5..d3e7eab91 100644 --- a/src/main/java/demo/DataDemo.java +++ b/src/main/java/demo/DataDemo.java @@ -13,6 +13,7 @@ import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; import org.opensearch.flowframework.template.ProcessNode; +import org.opensearch.flowframework.template.Template; import org.opensearch.flowframework.template.TemplateParser; import java.io.IOException; @@ -21,7 +22,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -36,10 +36,11 @@ public class DataDemo { * Demonstrate parsing a JSON graph. * * @param args unused + * @throws IOException on a failure */ @SuppressForbidden(reason = "just a demo class that will be deleted") - public static void main(String[] args) { - String path = "src/test/resources/template/datademo.json"; + public static void main(String[] args) throws IOException { + String path = "src/test/resources/template/demo.json"; String json; try { json = new String(Files.readAllBytes(PathUtils.get(path)), StandardCharsets.UTF_8); @@ -49,11 +50,12 @@ public static void main(String[] args) { } logger.info("Parsing graph to sequence..."); - List processSequence = TemplateParser.parseJsonGraphToSequence(json); + Template t = TemplateParser.parseJsonToTemplate(json); + List processSequence = TemplateParser.parseWorkflowToSequence(t.workflows().get("datademo")); List> futureList = new ArrayList<>(); for (ProcessNode n : processSequence) { - Set predecessors = n.getPredecessors(); + List predecessors = n.predecessors(); logger.info( "Queueing process [{}].{}", n.id(), diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 15dde22a9..19396dc2e 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -13,6 +13,7 @@ import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; import org.opensearch.flowframework.template.ProcessNode; +import org.opensearch.flowframework.template.Template; import org.opensearch.flowframework.template.TemplateParser; import java.io.IOException; @@ -21,7 +22,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -36,9 +36,10 @@ public class Demo { * Demonstrate parsing a JSON graph. * * @param args unused + * @throws IOException on a failure */ @SuppressForbidden(reason = "just a demo class that will be deleted") - public static void main(String[] args) { + public static void main(String[] args) throws IOException { String path = "src/test/resources/template/demo.json"; String json; try { @@ -49,11 +50,12 @@ public static void main(String[] args) { } logger.info("Parsing graph to sequence..."); - List processSequence = TemplateParser.parseJsonGraphToSequence(json); + Template t = TemplateParser.parseJsonToTemplate(json); + List processSequence = TemplateParser.parseWorkflowToSequence(t.workflows().get("demo")); List> futureList = new ArrayList<>(); for (ProcessNode n : processSequence) { - Set predecessors = n.getPredecessors(); + List predecessors = n.predecessors(); logger.info( "Queueing process [{}].{}", n.id(), diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index 0e92ba8ba..6e9800861 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -12,18 +12,14 @@ import org.apache.logging.log4j.Logger; import org.opensearch.common.SuppressForbidden; import org.opensearch.common.io.PathUtils; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.template.Template; import org.opensearch.flowframework.template.TemplateParser; +import org.opensearch.flowframework.workflow.Workflow; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import java.util.Map.Entry; /** * Demo class exercising {@link TemplateParser}. This will be moved to a unit test. @@ -49,15 +45,14 @@ public static void main(String[] args) throws IOException { return; } - logger.info("Parsing template..."); - XContentParser parser = JsonXContent.jsonXContent.createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - json - ); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Template t = Template.parse(parser); + Template t = TemplateParser.parseJsonToTemplate(json); + System.out.println(t.toJson()); System.out.println(t.toYaml()); + + for (Entry e : t.workflows().entrySet()) { + logger.info("Parsing {} workflow.", e.getKey()); + TemplateParser.parseWorkflowToSequence(e.getValue()); + } } } diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessNode.java b/src/main/java/org/opensearch/flowframework/template/ProcessNode.java index 08a7ec841..b774384a4 100644 --- a/src/main/java/org/opensearch/flowframework/template/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/template/ProcessNode.java @@ -14,10 +14,7 @@ import org.opensearch.flowframework.workflow.WorkflowStep; import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Objects; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -33,32 +30,22 @@ public class ProcessNode { private final String id; private final WorkflowStep workflowStep; private final WorkflowData input; - private CompletableFuture future = null; - - // will be populated during graph parsing - private Set predecessors = Collections.emptySet(); - - /** - * Create this node linked to its executing process. - * - * @param id A string identifying the workflow step - * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. - */ - ProcessNode(String id, WorkflowStep workflowStep) { - this(id, workflowStep, WorkflowData.EMPTY); - } + private final List predecessors; + private final CompletableFuture future = new CompletableFuture<>(); /** - * Create this node linked to its executing process. + * Create this node linked to its executing process, including input data and any predecessor nodes. * * @param id A string identifying the workflow step * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. - * @param input Input required by the node + * @param input Input required by the node encoded in a {@link WorkflowData} instance. + * @param predecessors Nodes preceding this one in the workflow */ - public ProcessNode(String id, WorkflowStep workflowStep, WorkflowData input) { + public ProcessNode(String id, WorkflowStep workflowStep, WorkflowData input, List predecessors) { this.id = id; this.workflowStep = workflowStep; this.input = input; + this.predecessors = predecessors; } /** @@ -92,41 +79,31 @@ public WorkflowData input() { * @return A future indicating the processing state of this node. * Returns {@code null} if it has not begun executing, should not happen if a workflow is sorted and executed topologically. */ - public CompletableFuture getFuture() { + public CompletableFuture future() { return future; } /** * Returns the predecessors of this node in the workflow. - * The predecessor's {@link #getFuture()} must complete before execution begins on this node. + * The predecessor's {@link #future()} must complete before execution begins on this node. * * @return a set of predecessor nodes, if any. At least one node in the graph must have no predecessors and serve as a start node. */ - public Set getPredecessors() { + public List predecessors() { return predecessors; } - /** - * Sets the predecessor node. Called by {@link TemplateParser}. - * - * @param predecessors The predecessors of this node. - */ - void setPredecessors(Set predecessors) { - this.predecessors = Set.copyOf(predecessors); - } - /** * Execute this node in the sequence. Initializes the node's {@link CompletableFuture} and completes it when the process completes. * * @return this node's future. This is returned immediately, while process execution continues asynchronously. */ public CompletableFuture execute() { - this.future = new CompletableFuture<>(); // TODO this class will be instantiated with the OpenSearch thread pool (or one for tests!) // the generic executor from that pool should be passed to this runAsync call // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 CompletableFuture.runAsync(() -> { - List> predFutures = predecessors.stream().map(p -> p.getFuture()).collect(Collectors.toList()); + List> predFutures = predecessors.stream().map(p -> p.future()).collect(Collectors.toList()); if (!predecessors.isEmpty()) { CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); try { @@ -168,20 +145,6 @@ private void handleException(Exception e) { logger.debug("<<< Completed Exceptionally {}", this.id); } - @Override - public int hashCode() { - return Objects.hash(id); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) return true; - if (obj == null) return false; - if (getClass() != obj.getClass()) return false; - ProcessNode other = (ProcessNode) obj; - return Objects.equals(id, other.id); - } - @Override public String toString() { return this.id; diff --git a/src/main/java/org/opensearch/flowframework/template/Template.java b/src/main/java/org/opensearch/flowframework/template/Template.java index 48a1630a2..2e353605b 100644 --- a/src/main/java/org/opensearch/flowframework/template/Template.java +++ b/src/main/java/org/opensearch/flowframework/template/Template.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -53,9 +52,9 @@ public class Template implements ToXContentObject { private final String name; private final String description; private final String useCase; // probably an ENUM actually - private final String[] operations; // probably an ENUM actually + private final List operations; // probably an ENUM actually private final Version templateVersion; - private final Version[] compatibilityVersion; + private final List compatibilityVersion; private final Map userInputs; private final Map workflows; @@ -75,20 +74,20 @@ public Template( String name, String description, String useCase, - String[] operations, + List operations, Version templateVersion, - Version[] compatibilityVersion, + List compatibilityVersion, Map userInputs, Map workflows ) { this.name = name; this.description = description; this.useCase = useCase; - this.operations = operations; + this.operations = List.copyOf(operations); this.templateVersion = templateVersion; - this.compatibilityVersion = compatibilityVersion; - this.userInputs = userInputs; - this.workflows = workflows; + this.compatibilityVersion = List.copyOf(compatibilityVersion); + this.userInputs = Map.copyOf(userInputs); + this.workflows = Map.copyOf(workflows); } @Override @@ -103,12 +102,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } xContentBuilder.endArray(); - if (this.templateVersion != null || this.compatibilityVersion.length > 0) { + if (this.templateVersion != null || !this.compatibilityVersion.isEmpty()) { xContentBuilder.startObject(VERSION_FIELD); if (this.templateVersion != null) { xContentBuilder.field(TEMPLATE_FIELD, this.templateVersion); } - if (this.compatibilityVersion.length > 0) { + if (!this.compatibilityVersion.isEmpty()) { xContentBuilder.startArray(COMPATIBILITY_FIELD); for (Version v : this.compatibilityVersion) { xContentBuilder.value(v); @@ -146,9 +145,9 @@ public static Template parse(XContentParser parser) throws IOException { String name = null; String description = ""; String useCase = ""; - String[] operations = new String[0]; + List operations = new ArrayList<>(); Version templateVersion = null; - Version[] compatibilityVersion = new Version[0]; + List compatibilityVersion = new ArrayList<>(); Map userInputs = new HashMap<>(); Map workflows = new HashMap<>(); @@ -168,11 +167,9 @@ public static Template parse(XContentParser parser) throws IOException { break; case OPERATIONS_FIELD: ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List operationsList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - operationsList.add(parser.text()); + operations.add(parser.text()); } - operations = operationsList.toArray(new String[0]); break; case VERSION_FIELD: ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -185,11 +182,9 @@ public static Template parse(XContentParser parser) throws IOException { break; case COMPATIBILITY_FIELD: ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List compatibilityList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - compatibilityList.add(Version.fromString(parser.text())); + compatibilityVersion.add(Version.fromString(parser.text())); } - compatibilityVersion = compatibilityList.toArray(new Version[0]); break; default: throw new IOException("Unable to parse field [" + fieldName + "] in a version object."); @@ -293,6 +288,70 @@ public String toYaml() { } } + /** + * The name of this template + * @return the name + */ + public String name() { + return name; + } + + /** + * A description of what this template does + * @return the description + */ + public String description() { + return description; + } + + /** + * A canonical use case name for this template + * @return the useCase + */ + public String useCase() { + return useCase; + } + + /** + * Operations this use case supports + * @return the operations + */ + public List operations() { + return operations; + } + + /** + * The version of this template + * @return the templateVersion + */ + public Version templateVersion() { + return templateVersion; + } + + /** + * OpenSearch version compatibility of this template + * @return the compatibilityVersion + */ + public List compatibilityVersion() { + return compatibilityVersion; + } + + /** + * A map of user inputs + * @return the userInputs + */ + public Map userInputs() { + return userInputs; + } + + /** + * Workflows encoded in this template, generally corresponding to the operations returned by {@link #operations()}. + * @return the workflows + */ + public Map workflows() { + return workflows; + } + @Override public String toString() { return "Template [name=" @@ -302,11 +361,11 @@ public String toString() { + ", useCase=" + useCase + ", operations=" - + Arrays.toString(operations) + + operations + ", templateVersion=" + templateVersion + ", compatibilityVersion=" - + Arrays.toString(compatibilityVersion) + + compatibilityVersion + ", userInputs=" + userInputs + ", workflows=" diff --git a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java index 8da58149a..2a9092880 100644 --- a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java +++ b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java @@ -8,16 +8,18 @@ */ package org.opensearch.flowframework.template; -import com.google.gson.Gson; -import com.google.gson.JsonElement; -import com.google.gson.JsonObject; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.workflow.Workflow; import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowStep; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import java.io.IOException; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; @@ -30,6 +32,8 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + /** * Utility class for parsing templates. */ @@ -42,71 +46,95 @@ public class TemplateParser { */ private TemplateParser() {} + /** + * Parse a JSON use case template + * + * @param json A string containing a JSON representation of a use case template + * @return A {@link Template} represented by the JSON. + * @throws IOException on failure to parse + */ + public static Template parseJsonToTemplate(String json) throws IOException { + logger.info("Parsing template..."); + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + json + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + return Template.parse(parser); + } + /** * Parse a JSON representation of nodes and edges into a topologically sorted list of process nodes. - * @param json A string containing a JSON representation of nodes and edges + * @param workflow A string containing a JSON representation of nodes and edges * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. */ - public static List parseJsonGraphToSequence(String json) { - Gson gson = new Gson(); - JsonObject jsonObject = gson.fromJson(json, JsonObject.class); - - JsonObject graph = jsonObject.getAsJsonObject("workflow"); + public static List parseWorkflowToSequence(Workflow workflow) { + List sortedNodes = topologicalSort(workflow.nodes(), workflow.edges()); List nodes = new ArrayList<>(); - List edges = new ArrayList<>(); - - for (JsonElement nodeJson : graph.getAsJsonArray(Workflow.NODES_FIELD)) { - JsonObject nodeObject = nodeJson.getAsJsonObject(); - String nodeId = nodeObject.get(WorkflowNode.ID_FIELD).getAsString(); - String stepType = nodeObject.get(WorkflowNode.TYPE_FIELD).getAsString(); - WorkflowStep workflowStep = WorkflowStepFactory.get().createStep(stepType); + Map idToNodeMap = new HashMap<>(); + for (WorkflowNode node : sortedNodes) { + WorkflowStep step = WorkflowStepFactory.get().createStep(node.type()); + WorkflowData data = new WorkflowData() { + public Map getContent() { + return node.inputs(); + }; - // TODO as part of this PR: Fetch data from the template here - WorkflowData inputData = new WorkflowData() { - // TODO override params and content from user template + public Map getParams() { + return Collections.emptyMap(); + }; }; - nodes.add(new ProcessNode(nodeId, workflowStep, inputData)); - } + List predecessorNodes = workflow.edges() + .stream() + .filter(e -> e.destination().equals(node.id())) + // since we are iterating in topological order we know all predecessors will be in the map + .map(e -> idToNodeMap.get(e.source())) + .collect(Collectors.toList()); - for (JsonElement edgeJson : graph.getAsJsonArray(Workflow.EDGES_FIELD)) { - JsonObject edgeObject = edgeJson.getAsJsonObject(); - String sourceNodeId = edgeObject.get(WorkflowEdge.SOURCE_FIELD).getAsString(); - String destNodeId = edgeObject.get(WorkflowEdge.DEST_FIELD).getAsString(); - if (sourceNodeId.equals(destNodeId)) { - throw new IllegalArgumentException("Edge connects node " + sourceNodeId + " to itself."); - } - edges.add(new WorkflowEdge(sourceNodeId, destNodeId)); + ProcessNode processNode = new ProcessNode(node.id(), step, data, predecessorNodes); + idToNodeMap.put(processNode.id(), processNode); + nodes.add(processNode); } - return topologicalSort(nodes, edges); + return nodes; } - private static List topologicalSort(List nodes, List edges) { - // Define the graph - Set graph = new HashSet<>(edges); - // Map node id string to object - Map nodeMap = nodes.stream().collect(Collectors.toMap(ProcessNode::id, Function.identity())); + private static List topologicalSort(List workflowNodes, List workflowEdges) { + // Basic validation + Set nodeIds = workflowNodes.stream().map(n -> n.id()).collect(Collectors.toSet()); + for (WorkflowEdge edge : workflowEdges) { + String source = edge.source(); + if (!nodeIds.contains(source)) { + throw new IllegalArgumentException("Edge source " + source + " does not correspond to a node."); + } + String dest = edge.destination(); + if (!nodeIds.contains(dest)) { + throw new IllegalArgumentException("Edge destination " + dest + " does not correspond to a node."); + } + if (source.equals(dest)) { + throw new IllegalArgumentException("Edge connects node " + source + " to itself."); + } + } + // Build predecessor and successor maps - Map> predecessorEdges = new HashMap<>(); - Map> successorEdges = new HashMap<>(); - for (WorkflowEdge edge : edges) { - ProcessNode source = nodeMap.get(edge.getSource()); - ProcessNode dest = nodeMap.get(edge.getDestination()); + Map> predecessorEdges = new HashMap<>(); + Map> successorEdges = new HashMap<>(); + Map nodeMap = workflowNodes.stream().collect(Collectors.toMap(WorkflowNode::id, Function.identity())); + for (WorkflowEdge edge : workflowEdges) { + WorkflowNode source = nodeMap.get(edge.source()); + WorkflowNode dest = nodeMap.get(edge.destination()); predecessorEdges.computeIfAbsent(dest, k -> new HashSet<>()).add(edge); successorEdges.computeIfAbsent(source, k -> new HashSet<>()).add(edge); } - // update predecessors on the node object - nodes.stream().filter(n -> predecessorEdges.containsKey(n)).forEach(n -> { - n.setPredecessors(predecessorEdges.get(n).stream().map(e -> nodeMap.get(e.getSource())).collect(Collectors.toSet())); - }); // See https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + Set graph = new HashSet<>(workflowEdges); // L <- Empty list that will contain the sorted elements - List sortedNodes = new ArrayList<>(); + List sortedNodes = new ArrayList<>(); // S <- Set of all nodes with no incoming edge - Queue sourceNodes = new ArrayDeque<>(); - nodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); + Queue sourceNodes = new ArrayDeque<>(); + workflowNodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); if (sourceNodes.isEmpty()) { throw new IllegalArgumentException("No start node detected: all nodes have a predecessor."); } @@ -115,12 +143,12 @@ private static List topologicalSort(List nodes, List inputs) { this.id = id; this.type = type; - this.inputs = inputs; + this.inputs = Map.copyOf(inputs); } @Override @@ -140,7 +140,7 @@ public static WorkflowNode parse(XContentParser parser) throws IOException { * Return this node's id * @return the id */ - public String getId() { + public String id() { return id; } @@ -148,7 +148,7 @@ public String getId() { * Return this node's type * @return the type */ - public String getType() { + public String type() { return type; } @@ -156,7 +156,7 @@ public String getType() { * Return this node's input data * @return the inputs */ - public Map getInputs() { + public Map inputs() { return inputs; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/Workflow.java b/src/main/java/org/opensearch/flowframework/workflow/Workflow.java index 97d38b16d..f0bddc6de 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/Workflow.java +++ b/src/main/java/org/opensearch/flowframework/workflow/Workflow.java @@ -17,7 +17,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -40,8 +39,8 @@ public class Workflow implements ToXContentObject { public static final String EDGES_FIELD = "edges"; private final Map userParams; - private final WorkflowNode[] nodes; - private final WorkflowEdge[] edges; + private final List nodes; + private final List edges; /** * Create this workflow with any user params and the graph of nodes and edges. @@ -50,10 +49,10 @@ public class Workflow implements ToXContentObject { * @param nodes An array of {@link WorkflowNode} objects * @param edges An array of {@link WorkflowEdge} objects. */ - public Workflow(Map userParams, WorkflowNode[] nodes, WorkflowEdge[] edges) { - this.userParams = userParams; - this.nodes = nodes; - this.edges = edges; + public Workflow(Map userParams, List nodes, List edges) { + this.userParams = Map.copyOf(userParams); + this.nodes = List.copyOf(nodes); + this.edges = List.copyOf(edges); } @Override @@ -82,16 +81,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } /** - * Parse raw json content into a workflow instance. + * Parse raw JSON content into a workflow instance. * - * @param parser json based content parser + * @param parser JSON based content parser * @return the parsed Workflow instance * @throws IOException if content can't be parsed correctly */ public static Workflow parse(XContentParser parser) throws IOException { Map userParams = new HashMap<>(); - WorkflowNode[] nodes = null; - WorkflowEdge[] edges = null; + List nodes = new ArrayList<>(); + List edges = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -109,40 +108,58 @@ public static Workflow parse(XContentParser parser) throws IOException { case NODES_FIELD: case STEPS_FIELD: ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List nodesList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - nodesList.add(WorkflowNode.parse(parser)); + nodes.add(WorkflowNode.parse(parser)); } - nodes = nodesList.toArray(new WorkflowNode[0]); break; case EDGES_FIELD: ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - List edgesList = new ArrayList<>(); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - edgesList.add(WorkflowEdge.parse(parser)); + edges.add(WorkflowEdge.parse(parser)); } - edges = edgesList.toArray(new WorkflowEdge[0]); break; } } - if (nodes == null || nodes.length == 0) { + if (nodes.isEmpty()) { throw new IOException("A workflow must have at least one node."); } - if (edges == null || edges.length == 0) { + if (edges.isEmpty()) { // infer edges from sequence of nodes - List edgesList = new ArrayList<>(); // Start iteration at 1, will skip for a one-node array - for (int i = 1; i < nodes.length; i++) { - edgesList.add(new WorkflowEdge(nodes[i - 1].getId(), nodes[i].getId())); + for (int i = 1; i < nodes.size(); i++) { + edges.add(new WorkflowEdge(nodes.get(i - 1).id(), nodes.get(i).id())); } - edges = edgesList.toArray(new WorkflowEdge[0]); } return new Workflow(userParams, nodes, edges); } + /** + * Get user parameters. These will be passed to all workflow nodes and available as {@link WorkflowData#getParams()} + * @return the userParams + */ + public Map userParams() { + return userParams; + } + + /** + * Get the nodes in the workflow. Ordering matches the user template which may or may not match execution order. + * @return the nodes + */ + public List nodes() { + return nodes; + } + + /** + * Get the edges in the workflow. These specify connections of nodes which form a graph. + * @return the edges + */ + public List edges() { + return edges; + } + @Override public String toString() { - return "Workflow [userParams=" + userParams + ", nodes=" + Arrays.toString(nodes) + ", edges=" + Arrays.toString(edges) + "]"; + return "Workflow [userParams=" + userParams + ", nodes=" + nodes + ", edges=" + edges + "]"; } } diff --git a/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java index 3feab9f3b..e05e2555c 100644 --- a/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java @@ -42,24 +42,18 @@ public CompletableFuture execute(List data) { public String getName() { return "test"; } - }); + }, WorkflowData.EMPTY, Collections.emptyList()); assertEquals("A", nodeA.id()); assertEquals("test", nodeA.workflowStep().getName()); assertEquals(WorkflowData.EMPTY, nodeA.input()); - assertEquals(Collections.emptySet(), nodeA.getPredecessors()); + assertEquals(Collections.emptyList(), nodeA.predecessors()); assertEquals("A", nodeA.toString()); // TODO: Once we can get OpenSearch Thread Pool for this execute method, create an IT and don't test execute here CompletableFuture f = nodeA.execute(); - assertEquals(f, nodeA.getFuture()); + assertEquals(f, nodeA.future()); f.orTimeout(5, TimeUnit.SECONDS); assertTrue(f.isDone()); assertEquals(WorkflowData.EMPTY, f.get()); - - ProcessNode nodeB = new ProcessNode("B", null); - assertNotEquals(nodeA, nodeB); - - ProcessNode nodeA2 = new ProcessNode("A", null); - assertEquals(nodeA, nodeA2); } } diff --git a/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java index 1a50da906..1d85e059c 100644 --- a/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java +++ b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java @@ -8,28 +8,39 @@ */ package org.opensearch.flowframework.template; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.workflow.Workflow; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.template.GraphJsonUtil.edge; import static org.opensearch.flowframework.template.GraphJsonUtil.node; import static org.opensearch.flowframework.template.GraphJsonUtil.workflow; public class TemplateParserTests extends OpenSearchTestCase { + private static final String MUST_HAVE_AT_LEAST_ONE_NODE = "A workflow must have at least one node."; private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor."; private static final String CYCLE_DETECTED = "Cycle detected:"; - // Output list elements - private static ProcessNode expectedNode(String id) { - return new ProcessNode(id, null, null); - } - - // Less verbose parser - private static List parse(String json) { - return TemplateParser.parseJsonGraphToSequence(json); + // Wrap parser into string list + private static List parse(String json) throws IOException { + XContentParser parser = JsonXContent.jsonXContent.createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + json + ); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Workflow w = Workflow.parse(parser); + return TemplateParser.parseWorkflowToSequence(w).stream().map(ProcessNode::id).collect(Collectors.toList()); } @Override @@ -37,13 +48,13 @@ public void setUp() throws Exception { super.setUp(); } - public void testOrdering() { - List workflow; + public void testOrdering() throws IOException { + List workflow; workflow = parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("C", "B"), edge("B", "A")))); - assertEquals(0, workflow.indexOf(expectedNode("C"))); - assertEquals(1, workflow.indexOf(expectedNode("B"))); - assertEquals(2, workflow.indexOf(expectedNode("A"))); + assertEquals(0, workflow.indexOf("C")); + assertEquals(1, workflow.indexOf("B")); + assertEquals(2, workflow.indexOf("A")); workflow = parse( workflow( @@ -51,12 +62,12 @@ public void testOrdering() { List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("C", "D")) ) ); - assertEquals(0, workflow.indexOf(expectedNode("A"))); - int b = workflow.indexOf(expectedNode("B")); - int c = workflow.indexOf(expectedNode("C")); + assertEquals(0, workflow.indexOf("A")); + int b = workflow.indexOf("B"); + int c = workflow.indexOf("C"); assertTrue(b == 1 || b == 2); assertTrue(c == 1 || c == 2); - assertEquals(3, workflow.indexOf(expectedNode("D"))); + assertEquals(3, workflow.indexOf("D")); workflow = parse( workflow( @@ -64,14 +75,14 @@ public void testOrdering() { List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("D", "E"), edge("C", "E")) ) ); - assertEquals(0, workflow.indexOf(expectedNode("A"))); - b = workflow.indexOf(expectedNode("B")); - c = workflow.indexOf(expectedNode("C")); - int d = workflow.indexOf(expectedNode("D")); + assertEquals(0, workflow.indexOf("A")); + b = workflow.indexOf("B"); + c = workflow.indexOf("C"); + int d = workflow.indexOf("D"); assertTrue(b == 1 || b == 2); assertTrue(c == 1 || c == 2); assertTrue(d == 2 || d == 3); - assertEquals(4, workflow.indexOf(expectedNode("E"))); + assertEquals(4, workflow.indexOf("E")); } public void testCycles() { @@ -115,18 +126,18 @@ public void testCycles() { assertTrue(ex.getMessage().contains("D->B")); } - public void testNoEdges() { - Exception ex = assertThrows( - IllegalArgumentException.class, - () -> parse(workflow(Collections.emptyList(), Collections.emptyList())) - ); - assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); + public void testNoEdges() throws IOException { + List workflow; + Exception ex = assertThrows(IOException.class, () -> parse(workflow(Collections.emptyList(), Collections.emptyList()))); + assertEquals(MUST_HAVE_AT_LEAST_ONE_NODE, ex.getMessage()); - assertEquals(List.of(expectedNode("A")), parse(workflow(List.of(node("A")), Collections.emptyList()))); + workflow = parse(workflow(List.of(node("A")), Collections.emptyList())); + assertEquals(1, workflow.size()); + assertEquals("A", workflow.get(0)); - List workflow = parse(workflow(List.of(node("A"), node("B")), Collections.emptyList())); + workflow = parse(workflow(List.of(node("A"), node("B")), Collections.emptyList())); assertEquals(2, workflow.size()); - assertTrue(workflow.contains(expectedNode("A"))); - assertTrue(workflow.contains(expectedNode("B"))); + assertTrue(workflow.contains("A")); + assertTrue(workflow.contains("B")); } } diff --git a/src/test/java/org/opensearch/flowframework/template/WorkflowEdgeTests.java b/src/test/java/org/opensearch/flowframework/template/WorkflowEdgeTests.java index f44b54527..6e863445a 100644 --- a/src/test/java/org/opensearch/flowframework/template/WorkflowEdgeTests.java +++ b/src/test/java/org/opensearch/flowframework/template/WorkflowEdgeTests.java @@ -19,8 +19,8 @@ public void setUp() throws Exception { public void testEdge() { WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); - assertEquals("A", edgeAB.getSource()); - assertEquals("B", edgeAB.getDestination()); + assertEquals("A", edgeAB.source()); + assertEquals("B", edgeAB.destination()); assertEquals("A->B", edgeAB.toString()); WorkflowEdge edgeAB2 = new WorkflowEdge("A", "B"); diff --git a/src/test/resources/template/datademo.json b/src/test/resources/template/datademo.json deleted file mode 100644 index 10a2bfdc6..000000000 --- a/src/test/resources/template/datademo.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "sequence": { - "nodes": [ - { - "id": "create_index", - "step_type": "create_index", - "index_name": "demo" - }, - { - "id": "create_another_index", - "step_type": "create_index", - "index_name": "second_demo" - } - ], - "edges": [ - { - "source": "create_index", - "dest": "create_another_index" - } - ] - } -} diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json index c068301e8..86ba1005c 100644 --- a/src/test/resources/template/demo.json +++ b/src/test/resources/template/demo.json @@ -1,40 +1,70 @@ { - "sequence": { - "nodes": [ - { - "id": "fetch_model", - "step_type": "fetch_model" - }, - { - "id": "create_ingest_pipeline", - "step_type": "create_ingest_pipeline" - }, - { - "id": "create_search_pipeline", - "step_type": "create_search_pipeline" - }, - { - "id": "create_neural_search_index", - "step_type": "create_neural_search_index" - } - ], - "edges": [ - { - "source": "fetch_model", - "dest": "create_ingest_pipeline" - }, - { - "source": "fetch_model", - "dest": "create_search_pipeline" - }, - { - "source": "create_ingest_pipeline", - "dest": "create_neural_search_index" - }, - { - "source": "create_search_pipeline", - "dest": "create_neural_search_index" - } - ] + "name": "demo-template", + "description": "Demonstrates workflow steps and passing around of input/output", + "user_inputs": { + "index_name": "my-knn-index", + "index_settings": { + + } + }, + "workflows": { + "demo": { + "nodes": [ + { + "id": "fetch_model", + "type": "fetch_model" + }, + { + "id": "create_ingest_pipeline", + "type": "create_ingest_pipeline" + }, + { + "id": "create_search_pipeline", + "type": "create_search_pipeline" + }, + { + "id": "create_neural_search_index", + "type": "create_neural_search_index" + } + ], + "edges": [ + { + "source": "fetch_model", + "dest": "create_ingest_pipeline" + }, + { + "source": "fetch_model", + "dest": "create_search_pipeline" + }, + { + "source": "create_ingest_pipeline", + "dest": "create_neural_search_index" + }, + { + "source": "create_search_pipeline", + "dest": "create_neural_search_index" + } + ] + }, + "datademo": { + "nodes": [ + { + "id": "create_index", + "type": "create_index", + "index_name": "demo" + }, + { + "id": "create_another_index", + "type": "create_index", + "index_name": "second_demo" + } + ], + "edges": [ + { + "source": "create_index", + "dest": "create_another_index" + } + ] + } } }