diff --git a/.codecov.yml b/.codecov.yml index 7c38e4e63..e5bbd7262 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -1,6 +1,10 @@ codecov: require_ci_to_pass: yes +# ignore files in demo package +ignore: + - "src/main/java/demo" + coverage: precision: 2 round: down diff --git a/build.gradle b/build.gradle index 99170b22a..510788da5 100644 --- a/build.gradle +++ b/build.gradle @@ -56,6 +56,7 @@ opensearchplugin { dependencyLicenses.enabled = false // This requires an additional Jar not published as part of build-tools loggerUsageCheck.enabled = false +thirdPartyAudit.enabled = false // No need to validate pom, as we do not upload to maven/sonatype validateNebulaPom.enabled = false @@ -105,7 +106,8 @@ repositories { dependencies { implementation "org.opensearch:opensearch:${opensearch_version}" implementation 'org.junit.jupiter:junit-jupiter:5.10.0' - compileOnly "com.google.guava:guava:32.1.2-jre" + implementation "com.google.code.gson:gson:2.10.1" + implementation "com.google.guava:guava:32.1.2-jre" api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" configurations.all { diff --git a/formatter/formatting.gradle b/formatter/formatting.gradle index e3bc090e0..8f842128f 100644 --- a/formatter/formatting.gradle +++ b/formatter/formatting.gradle @@ -35,6 +35,7 @@ allprojects { trimTrailingWhitespace() endWithNewline() + indentWithSpaces() } format("license", { licenseHeaderFile("${rootProject.file("formatter/license-header.txt")}", "package "); diff --git a/src/main/java/demo/CreateIndexWorkflowStep.java b/src/main/java/demo/CreateIndexWorkflowStep.java new file mode 100644 index 000000000..6b2ab0a7b --- /dev/null +++ b/src/main/java/demo/CreateIndexWorkflowStep.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Sample to show other devs how to pass data around. Will be deleted once other PRs are merged. + */ +public class CreateIndexWorkflowStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(CreateIndexWorkflowStep.class); + + private final String name; + + /** + * Instantiate this class. + */ + public CreateIndexWorkflowStep() { + this.name = "CREATE_INDEX"; + } + + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + // TODO we will be passing a thread pool to this object when it's instantiated + // we should either add the generic executor from that pool to this call + // or use executorservice.submit or any of various threading options + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 + CompletableFuture.runAsync(() -> { + String inputIndex = null; + boolean first = true; + for (WorkflowData wfData : data) { + logger.debug( + "{} sent params: {}, content: {}", + first ? "Initialization" : "Previous step", + wfData.getParams(), + wfData.getContent() + ); + if (first) { + Map params = data.get(0).getParams(); + if (params.containsKey("index")) { + inputIndex = params.get("index"); + } + first = false; + } + } + // do some work, simulating a REST API call + try { + Thread.sleep(2000); + } catch (InterruptedException e) {} + // Simulate response of created index + CreateIndexResponse response = new CreateIndexResponse(true, true, inputIndex); + future.complete(new WorkflowData(Map.of("index", response.index()))); + }); + + return future; + } + + @Override + public String getName() { + return name; + } +} diff --git a/src/main/java/demo/DataDemo.java b/src/main/java/demo/DataDemo.java new file mode 100644 index 000000000..f2d606f07 --- /dev/null +++ b/src/main/java/demo/DataDemo.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.io.PathUtils; +import org.opensearch.flowframework.template.ProcessNode; +import org.opensearch.flowframework.template.TemplateParser; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +/** + * Demo class exercising {@link TemplateParser}. This will be moved to a unit test. + */ +public class DataDemo { + + private static final Logger logger = LogManager.getLogger(DataDemo.class); + + // This is temporary. We need a factory class to generate these workflow steps + // based on a field in the JSON. + private static Map workflowMap = new HashMap<>(); + static { + workflowMap.put("create_index", new CreateIndexWorkflowStep()); + workflowMap.put("create_another_index", new CreateIndexWorkflowStep()); + } + + /** + * Demonstrate parsing a JSON graph. + * + * @param args unused + */ + @SuppressForbidden(reason = "just a demo class that will be deleted") + public static void main(String[] args) { + String path = "src/test/resources/template/datademo.json"; + String json; + try { + json = new String(Files.readAllBytes(PathUtils.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + logger.error("Failed to read JSON at path {}", path); + return; + } + + logger.info("Parsing graph to sequence..."); + List processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap); + List> futureList = new ArrayList<>(); + + for (ProcessNode n : processSequence) { + Set predecessors = n.getPredecessors(); + logger.info( + "Queueing process [{}].{}", + n.id(), + predecessors.isEmpty() + ? " Can start immediately!" + : String.format( + Locale.getDefault(), + " Must wait for [%s] to complete first.", + predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) + ) + ); + futureList.add(n.execute()); + } + futureList.forEach(CompletableFuture::join); + logger.info("All done!"); + } + +} diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java new file mode 100644 index 000000000..58d977827 --- /dev/null +++ b/src/main/java/demo/Demo.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.SuppressForbidden; +import org.opensearch.common.io.PathUtils; +import org.opensearch.flowframework.template.ProcessNode; +import org.opensearch.flowframework.template.TemplateParser; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +/** + * Demo class exercising {@link TemplateParser}. This will be moved to a unit test. + */ +public class Demo { + + private static final Logger logger = LogManager.getLogger(Demo.class); + + // This is temporary. We need a factory class to generate these workflow steps + // based on a field in the JSON. + private static Map workflowMap = new HashMap<>(); + static { + workflowMap.put("fetch_model", new DemoWorkflowStep(3000)); + workflowMap.put("create_ingest_pipeline", new DemoWorkflowStep(3000)); + workflowMap.put("create_search_pipeline", new DemoWorkflowStep(5000)); + workflowMap.put("create_neural_search_index", new DemoWorkflowStep(2000)); + } + + /** + * Demonstrate parsing a JSON graph. + * + * @param args unused + */ + @SuppressForbidden(reason = "just a demo class that will be deleted") + public static void main(String[] args) { + String path = "src/test/resources/template/demo.json"; + String json; + try { + json = new String(Files.readAllBytes(PathUtils.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + logger.error("Failed to read JSON at path {}", path); + return; + } + + logger.info("Parsing graph to sequence..."); + List processSequence = TemplateParser.parseJsonGraphToSequence(json, workflowMap); + List> futureList = new ArrayList<>(); + + for (ProcessNode n : processSequence) { + Set predecessors = n.getPredecessors(); + logger.info( + "Queueing process [{}].{}", + n.id(), + predecessors.isEmpty() + ? " Can start immediately!" + : String.format( + Locale.getDefault(), + " Must wait for [%s] to complete first.", + predecessors.stream().map(p -> p.id()).collect(Collectors.joining(", ")) + ) + ); + // TODO need to handle this better, passing an argument when we start them all at the beginning is silly + futureList.add(n.execute()); + } + futureList.forEach(CompletableFuture::join); + logger.info("All done!"); + } + +} diff --git a/src/main/java/demo/DemoWorkflowStep.java b/src/main/java/demo/DemoWorkflowStep.java new file mode 100644 index 000000000..037d9b6f6 --- /dev/null +++ b/src/main/java/demo/DemoWorkflowStep.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package demo; + +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.List; +import java.util.concurrent.CompletableFuture; + +/** + * Demo workflowstep to show sequenced execution + */ +public class DemoWorkflowStep implements WorkflowStep { + + private final long delay; + private final String name; + + /** + * Instantiate a step with a delay. + * @param delay milliseconds to take pretending to do work while really sleeping + */ + public DemoWorkflowStep(long delay) { + this.delay = delay; + this.name = "DEMO_DELAY_" + delay; + } + + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + CompletableFuture.runAsync(() -> { + try { + Thread.sleep(this.delay); + future.complete(null); + } catch (InterruptedException e) { + future.completeExceptionally(e); + } + }); + return future; + } + + @Override + public String getName() { + return name; + } +} diff --git a/src/main/java/demo/README.txt b/src/main/java/demo/README.txt new file mode 100644 index 000000000..4fef77960 --- /dev/null +++ b/src/main/java/demo/README.txt @@ -0,0 +1,13 @@ + +DO NOT DEPEND ON CLASSES IN THIS PACKAGE. + +The contents of this folder are for demo/proof-of-concept use. + +Feel free to look at the classes in this folder for potential "how could I" scenarios. + +Tests will not be written against them. +Documentation may be incomplete, wrong, or outdated. +These are not for production use. +They will be deleted without notice at some point, and altered without notice at other points. + +DO NOT DEPEND ON CLASSES IN THIS PACKAGE. diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index f810767eb..e5df0bf46 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -8,11 +8,49 @@ */ package org.opensearch.flowframework; +import com.google.common.collect.ImmutableList; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.flowframework.workflow.CreateIndex.CreateIndexStep; +import org.opensearch.flowframework.workflow.CreateIngestPipelineStep; import org.opensearch.plugins.Plugin; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.watcher.ResourceWatcherService; + +import java.util.Collection; +import java.util.function.Supplier; /** * An OpenSearch plugin that enables builders to innovate AI apps on OpenSearch. */ public class FlowFrameworkPlugin extends Plugin { - // Implement the relevant Plugin Interfaces here + + private Client client; + + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { + this.client = client; + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + CreateIndexStep createIndexStep = new CreateIndexStep(client); + return ImmutableList.of(createIngestPipelineStep, createIndexStep); + } } diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessNode.java b/src/main/java/org/opensearch/flowframework/template/ProcessNode.java new file mode 100644 index 000000000..08a7ec841 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/ProcessNode.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * Representation of a process node in a workflow graph. Tracks predecessor nodes which must be completed before it can start execution. + */ +public class ProcessNode { + + private static final Logger logger = LogManager.getLogger(ProcessNode.class); + + private final String id; + private final WorkflowStep workflowStep; + private final WorkflowData input; + private CompletableFuture future = null; + + // will be populated during graph parsing + private Set predecessors = Collections.emptySet(); + + /** + * Create this node linked to its executing process. + * + * @param id A string identifying the workflow step + * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. + */ + ProcessNode(String id, WorkflowStep workflowStep) { + this(id, workflowStep, WorkflowData.EMPTY); + } + + /** + * Create this node linked to its executing process. + * + * @param id A string identifying the workflow step + * @param workflowStep A java class implementing {@link WorkflowStep} to be executed when it's this node's turn. + * @param input Input required by the node + */ + public ProcessNode(String id, WorkflowStep workflowStep, WorkflowData input) { + this.id = id; + this.workflowStep = workflowStep; + this.input = input; + } + + /** + * Returns the node's id. + * @return the node id. + */ + public String id() { + return id; + } + + /** + * Returns the node's workflow implementation. + * @return the workflow step + */ + public WorkflowStep workflowStep() { + return workflowStep; + } + + /** + * Returns the input data for this node. + * @return the input data + */ + public WorkflowData input() { + return input; + } + + /** + * Returns a {@link CompletableFuture} if this process is executing. + * Relies on the node having been sorted and executed in an order such that all predecessor nodes have begun execution first (and thus populated this value). + * + * @return A future indicating the processing state of this node. + * Returns {@code null} if it has not begun executing, should not happen if a workflow is sorted and executed topologically. + */ + public CompletableFuture getFuture() { + return future; + } + + /** + * Returns the predecessors of this node in the workflow. + * The predecessor's {@link #getFuture()} must complete before execution begins on this node. + * + * @return a set of predecessor nodes, if any. At least one node in the graph must have no predecessors and serve as a start node. + */ + public Set getPredecessors() { + return predecessors; + } + + /** + * Sets the predecessor node. Called by {@link TemplateParser}. + * + * @param predecessors The predecessors of this node. + */ + void setPredecessors(Set predecessors) { + this.predecessors = Set.copyOf(predecessors); + } + + /** + * Execute this node in the sequence. Initializes the node's {@link CompletableFuture} and completes it when the process completes. + * + * @return this node's future. This is returned immediately, while process execution continues asynchronously. + */ + public CompletableFuture execute() { + this.future = new CompletableFuture<>(); + // TODO this class will be instantiated with the OpenSearch thread pool (or one for tests!) + // the generic executor from that pool should be passed to this runAsync call + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 + CompletableFuture.runAsync(() -> { + List> predFutures = predecessors.stream().map(p -> p.getFuture()).collect(Collectors.toList()); + if (!predecessors.isEmpty()) { + CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); + try { + // We need timeouts to be part of the user template or in settings + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/45 + waitForPredecessors.orTimeout(30, TimeUnit.SECONDS).get(); + } catch (InterruptedException | ExecutionException e) { + handleException(e); + return; + } + } + logger.info(">>> Starting {}.", this.id); + // get the input data from predecessor(s) + List input = new ArrayList(); + input.add(this.input); + for (CompletableFuture cf : predFutures) { + try { + input.add(cf.get()); + } catch (InterruptedException | ExecutionException e) { + handleException(e); + return; + } + } + CompletableFuture stepFuture = this.workflowStep.execute(input); + try { + stepFuture.join(); + future.complete(stepFuture.get()); + logger.debug("<<< Completed {}", this.id); + } catch (InterruptedException | ExecutionException e) { + handleException(e); + } + }); + return this.future; + } + + private void handleException(Exception e) { + // TODO: better handling of getCause + this.future.completeExceptionally(e); + logger.debug("<<< Completed Exceptionally {}", this.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + ProcessNode other = (ProcessNode) obj; + return Objects.equals(id, other.id); + } + + @Override + public String toString() { + return this.id; + } +} diff --git a/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java b/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java new file mode 100644 index 000000000..9544620fb --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/ProcessSequenceEdge.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import java.util.Objects; + +/** + * Representation of an edge between process nodes in a workflow graph. + */ +public class ProcessSequenceEdge { + private final String source; + private final String destination; + + /** + * Create this edge with the id's of the source and destination nodes. + * + * @param source The source node id. + * @param destination The destination node id. + */ + ProcessSequenceEdge(String source, String destination) { + this.source = source; + this.destination = destination; + } + + /** + * Gets the source node id. + * + * @return the source node id. + */ + public String getSource() { + return source; + } + + /** + * Gets the destination node id. + * + * @return the destination node id. + */ + public String getDestination() { + return destination; + } + + @Override + public int hashCode() { + return Objects.hash(destination, source); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null) return false; + if (getClass() != obj.getClass()) return false; + ProcessSequenceEdge other = (ProcessSequenceEdge) obj; + return Objects.equals(destination, other.destination) && Objects.equals(source, other.source); + } + + @Override + public String toString() { + return this.source + "->" + this.destination; + } +} diff --git a/src/main/java/org/opensearch/flowframework/template/TemplateParser.java b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java new file mode 100644 index 000000000..56635f1b4 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/template/TemplateParser.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Utility class for parsing templates. + */ +public class TemplateParser { + + private static final Logger logger = LogManager.getLogger(TemplateParser.class); + + // Field names in the JSON. Package private for tests. + static final String WORKFLOW = "sequence"; + static final String NODES = "nodes"; + static final String NODE_ID = "id"; + static final String EDGES = "edges"; + static final String SOURCE = "source"; + static final String DESTINATION = "dest"; + + /** + * Prevent instantiating this class. + */ + private TemplateParser() {} + + /** + * Parse a JSON representation of nodes and edges into a topologically sorted list of process nodes. + * @param json A string containing a JSON representation of nodes and edges + * @param workflowSteps A map linking JSON node names to Java objects implementing {@link WorkflowStep} + * @return A list of Process Nodes sorted topologically. All predecessors of any node will occur prior to it in the list. + */ + public static List parseJsonGraphToSequence(String json, Map workflowSteps) { + Gson gson = new Gson(); + JsonObject jsonObject = gson.fromJson(json, JsonObject.class); + + JsonObject graph = jsonObject.getAsJsonObject(WORKFLOW); + + List nodes = new ArrayList<>(); + List edges = new ArrayList<>(); + + for (JsonElement nodeJson : graph.getAsJsonArray(NODES)) { + JsonObject nodeObject = nodeJson.getAsJsonObject(); + String nodeId = nodeObject.get(NODE_ID).getAsString(); + // The below steps will be replaced by a generator class that instantiates a WorkflowStep + // based on user_input data from the template. + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/41 + WorkflowStep workflowStep = workflowSteps.get(nodeId); + // temporary demo POC of getting from a request to input data + // this will be refactored into something pulling from user template as part of the above issue + WorkflowData inputData = WorkflowData.EMPTY; + if (List.of("create_index", "create_another_index").contains(nodeId)) { + CreateIndexRequest request = new CreateIndexRequest(nodeObject.get("index_name").getAsString()); + inputData = new WorkflowData( + Map.of("mappings", request.mappings(), "settings", request.settings(), "aliases", request.aliases()), + Map.of("index", request.index()) + ); + } + nodes.add(new ProcessNode(nodeId, workflowStep, inputData)); + } + + for (JsonElement edgeJson : graph.getAsJsonArray(EDGES)) { + JsonObject edgeObject = edgeJson.getAsJsonObject(); + String sourceNodeId = edgeObject.get(SOURCE).getAsString(); + String destNodeId = edgeObject.get(DESTINATION).getAsString(); + if (sourceNodeId.equals(destNodeId)) { + throw new IllegalArgumentException("Edge connects node " + sourceNodeId + " to itself."); + } + edges.add(new ProcessSequenceEdge(sourceNodeId, destNodeId)); + } + + return topologicalSort(nodes, edges); + } + + private static List topologicalSort(List nodes, List edges) { + // Define the graph + Set graph = new HashSet<>(edges); + // Map node id string to object + Map nodeMap = nodes.stream().collect(Collectors.toMap(ProcessNode::id, Function.identity())); + // Build predecessor and successor maps + Map> predecessorEdges = new HashMap<>(); + Map> successorEdges = new HashMap<>(); + for (ProcessSequenceEdge edge : edges) { + ProcessNode source = nodeMap.get(edge.getSource()); + ProcessNode dest = nodeMap.get(edge.getDestination()); + predecessorEdges.computeIfAbsent(dest, k -> new HashSet<>()).add(edge); + successorEdges.computeIfAbsent(source, k -> new HashSet<>()).add(edge); + } + // update predecessors on the node object + nodes.stream().filter(n -> predecessorEdges.containsKey(n)).forEach(n -> { + n.setPredecessors(predecessorEdges.get(n).stream().map(e -> nodeMap.get(e.getSource())).collect(Collectors.toSet())); + }); + + // See https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + // L <- Empty list that will contain the sorted elements + List sortedNodes = new ArrayList<>(); + // S <- Set of all nodes with no incoming edge + Queue sourceNodes = new ArrayDeque<>(); + nodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); + if (sourceNodes.isEmpty()) { + throw new IllegalArgumentException("No start node detected: all nodes have a predecessor."); + } + logger.debug("Start node(s): {}", sourceNodes); + + // while S is not empty do + while (!sourceNodes.isEmpty()) { + // remove a node n from S + ProcessNode n = sourceNodes.poll(); + // add n to L + sortedNodes.add(n); + // for each node m with an edge e from n to m do + for (ProcessSequenceEdge e : successorEdges.getOrDefault(n, Collections.emptySet())) { + ProcessNode m = nodeMap.get(e.getDestination()); + // remove edge e from the graph + graph.remove(e); + // if m has no other incoming edges then + if (!predecessorEdges.get(m).stream().anyMatch(i -> graph.contains(i))) { + // insert m into S + sourceNodes.add(m); + } + } + } + if (!graph.isEmpty()) { + throw new IllegalArgumentException("Cycle detected: " + graph); + } + logger.debug("Execution sequence: {}", sortedNodes); + return sortedNodes; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStep.java new file mode 100644 index 000000000..7f92b8057 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStep.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow.CreateIndex; + +import com.google.common.base.Charsets; +import com.google.common.io.Resources; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; + +import java.io.IOException; +import java.net.URL; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Step to create an index + */ +public class CreateIndexStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(CreateIndexStep.class); + private Client client; + private final String NAME = "create_index_step"; + + /** + * Instantiate this class + * @param client Client to create an index + */ + public CreateIndexStep(Client client) { + this.client = client; + } + + @Override + public CompletableFuture execute(List data) { + CompletableFuture future = new CompletableFuture<>(); + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(CreateIndexResponse createIndexResponse) { + logger.info("created index: {}", createIndexResponse.index()); + future.complete(new WorkflowData(Map.of("index-name", createIndexResponse.index()))); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to create an index", e); + future.completeExceptionally(e); + } + }; + + String index = null; + String type = null; + Settings settings = null; + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + index = (String) content.get("index-name"); + type = (String) content.get("type"); + if (index != null && type != null && settings != null) { + break; + } + } + + // TODO: + // 1. Create settings based on the index settings received from content + + try { + CreateIndexRequest request = new CreateIndexRequest(index).mapping( + getIndexMappings("mappings/" + type + ".json"), + XContentType.JSON + ); + client.admin().indices().create(request, actionListener); + } catch (Exception e) { + logger.error("Failed to find the right mapping for the index", e); + } + + return future; + } + + @Override + public String getName() { + return NAME; + } + + /** + * Get index mapping json content. + * + * @param mapping type of the index to fetch the specific mapping file + * @return index mapping + * @throws IOException IOException if mapping file can't be read correctly + */ + private static String getIndexMappings(String mapping) throws IOException { + URL url = CreateIndexStep.class.getClassLoader().getResource(mapping); + return Resources.toString(url, Charsets.UTF_8); + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java new file mode 100644 index 000000000..8382925b2 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -0,0 +1,191 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ingest.PutPipelineRequest; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +/** + * Workflow step to create an ingest pipeline + */ +public class CreateIngestPipelineStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(CreateIngestPipelineStep.class); + private static final String NAME = "create_ingest_pipeline_step"; + + // Common pipeline configuration fields + private static final String PIPELINE_ID_FIELD = "id"; + private static final String DESCRIPTION_FIELD = "description"; + private static final String PROCESSORS_FIELD = "processors"; + private static final String TYPE_FIELD = "type"; + + // Temporary text embedding processor fields + private static final String FIELD_MAP = "field_map"; + private static final String MODEL_ID_FIELD = "model_id"; + private static final String INPUT_FIELD = "input_field_name"; + private static final String OUTPUT_FIELD = "output_field_name"; + + // Client to store a pipeline in the cluster state + private final ClusterAdminClient clusterAdminClient; + + /** + * Instantiates a new CreateIngestPipelineStep + * + * @param client The client to create a pipeline and store workflow data into the global context index + */ + public CreateIngestPipelineStep(Client client) { + this.clusterAdminClient = client.admin().cluster(); + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture createIngestPipelineFuture = new CompletableFuture<>(); + + String pipelineId = null; + String description = null; + String type = null; + String modelId = null; + String inputFieldName = null; + String outputFieldName = null; + BytesReference configuration = null; + + // Extract required content from workflow data and generate the ingest pipeline configuration + for (WorkflowData workflowData : data) { + + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case PIPELINE_ID_FIELD: + pipelineId = (String) content.get(PIPELINE_ID_FIELD); + break; + case DESCRIPTION_FIELD: + description = (String) content.get(DESCRIPTION_FIELD); + break; + case TYPE_FIELD: + type = (String) content.get(TYPE_FIELD); + break; + case MODEL_ID_FIELD: + modelId = (String) content.get(MODEL_ID_FIELD); + break; + case INPUT_FIELD: + inputFieldName = (String) content.get(INPUT_FIELD); + break; + case OUTPUT_FIELD: + outputFieldName = (String) content.get(OUTPUT_FIELD); + break; + default: + break; + } + } + + // Determmine if fields have been populated, else iterate over remaining workflow data + if (Stream.of(pipelineId, description, modelId, type, inputFieldName, outputFieldName).allMatch(x -> x != null)) { + try { + configuration = BytesReference.bytes( + buildIngestPipelineRequestContent(description, modelId, type, inputFieldName, outputFieldName) + ); + } catch (IOException e) { + logger.error("Failed to create ingest pipeline configuration: " + e.getMessage()); + createIngestPipelineFuture.completeExceptionally(e); + } + break; + } + } + + if (configuration == null) { + // Required workflow data not found + createIngestPipelineFuture.completeExceptionally(new Exception("Failed to create ingest pipeline, required inputs not found")); + } else { + // Create PutPipelineRequest and execute + PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configuration, XContentType.JSON); + clusterAdminClient.putPipeline(putPipelineRequest, ActionListener.wrap(response -> { + logger.info("Created ingest pipeline : " + putPipelineRequest.getId()); + + // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead + createIngestPipelineFuture.complete(new WorkflowData(Map.of("pipelineId", putPipelineRequest.getId()))); + + // TODO : Use node client to index response data to global context (pending global context index implementation) + + }, exception -> { + logger.error("Failed to create ingest pipeline : " + exception.getMessage()); + createIngestPipelineFuture.completeExceptionally(exception); + })); + } + + return createIngestPipelineFuture; + } + + @Override + public String getName() { + return NAME; + } + + /** + * Temporary, generates the ingest pipeline request content for text_embedding processor from workflow data + * { + * "description" : "", + * "processors" : [ + * { + * "" : { + * "model_id" : "", + * "field_map" : { + * "" : "" + * } + * } + * ] + * } + * + * @param description The description of the ingest pipeline configuration + * @param modelId The ID of the model that will be used in the embedding interface + * @param type The processor type + * @param inputFieldName The field name used to cache text for text embeddings + * @param outputFieldName The field name in which output text is stored + * @throws IOException if the request content fails to be generated + * @return the xcontent builder with the formatted ingest pipeline configuration + */ + private XContentBuilder buildIngestPipelineRequestContent( + String description, + String modelId, + String type, + String inputFieldName, + String outputFieldName + ) throws IOException { + return XContentFactory.jsonBuilder() + .startObject() + .field(DESCRIPTION_FIELD, description) + .startArray(PROCESSORS_FIELD) + .startObject() + .startObject(type) + .field(MODEL_ID_FIELD, modelId) + .startObject(FIELD_MAP) + .field(inputFieldName, outputFieldName) + .endObject() + .endObject() + .endObject() + .endArray() + .endObject(); + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java index 3e8dc81b2..fbe4a5708 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowData.java @@ -8,7 +8,57 @@ */ package org.opensearch.flowframework.workflow; +import java.util.Collections; +import java.util.Map; + /** - * Interface for handling the input/output of the building blocks. + * Class encapsulating data provided as input to, and produced as output from, {@link WorkflowStep}s. */ -public interface WorkflowData {} +public class WorkflowData { + + /** + * An object representing no data, useful when a workflow step has no required input or output. + */ + public static WorkflowData EMPTY = new WorkflowData(); + + private final Map content; + private final Map params; + + private WorkflowData() { + this(Collections.emptyMap(), Collections.emptyMap()); + } + + /** + * Instantiate this object with content and empty params. + * @param content The content map + */ + public WorkflowData(Map content) { + this(content, Collections.emptyMap()); + } + + /** + * Instantiate this object with content and params. + * @param content The content map + * @param params The params map + */ + public WorkflowData(Map content, Map params) { + this.content = Map.copyOf(content); + this.params = Map.copyOf(params); + } + + /** + * Returns a map which represents the content associated with a Rest API request or response. + * @return the content of this data. + */ + public Map getContent() { + return this.content; + }; + + /** + * Returns a map represents the params associated with a Rest API request, parsed from the URI. + * @return the params of this data. + */ + public Map getParams() { + return this.params; + }; +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index 6a65ce6e3..6cd5f5a28 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -8,8 +8,7 @@ */ package org.opensearch.flowframework.workflow; -import org.opensearch.common.Nullable; - +import java.util.List; import java.util.concurrent.CompletableFuture; /** @@ -18,11 +17,11 @@ public interface WorkflowStep { /** - * Triggers the processing of the building block. - * @param data for input/output params of the building blocks. - * @return CompletableFuture of the building block. + * Triggers the actual processing of the building block. + * @param data representing input params and content, or output content of previous steps. The first element of the list is data (if any) provided from parsing the template, and may be {@link WorkflowData#EMPTY}. + * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. */ - CompletableFuture execute(@Nullable WorkflowData data); + CompletableFuture execute(List data); /** * diff --git a/src/main/resources/log4j2.xml b/src/main/resources/log4j2.xml new file mode 100644 index 000000000..21a4c6fa5 --- /dev/null +++ b/src/main/resources/log4j2.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/main/resources/mappings/knn.json b/src/main/resources/mappings/knn.json new file mode 100644 index 000000000..c31946e62 --- /dev/null +++ b/src/main/resources/mappings/knn.json @@ -0,0 +1,16 @@ +{ + "properties": { + "desc_v": { + "type": "keyword" + }, + "name_v": { + "type": "keyword" + }, + "description": { + "type": "keyword" + }, + "name": { + "type": "keyword" + } + } +} diff --git a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginIT.java b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginIT.java index d54dc2c63..0dccc27ce 100644 --- a/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginIT.java +++ b/src/test/java/org/opensearch/flowframework/FlowFrameworkPluginIT.java @@ -22,8 +22,6 @@ import java.util.Collection; import java.util.Collections; -import static org.hamcrest.Matchers.containsString; - @ThreadLeakScope(ThreadLeakScope.Scope.NONE) @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE) public class FlowFrameworkPluginIT extends OpenSearchIntegTestCase { @@ -38,6 +36,6 @@ public void testPluginInstalled() throws IOException, ParseException { String body = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8); logger.info("response body: {}", body); - assertThat(body, containsString("flowframework")); + assertTrue(body.contains("flowframework")); } } diff --git a/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java new file mode 100644 index 000000000..d9f365708 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/template/ProcessNodeTests.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope.Scope; + +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.flowframework.workflow.WorkflowStep; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +@ThreadLeakScope(Scope.NONE) +public class ProcessNodeTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testNode() throws InterruptedException, ExecutionException { + ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() { + @Override + public CompletableFuture execute(List data) { + CompletableFuture f = new CompletableFuture<>(); + f.complete(WorkflowData.EMPTY); + return f; + } + + @Override + public String getName() { + return "test"; + } + }); + assertEquals("A", nodeA.id()); + assertEquals("test", nodeA.workflowStep().getName()); + assertEquals(WorkflowData.EMPTY, nodeA.input()); + assertEquals(Collections.emptySet(), nodeA.getPredecessors()); + assertEquals("A", nodeA.toString()); + + // TODO: This test is flaky on Windows. Disabling until thread pool is integrated + // https://github.com/opensearch-project/opensearch-ai-flow-framework/issues/42 + // CompletableFuture f = nodeA.execute(); + // assertEquals(f, nodeA.future()); + // f.orTimeout(5, TimeUnit.SECONDS); + // assertTrue(f.isDone()); + // assertEquals(WorkflowData.EMPTY, f.get()); + + ProcessNode nodeB = new ProcessNode("B", null); + assertNotEquals(nodeA, nodeB); + + ProcessNode nodeA2 = new ProcessNode("A", null); + assertEquals(nodeA, nodeA2); + } +} diff --git a/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java b/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java new file mode 100644 index 000000000..80cecd96e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/template/ProcessSequenceEdgeTests.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import org.opensearch.test.OpenSearchTestCase; + +public class ProcessSequenceEdgeTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testEdge() { + ProcessSequenceEdge edgeAB = new ProcessSequenceEdge("A", "B"); + assertEquals("A", edgeAB.getSource()); + assertEquals("B", edgeAB.getDestination()); + assertEquals("A->B", edgeAB.toString()); + + ProcessSequenceEdge edgeAB2 = new ProcessSequenceEdge("A", "B"); + assertEquals(edgeAB, edgeAB2); + + ProcessSequenceEdge edgeAC = new ProcessSequenceEdge("A", "C"); + assertNotEquals(edgeAB, edgeAC); + } +} diff --git a/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java new file mode 100644 index 000000000..24dcf0640 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/template/TemplateParserTests.java @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.template; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.opensearch.flowframework.template.TemplateParser.DESTINATION; +import static org.opensearch.flowframework.template.TemplateParser.EDGES; +import static org.opensearch.flowframework.template.TemplateParser.NODES; +import static org.opensearch.flowframework.template.TemplateParser.NODE_ID; +import static org.opensearch.flowframework.template.TemplateParser.SOURCE; +import static org.opensearch.flowframework.template.TemplateParser.WORKFLOW; + +public class TemplateParserTests extends OpenSearchTestCase { + + private static final String NO_START_NODE_DETECTED = "No start node detected: all nodes have a predecessor."; + private static final String CYCLE_DETECTED = "Cycle detected:"; + + // Input JSON generators + private static String node(String id) { + return "{\"" + NODE_ID + "\": \"" + id + "\"}"; + } + + private static String edge(String sourceId, String destId) { + return "{\"" + SOURCE + "\": \"" + sourceId + "\", \"" + DESTINATION + "\": \"" + destId + "\"}"; + } + + private static String workflow(List nodes, List edges) { + return "{\"" + WORKFLOW + "\": {" + arrayField(NODES, nodes) + ", " + arrayField(EDGES, edges) + "}}"; + } + + private static String arrayField(String fieldName, List objects) { + return "\"" + fieldName + "\": [" + objects.stream().collect(Collectors.joining(", ")) + "]"; + } + + // Output list elements + private static ProcessNode expectedNode(String id) { + return new ProcessNode(id, null, null); + } + + // Less verbose parser + private static List parse(String json) { + return TemplateParser.parseJsonGraphToSequence(json, Collections.emptyMap()); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testOrdering() { + List workflow; + + workflow = parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("C", "B"), edge("B", "A")))); + assertEquals(0, workflow.indexOf(expectedNode("C"))); + assertEquals(1, workflow.indexOf(expectedNode("B"))); + assertEquals(2, workflow.indexOf(expectedNode("A"))); + + workflow = parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D")), + List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("C", "D")) + ) + ); + assertEquals(0, workflow.indexOf(expectedNode("A"))); + int b = workflow.indexOf(expectedNode("B")); + int c = workflow.indexOf(expectedNode("C")); + assertTrue(b == 1 || b == 2); + assertTrue(c == 1 || c == 2); + assertEquals(3, workflow.indexOf(expectedNode("D"))); + + workflow = parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D"), node("E")), + List.of(edge("A", "B"), edge("A", "C"), edge("B", "D"), edge("D", "E"), edge("C", "E")) + ) + ); + assertEquals(0, workflow.indexOf(expectedNode("A"))); + b = workflow.indexOf(expectedNode("B")); + c = workflow.indexOf(expectedNode("C")); + int d = workflow.indexOf(expectedNode("D")); + assertTrue(b == 1 || b == 2); + assertTrue(c == 1 || c == 2); + assertTrue(d == 2 || d == 3); + assertEquals(4, workflow.indexOf(expectedNode("E"))); + } + + public void testCycles() { + Exception ex; + + ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A")), List.of(edge("A", "A"))))); + assertEquals("Edge connects node A to itself.", ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "B")))) + ); + assertEquals("Edge connects node B to itself.", ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "A")))) + ); + assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("A", "B"), edge("B", "C"), edge("C", "B")))) + ); + assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); + assertTrue(ex.getMessage().contains("B->C")); + assertTrue(ex.getMessage().contains("C->B")); + + ex = assertThrows( + IllegalArgumentException.class, + () -> parse( + workflow( + List.of(node("A"), node("B"), node("C"), node("D")), + List.of(edge("A", "B"), edge("B", "C"), edge("C", "D"), edge("D", "B")) + ) + ) + ); + assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); + assertTrue(ex.getMessage().contains("B->C")); + assertTrue(ex.getMessage().contains("C->D")); + assertTrue(ex.getMessage().contains("D->B")); + } + + public void testNoEdges() { + Exception ex = assertThrows( + IllegalArgumentException.class, + () -> parse(workflow(Collections.emptyList(), Collections.emptyList())) + ); + assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); + + assertEquals(List.of(expectedNode("A")), parse(workflow(List.of(node("A")), Collections.emptyList()))); + + List workflow = parse(workflow(List.of(node("A"), node("B")), Collections.emptyList())); + assertEquals(2, workflow.size()); + assertTrue(workflow.contains(expectedNode("A"))); + assertTrue(workflow.contains(expectedNode("B"))); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStepTests.java new file mode 100644 index 000000000..c5d680a94 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndex/CreateIndexStepTests.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow.CreateIndex; + +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class CreateIndexStepTests extends OpenSearchTestCase { + + private WorkflowData inputData = WorkflowData.EMPTY; + + private Client client; + + private AdminClient adminClient; + + private IndicesAdminClient indicesAdminClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData(Map.ofEntries(Map.entry("index-name", "demo"), Map.entry("type", "knn"))); + client = mock(Client.class); + adminClient = mock(AdminClient.class); + indicesAdminClient = mock(IndicesAdminClient.class); + + when(adminClient.indices()).thenReturn(indicesAdminClient); + when(client.admin()).thenReturn(adminClient); + + } + + public void testCreateIndexStep() throws ExecutionException, InterruptedException, IOException { + + CreateIndexStep createIndexStep = new CreateIndexStep(client); + + ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + CompletableFuture future = createIndexStep.execute(List.of(inputData)); + assertFalse(future.isDone()); + verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onResponse(new CreateIndexResponse(true, true, "demo")); + + assertTrue(future.isDone() && !future.isCompletedExceptionally()); + + Map outputData = Map.of("index-name", "demo"); + assertEquals(outputData, future.get().getContent()); + + } + + public void testCreateIndexStepFailure() throws ExecutionException, InterruptedException { + + CreateIndexStep createIndexStep = new CreateIndexStep(client); + + ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + CompletableFuture future = createIndexStep.execute(List.of(inputData)); + assertFalse(future.isDone()); + verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); + + actionListenerCaptor.getValue().onFailure(new Exception("Failed to create an index")); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof Exception); + assertEquals("Failed to create an index", ex.getCause().getMessage()); + } +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java new file mode 100644 index 000000000..286bc2de9 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java @@ -0,0 +1,126 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow.CreateIngestPipeline; + +import org.opensearch.action.ingest.PutPipelineRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.workflow.CreateIngestPipelineStep; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class CreateIngestPipelineStepTests extends OpenSearchTestCase { + + private WorkflowData inputData; + private WorkflowData outpuData; + private Client client; + private AdminClient adminClient; + private ClusterAdminClient clusterAdminClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("id", "pipelineId"), + Map.entry("description", "some description"), + Map.entry("type", "text_embedding"), + Map.entry("model_id", "model_id"), + Map.entry("input_field_name", "inputField"), + Map.entry("output_field_name", "outputField") + ) + ); + + // Set output data to returned pipelineId + outpuData = new WorkflowData(Map.ofEntries(Map.entry("pipelineId", "pipelineId"))); + + client = mock(Client.class); + adminClient = mock(AdminClient.class); + clusterAdminClient = mock(ClusterAdminClient.class); + + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + } + + public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException { + + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + + ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + + assertFalse(future.isDone()); + + // Mock put pipeline request execution and return true + verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onResponse(new AcknowledgedResponse(true)); + + assertTrue(future.isDone() && !future.isCompletedExceptionally()); + assertEquals(outpuData.getContent(), future.get().getContent()); + } + + public void testCreateIngestPipelineStepFailure() throws InterruptedException { + + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + + ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + + assertFalse(future.isDone()); + + // Mock put pipeline request execution and return false + verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onFailure(new Exception("Failed to create ingest pipeline")); + + assertTrue(future.isDone() && future.isCompletedExceptionally()); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals("Failed to create ingest pipeline", exception.getCause().getMessage()); + } + + public void testMissingData() throws InterruptedException { + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + + // Data with missing input and output fields + WorkflowData incorrectData = new WorkflowData( + Map.ofEntries( + Map.entry("id", "pipelineId"), + Map.entry("description", "some description"), + Map.entry("type", "text_embedding"), + Map.entry("model_id", "model_id") + ) + ); + + CompletableFuture future = createIngestPipelineStep.execute(List.of(incorrectData)); + assertTrue(future.isDone() && future.isCompletedExceptionally()); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); + assertTrue(exception.getCause() instanceof Exception); + assertEquals("Failed to create ingest pipeline, required inputs not found", exception.getCause().getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java new file mode 100644 index 000000000..e2464dace --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowDataTests.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; + +public class WorkflowDataTests extends OpenSearchTestCase { + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + public void testWorkflowData() { + + WorkflowData empty = WorkflowData.EMPTY; + assertTrue(empty.getParams().isEmpty()); + assertTrue(empty.getContent().isEmpty()); + + Map expectedContent = Map.of("baz", new String[] { "qux", "quxx" }); + WorkflowData contentOnly = new WorkflowData(expectedContent); + assertTrue(contentOnly.getParams().isEmpty()); + assertEquals(expectedContent, contentOnly.getContent()); + + Map expectedParams = Map.of("foo", "bar"); + WorkflowData contentAndParams = new WorkflowData(expectedContent, expectedParams); + assertEquals(expectedParams, contentAndParams.getParams()); + assertEquals(expectedContent, contentAndParams.getContent()); + } +} diff --git a/src/test/resources/template/datademo.json b/src/test/resources/template/datademo.json new file mode 100644 index 000000000..a1323ed2c --- /dev/null +++ b/src/test/resources/template/datademo.json @@ -0,0 +1,20 @@ +{ + "sequence": { + "nodes": [ + { + "id": "create_index", + "index_name": "demo" + }, + { + "id": "create_another_index", + "index_name": "second_demo" + } + ], + "edges": [ + { + "source": "create_index", + "dest": "create_another_index" + } + ] + } +} diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json new file mode 100644 index 000000000..38f1d0644 --- /dev/null +++ b/src/test/resources/template/demo.json @@ -0,0 +1,36 @@ +{ + "sequence": { + "nodes": [ + { + "id": "fetch_model" + }, + { + "id": "create_ingest_pipeline" + }, + { + "id": "create_search_pipeline" + }, + { + "id": "create_neural_search_index" + } + ], + "edges": [ + { + "source": "fetch_model", + "dest": "create_ingest_pipeline" + }, + { + "source": "fetch_model", + "dest": "create_search_pipeline" + }, + { + "source": "create_ingest_pipeline", + "dest": "create_neural_search_index" + }, + { + "source": "create_search_pipeline", + "dest": "create_neural_search_index" + } + ] + } +}