diff --git a/azure-client-runtime/src/main/java/com/microsoft/azure/DAGraph.java b/azure-client-runtime/src/main/java/com/microsoft/azure/DAGraph.java
index 01917a77cf72a..df3ba483f6787 100644
--- a/azure-client-runtime/src/main/java/com/microsoft/azure/DAGraph.java
+++ b/azure-client-runtime/src/main/java/com/microsoft/azure/DAGraph.java
@@ -153,9 +153,8 @@ public void reportedCompleted(U completed) {
*/
private void initializeDependentKeys() {
visit(new Visitor() {
- // This 'visit' will be called only once per each node.
@Override
- public void visit(U node) {
+ public void visitNode(U node) {
if (node.dependencyKeys().isEmpty()) {
return;
}
@@ -166,6 +165,14 @@ public void visit(U node) {
.addDependent(dependentKey);
}
}
+
+ @Override
+ public void visitEdge(String fromKey, String toKey, GraphEdgeType edgeType) {
+ System.out.println("{" + fromKey + ", " + toKey + "} " + edgeType);
+ if (edgeType == GraphEdgeType.BACK) {
+ throw new IllegalStateException("Detected circular dependency: " + findPath(fromKey, toKey));
+ }
+ }
});
}
diff --git a/azure-client-runtime/src/main/java/com/microsoft/azure/Graph.java b/azure-client-runtime/src/main/java/com/microsoft/azure/Graph.java
index 40ceebaa50b2b..17440ddbeb193 100644
--- a/azure-client-runtime/src/main/java/com/microsoft/azure/Graph.java
+++ b/azure-client-runtime/src/main/java/com/microsoft/azure/Graph.java
@@ -12,6 +12,16 @@
import java.util.Map;
import java.util.Set;
+/**
+ * The edge types in a graph.
+ */
+enum GraphEdgeType {
+ TREE,
+ FORWARD,
+ BACK,
+ CROSS
+}
+
/**
* Type representing a directed graph data structure.
*
@@ -23,6 +33,11 @@
public class Graph> {
protected Map graph;
private Set visited;
+ private Integer time;
+ private Map entryTime;
+ private Map exitTime;
+ private Map parent;
+ private Set processed;
/**
* Creates a directed graph.
@@ -30,6 +45,11 @@ public class Graph> {
public Graph() {
this.graph = new HashMap<>();
this.visited = new HashSet<>();
+ this.time = 0;
+ this.entryTime = new HashMap<>();
+ this.exitTime = new HashMap<>();
+ this.parent = new HashMap<>();
+ this.processed = new HashSet<>();
}
/**
@@ -53,14 +73,23 @@ interface Visitor {
*
* @param node the node to visited
*/
- void visit(U node);
+ void visitNode(U node);
+
+ /**
+ * visit an edge.
+ *
+ * @param fromKey key of the from node
+ * @param toKey key of the to node
+ * @param graphEdgeType the edge type
+ */
+ void visitEdge(String fromKey, String toKey, GraphEdgeType graphEdgeType);
}
/**
* Perform DFS visit in this graph.
*
* The directed graph will be traversed in DFS order and the visitor will be notified as
- * search explores each node
+ * search explores each node and edge.
*
* @param visitor the graph visitor
*/
@@ -71,15 +100,61 @@ public void visit(Visitor visitor) {
}
}
visited.clear();
+ time = 0;
+ entryTime.clear();
+ exitTime.clear();
+ parent.clear();
+ processed.clear();
}
private void dfs(Visitor visitor, Node node) {
- visitor.visit(node);
- visited.add(node.key());
- for (String childKey : node.children()) {
- if (!visited.contains(childKey)) {
- this.dfs(visitor, this.graph.get(childKey));
+ visitor.visitNode(node);
+
+ String fromKey = node.key();
+ visited.add(fromKey);
+ time++;
+ entryTime.put(fromKey, time);
+ for (String toKey : node.children()) {
+ if (!visited.contains(toKey)) {
+ parent.put(toKey, fromKey);
+ visitor.visitEdge(fromKey, toKey, edgeType(fromKey, toKey));
+ this.dfs(visitor, this.graph.get(toKey));
+ } else {
+ visitor.visitEdge(fromKey, toKey, edgeType(fromKey, toKey));
}
}
+ time++;
+ exitTime.put(fromKey, time);
+ processed.add(fromKey);
+ }
+
+ private GraphEdgeType edgeType(String fromKey, String toKey) {
+ if (parent.containsKey(toKey) && parent.get(toKey).equals(fromKey)) {
+ return GraphEdgeType.TREE;
+ }
+
+ if (visited.contains(toKey) && !processed.contains(toKey)) {
+ return GraphEdgeType.BACK;
+ }
+
+ if (processed.contains(toKey) && entryTime.containsKey(toKey) && entryTime.containsKey(fromKey)) {
+ if (entryTime.get(toKey) > entryTime.get(fromKey)) {
+ return GraphEdgeType.FORWARD;
+ }
+
+ if (entryTime.get(toKey) < entryTime.get(fromKey)) {
+ return GraphEdgeType.CROSS;
+ }
+ }
+
+ throw new IllegalStateException("Internal Error: Unable to locate the edge type {" + fromKey + ", " + toKey + "}");
+ }
+
+ protected String findPath(String start, String end) {
+ if (start.equals(end)) {
+ return start;
+ } else {
+ return findPath(start, parent.get(end)) + " -> " + end;
+ }
}
}
diff --git a/azure-client-runtime/src/main/java/com/microsoft/azure/TaskGroupBase.java b/azure-client-runtime/src/main/java/com/microsoft/azure/TaskGroupBase.java
index 1cd0050e5ddb6..b607a41f54add 100644
--- a/azure-client-runtime/src/main/java/com/microsoft/azure/TaskGroupBase.java
+++ b/azure-client-runtime/src/main/java/com/microsoft/azure/TaskGroupBase.java
@@ -9,54 +9,10 @@
import com.microsoft.rest.ServiceCall;
import com.microsoft.rest.ServiceCallback;
+import com.microsoft.rest.ServiceResponse;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
-/**
- /**
- * An instance of this class provides access to the underlying REST service call running
- * in parallel.
- *
- * @param
- */
-class ParallelServiceCall extends ServiceCall {
- private TaskGroupBase taskGroup;
-
- /**
- * Creates a ParallelServiceCall.
- *
- * @param taskGroup the task group
- */
- ParallelServiceCall(TaskGroupBase taskGroup) {
- super(null);
- this.taskGroup = taskGroup;
- }
-
- /**
- * Cancels all the service calls currently executing.
- */
- public void cancel() {
- for (ServiceCall call : this.taskGroup.calls()) {
- call.cancel();
- }
- }
-
- /**
- * @return true if the call has been canceled; false otherwise.
- */
- public boolean isCancelled() {
- for (ServiceCall call : this.taskGroup.calls()) {
- if (!call.isCanceled()) {
- return false;
- }
- }
- return true;
- }
-}
-
/**
* The base implementation of TaskGroup interface.
*
@@ -65,8 +21,7 @@ public boolean isCancelled() {
public abstract class TaskGroupBase
implements TaskGroup> {
private DAGraph, DAGNode>> dag;
- private ConcurrentLinkedQueue serviceCalls = new ConcurrentLinkedQueue<>();
- private ParallelServiceCall parallelServiceCall;
+ private ParallelServiceCall parallelServiceCall;
/**
* Creates TaskGroupBase.
@@ -76,11 +31,7 @@ public abstract class TaskGroupBase
*/
public TaskGroupBase(String rootTaskItemId, TaskItem rootTaskItem) {
this.dag = new DAGraph<>(new DAGNode<>(rootTaskItemId, rootTaskItem));
- this.parallelServiceCall = new ParallelServiceCall<>(this);
- }
-
- List calls() {
- return Collections.unmodifiableList(Arrays.asList(serviceCalls.toArray(new ServiceCall[0])));
+ this.parallelServiceCall = new ParallelServiceCall();
}
@Override
@@ -93,14 +44,6 @@ public boolean isPreparer() {
return dag.isPreparer();
}
- /**
- * @return Gets the ParallelServiceCall instance that wraps the service calls running
- * in parallel.
- */
- public ParallelServiceCall parallelServiceCall() {
- return this.parallelServiceCall;
- }
-
@Override
public void merge(TaskGroup> parentTaskGroup) {
dag.merge(parentTaskGroup.dag());
@@ -116,32 +59,109 @@ public void prepare() {
@Override
public void execute() throws Exception {
DAGNode> nextNode = dag.getNext();
- if (nextNode == null) {
- return;
+ while (nextNode != null) {
+ nextNode.data().execute();
+ this.dag().reportedCompleted(nextNode);
+ nextNode = dag.getNext();
}
-
- nextNode.data().execute(this, nextNode);
}
@Override
public ServiceCall executeAsync(final ServiceCallback callback) {
- ServiceCall serviceCall = null;
+ executeReadyTasksAsync(callback);
+ return parallelServiceCall;
+ }
+
+ @Override
+ public T taskResult(String taskId) {
+ return dag.getNodeData(taskId).result();
+ }
+
+ /**
+ * Executes all runnable tasks, a task is runnable when all the tasks its depends
+ * on are finished running.
+ *
+ * @param callback the callback
+ */
+ private void executeReadyTasksAsync(final ServiceCallback callback) {
DAGNode> nextNode = dag.getNext();
while (nextNode != null) {
- serviceCall = nextNode.data().executeAsync(this, nextNode, dag.isRootNode(nextNode), callback);
- if (serviceCall != null) {
- // Filter out the null value returned by executeAsync. that happen
- // when TaskItem::executeAsync invokes TaskGroupBase::executeAsync
- // but there is no task available in the queue at the moment.
- this.serviceCalls.add(serviceCall);
- }
+ ServiceCall serviceCall = nextNode.data().executeAsync(taskCallback(nextNode, callback));
+ this.parallelServiceCall.addCall(serviceCall);
nextNode = dag.getNext();
}
- return serviceCall;
}
- @Override
- public T taskResult(String taskId) {
- return dag.getNodeData(taskId).result();
+ /**
+ * This method create and return a callback for the runnable task stored in the given node.
+ * This callback wraps the given callback.
+ *
+ * @param taskNode the node containing runnable task
+ * @param callback the callback to wrap
+ * @return the task callback
+ */
+ private ServiceCallback taskCallback(final DAGNode> taskNode, final ServiceCallback callback) {
+ final TaskGroupBase self = this;
+ return new ServiceCallback() {
+ @Override
+ public void failure(Throwable t) {
+ callback.failure(t);
+ }
+
+ @Override
+ public void success(ServiceResponse result) {
+ self.dag().reportedCompleted(taskNode);
+ if (self.dag().isRootNode(taskNode)) {
+ callback.success(result);
+ } else {
+ self.executeReadyTasksAsync(callback);
+ }
+ }
+ };
+ }
+
+ /**
+ * Type represents a set of REST calls running possibly in parallel.
+ */
+ private class ParallelServiceCall extends ServiceCall {
+ private ConcurrentLinkedQueue serviceCalls;
+
+ /**
+ * Creates a ParallelServiceCall.
+ */
+ ParallelServiceCall() {
+ super(null);
+ this.serviceCalls = new ConcurrentLinkedQueue<>();
+ }
+
+ /**
+ * Cancels all the service calls currently executing.
+ */
+ public void cancel() {
+ for (ServiceCall call : this.serviceCalls) {
+ call.cancel();
+ }
+ }
+
+ /**
+ * @return true if the call has been canceled; false otherwise.
+ */
+ public boolean isCancelled() {
+ for (ServiceCall call : this.serviceCalls) {
+ if (!call.isCanceled()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Add a call to the list of parallel calls.
+ *
+ * @param call the call
+ */
+ private void addCall(ServiceCall call) {
+ this.serviceCalls.add(call);
+ }
}
}
diff --git a/azure-client-runtime/src/main/java/com/microsoft/azure/TaskItem.java b/azure-client-runtime/src/main/java/com/microsoft/azure/TaskItem.java
index 1612ff4f835a7..8f0a3459a2e92 100644
--- a/azure-client-runtime/src/main/java/com/microsoft/azure/TaskItem.java
+++ b/azure-client-runtime/src/main/java/com/microsoft/azure/TaskItem.java
@@ -26,22 +26,17 @@ public interface TaskItem {
*
* once executed the result will be available through result getter
*
- * @param taskGroup the task group dispatching tasks
- * @param node the node the task item is associated with
* @throws Exception exception
*/
- void execute(TaskGroup> taskGroup, DAGNode> node) throws Exception;
+ void execute() throws Exception;
/**
* Executes the task asynchronously.
*
* once executed the result will be available through result getter
-
- * @param taskGroup the task group dispatching tasks
- * @param node the node the task item is associated with
- * @param isRootNode true if the node is root node
+ *
* @param callback callback to call on success or failure
* @return the handle of the REST call
*/
- ServiceCall executeAsync(TaskGroup> taskGroup, DAGNode> node, final boolean isRootNode, ServiceCallback callback);
+ ServiceCall executeAsync(ServiceCallback callback);
}