diff --git a/azure-client-runtime/src/main/java/com/microsoft/azure/DAGraph.java b/azure-client-runtime/src/main/java/com/microsoft/azure/DAGraph.java index 01917a77cf72a..df3ba483f6787 100644 --- a/azure-client-runtime/src/main/java/com/microsoft/azure/DAGraph.java +++ b/azure-client-runtime/src/main/java/com/microsoft/azure/DAGraph.java @@ -153,9 +153,8 @@ public void reportedCompleted(U completed) { */ private void initializeDependentKeys() { visit(new Visitor() { - // This 'visit' will be called only once per each node. @Override - public void visit(U node) { + public void visitNode(U node) { if (node.dependencyKeys().isEmpty()) { return; } @@ -166,6 +165,14 @@ public void visit(U node) { .addDependent(dependentKey); } } + + @Override + public void visitEdge(String fromKey, String toKey, GraphEdgeType edgeType) { + System.out.println("{" + fromKey + ", " + toKey + "} " + edgeType); + if (edgeType == GraphEdgeType.BACK) { + throw new IllegalStateException("Detected circular dependency: " + findPath(fromKey, toKey)); + } + } }); } diff --git a/azure-client-runtime/src/main/java/com/microsoft/azure/Graph.java b/azure-client-runtime/src/main/java/com/microsoft/azure/Graph.java index 40ceebaa50b2b..17440ddbeb193 100644 --- a/azure-client-runtime/src/main/java/com/microsoft/azure/Graph.java +++ b/azure-client-runtime/src/main/java/com/microsoft/azure/Graph.java @@ -12,6 +12,16 @@ import java.util.Map; import java.util.Set; +/** + * The edge types in a graph. + */ +enum GraphEdgeType { + TREE, + FORWARD, + BACK, + CROSS +} + /** * Type representing a directed graph data structure. *

@@ -23,6 +33,11 @@ public class Graph> { protected Map graph; private Set visited; + private Integer time; + private Map entryTime; + private Map exitTime; + private Map parent; + private Set processed; /** * Creates a directed graph. @@ -30,6 +45,11 @@ public class Graph> { public Graph() { this.graph = new HashMap<>(); this.visited = new HashSet<>(); + this.time = 0; + this.entryTime = new HashMap<>(); + this.exitTime = new HashMap<>(); + this.parent = new HashMap<>(); + this.processed = new HashSet<>(); } /** @@ -53,14 +73,23 @@ interface Visitor { * * @param node the node to visited */ - void visit(U node); + void visitNode(U node); + + /** + * visit an edge. + * + * @param fromKey key of the from node + * @param toKey key of the to node + * @param graphEdgeType the edge type + */ + void visitEdge(String fromKey, String toKey, GraphEdgeType graphEdgeType); } /** * Perform DFS visit in this graph. *

* The directed graph will be traversed in DFS order and the visitor will be notified as - * search explores each node + * search explores each node and edge. * * @param visitor the graph visitor */ @@ -71,15 +100,61 @@ public void visit(Visitor visitor) { } } visited.clear(); + time = 0; + entryTime.clear(); + exitTime.clear(); + parent.clear(); + processed.clear(); } private void dfs(Visitor visitor, Node node) { - visitor.visit(node); - visited.add(node.key()); - for (String childKey : node.children()) { - if (!visited.contains(childKey)) { - this.dfs(visitor, this.graph.get(childKey)); + visitor.visitNode(node); + + String fromKey = node.key(); + visited.add(fromKey); + time++; + entryTime.put(fromKey, time); + for (String toKey : node.children()) { + if (!visited.contains(toKey)) { + parent.put(toKey, fromKey); + visitor.visitEdge(fromKey, toKey, edgeType(fromKey, toKey)); + this.dfs(visitor, this.graph.get(toKey)); + } else { + visitor.visitEdge(fromKey, toKey, edgeType(fromKey, toKey)); } } + time++; + exitTime.put(fromKey, time); + processed.add(fromKey); + } + + private GraphEdgeType edgeType(String fromKey, String toKey) { + if (parent.containsKey(toKey) && parent.get(toKey).equals(fromKey)) { + return GraphEdgeType.TREE; + } + + if (visited.contains(toKey) && !processed.contains(toKey)) { + return GraphEdgeType.BACK; + } + + if (processed.contains(toKey) && entryTime.containsKey(toKey) && entryTime.containsKey(fromKey)) { + if (entryTime.get(toKey) > entryTime.get(fromKey)) { + return GraphEdgeType.FORWARD; + } + + if (entryTime.get(toKey) < entryTime.get(fromKey)) { + return GraphEdgeType.CROSS; + } + } + + throw new IllegalStateException("Internal Error: Unable to locate the edge type {" + fromKey + ", " + toKey + "}"); + } + + protected String findPath(String start, String end) { + if (start.equals(end)) { + return start; + } else { + return findPath(start, parent.get(end)) + " -> " + end; + } } } diff --git a/azure-client-runtime/src/main/java/com/microsoft/azure/TaskGroupBase.java b/azure-client-runtime/src/main/java/com/microsoft/azure/TaskGroupBase.java index 1cd0050e5ddb6..b607a41f54add 100644 --- a/azure-client-runtime/src/main/java/com/microsoft/azure/TaskGroupBase.java +++ b/azure-client-runtime/src/main/java/com/microsoft/azure/TaskGroupBase.java @@ -9,54 +9,10 @@ import com.microsoft.rest.ServiceCall; import com.microsoft.rest.ServiceCallback; +import com.microsoft.rest.ServiceResponse; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; import java.util.concurrent.ConcurrentLinkedQueue; -/** - /** - * An instance of this class provides access to the underlying REST service call running - * in parallel. - * - * @param - */ -class ParallelServiceCall extends ServiceCall { - private TaskGroupBase taskGroup; - - /** - * Creates a ParallelServiceCall. - * - * @param taskGroup the task group - */ - ParallelServiceCall(TaskGroupBase taskGroup) { - super(null); - this.taskGroup = taskGroup; - } - - /** - * Cancels all the service calls currently executing. - */ - public void cancel() { - for (ServiceCall call : this.taskGroup.calls()) { - call.cancel(); - } - } - - /** - * @return true if the call has been canceled; false otherwise. - */ - public boolean isCancelled() { - for (ServiceCall call : this.taskGroup.calls()) { - if (!call.isCanceled()) { - return false; - } - } - return true; - } -} - /** * The base implementation of TaskGroup interface. * @@ -65,8 +21,7 @@ public boolean isCancelled() { public abstract class TaskGroupBase implements TaskGroup> { private DAGraph, DAGNode>> dag; - private ConcurrentLinkedQueue serviceCalls = new ConcurrentLinkedQueue<>(); - private ParallelServiceCall parallelServiceCall; + private ParallelServiceCall parallelServiceCall; /** * Creates TaskGroupBase. @@ -76,11 +31,7 @@ public abstract class TaskGroupBase */ public TaskGroupBase(String rootTaskItemId, TaskItem rootTaskItem) { this.dag = new DAGraph<>(new DAGNode<>(rootTaskItemId, rootTaskItem)); - this.parallelServiceCall = new ParallelServiceCall<>(this); - } - - List calls() { - return Collections.unmodifiableList(Arrays.asList(serviceCalls.toArray(new ServiceCall[0]))); + this.parallelServiceCall = new ParallelServiceCall(); } @Override @@ -93,14 +44,6 @@ public boolean isPreparer() { return dag.isPreparer(); } - /** - * @return Gets the ParallelServiceCall instance that wraps the service calls running - * in parallel. - */ - public ParallelServiceCall parallelServiceCall() { - return this.parallelServiceCall; - } - @Override public void merge(TaskGroup> parentTaskGroup) { dag.merge(parentTaskGroup.dag()); @@ -116,32 +59,109 @@ public void prepare() { @Override public void execute() throws Exception { DAGNode> nextNode = dag.getNext(); - if (nextNode == null) { - return; + while (nextNode != null) { + nextNode.data().execute(); + this.dag().reportedCompleted(nextNode); + nextNode = dag.getNext(); } - - nextNode.data().execute(this, nextNode); } @Override public ServiceCall executeAsync(final ServiceCallback callback) { - ServiceCall serviceCall = null; + executeReadyTasksAsync(callback); + return parallelServiceCall; + } + + @Override + public T taskResult(String taskId) { + return dag.getNodeData(taskId).result(); + } + + /** + * Executes all runnable tasks, a task is runnable when all the tasks its depends + * on are finished running. + * + * @param callback the callback + */ + private void executeReadyTasksAsync(final ServiceCallback callback) { DAGNode> nextNode = dag.getNext(); while (nextNode != null) { - serviceCall = nextNode.data().executeAsync(this, nextNode, dag.isRootNode(nextNode), callback); - if (serviceCall != null) { - // Filter out the null value returned by executeAsync. that happen - // when TaskItem::executeAsync invokes TaskGroupBase::executeAsync - // but there is no task available in the queue at the moment. - this.serviceCalls.add(serviceCall); - } + ServiceCall serviceCall = nextNode.data().executeAsync(taskCallback(nextNode, callback)); + this.parallelServiceCall.addCall(serviceCall); nextNode = dag.getNext(); } - return serviceCall; } - @Override - public T taskResult(String taskId) { - return dag.getNodeData(taskId).result(); + /** + * This method create and return a callback for the runnable task stored in the given node. + * This callback wraps the given callback. + * + * @param taskNode the node containing runnable task + * @param callback the callback to wrap + * @return the task callback + */ + private ServiceCallback taskCallback(final DAGNode> taskNode, final ServiceCallback callback) { + final TaskGroupBase self = this; + return new ServiceCallback() { + @Override + public void failure(Throwable t) { + callback.failure(t); + } + + @Override + public void success(ServiceResponse result) { + self.dag().reportedCompleted(taskNode); + if (self.dag().isRootNode(taskNode)) { + callback.success(result); + } else { + self.executeReadyTasksAsync(callback); + } + } + }; + } + + /** + * Type represents a set of REST calls running possibly in parallel. + */ + private class ParallelServiceCall extends ServiceCall { + private ConcurrentLinkedQueue serviceCalls; + + /** + * Creates a ParallelServiceCall. + */ + ParallelServiceCall() { + super(null); + this.serviceCalls = new ConcurrentLinkedQueue<>(); + } + + /** + * Cancels all the service calls currently executing. + */ + public void cancel() { + for (ServiceCall call : this.serviceCalls) { + call.cancel(); + } + } + + /** + * @return true if the call has been canceled; false otherwise. + */ + public boolean isCancelled() { + for (ServiceCall call : this.serviceCalls) { + if (!call.isCanceled()) { + return false; + } + } + return true; + } + + /** + * Add a call to the list of parallel calls. + * + * @param call the call + */ + private void addCall(ServiceCall call) { + this.serviceCalls.add(call); + } } } diff --git a/azure-client-runtime/src/main/java/com/microsoft/azure/TaskItem.java b/azure-client-runtime/src/main/java/com/microsoft/azure/TaskItem.java index 1612ff4f835a7..8f0a3459a2e92 100644 --- a/azure-client-runtime/src/main/java/com/microsoft/azure/TaskItem.java +++ b/azure-client-runtime/src/main/java/com/microsoft/azure/TaskItem.java @@ -26,22 +26,17 @@ public interface TaskItem { *

* once executed the result will be available through result getter * - * @param taskGroup the task group dispatching tasks - * @param node the node the task item is associated with * @throws Exception exception */ - void execute(TaskGroup> taskGroup, DAGNode> node) throws Exception; + void execute() throws Exception; /** * Executes the task asynchronously. *

* once executed the result will be available through result getter - - * @param taskGroup the task group dispatching tasks - * @param node the node the task item is associated with - * @param isRootNode true if the node is root node + * * @param callback callback to call on success or failure * @return the handle of the REST call */ - ServiceCall executeAsync(TaskGroup> taskGroup, DAGNode> node, final boolean isRootNode, ServiceCallback callback); + ServiceCall executeAsync(ServiceCallback callback); }