Skip to content
This repository has been archived by the owner on Jun 20, 2024. It is now read-only.

Various Task Manage Speedups #373

Merged
merged 10 commits into from
Mar 25, 2020
4 changes: 2 additions & 2 deletions core/src/main/scala/dagr/core/execsystem/TaskManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,10 @@ class TaskManager(taskManagerResources: SystemResources = TaskManagerDefaults.de
// we will make this task dependent on the tasks it creates...
if (tasks.contains(node.task)) throw new IllegalStateException(s"Task [${node.task.name}] contained itself in the list returned by getTasks")
// track the new tasks. If they are already added, that's fine too.
val taskIds: Seq[TaskId] = tasks.map { task => addTask(task = task, enclosingNode = Some(node), ignoreExists = true) }
val taskIds: Seq[TaskId] = addTasks(tasks, enclosingNode = Some(node), ignoreExists = true)
// make this node dependent on those tasks
taskIds.map(taskId => node.addPredecessors(this(taskId)))
// we may need to update precedessors if a returned task was already completed
// we may need to update predecessors if a returned task was already completed
if (tasks.flatMap(t => graphNodeFor(t)).exists(_.state == GraphNodeState.COMPLETED)) updatePredecessors()
// TODO: we could check each new task to see if they are in the PREDECESSORS_AND_UNEXPANDED state
true
Expand Down
100 changes: 58 additions & 42 deletions core/src/main/scala/dagr/core/execsystem/TaskTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,52 +70,55 @@ trait TaskTracker extends TaskManagerLike with LazyLogging {
addTask(task=task, enclosingNode=None, ignoreExists=false)
}

private def addTaskNoChecking(task: Task, enclosingNode: Option[GraphNode] = None): TaskId = {
nh13 marked this conversation as resolved.
Show resolved Hide resolved
// set the task id
val id = yieldAndThen(nextId) {nextId += 1}
// set the task info
require(task._taskInfo.isEmpty) // should not have any info!
val info = new TaskExecutionInfo(
task=task,
nh13 marked this conversation as resolved.
Show resolved Hide resolved
taskId=id,
status=UNKNOWN,
script=scriptPathFor(task=task, id=id, attemptIndex=1),
logFile=logPathFor(task=task, id=id, attemptIndex=1),
submissionDate=Some(Instant.now())
)
task._taskInfo = Some(info)

// create the graph node
val node = predecessorsOf(task=task) match {
case None => new GraphNode(task=task, predecessorNodes=Nil, state=GraphNodeState.ORPHAN, enclosingNode=enclosingNode)
case Some(predecessors) => new GraphNode(task=task, predecessorNodes=predecessors, enclosingNode=enclosingNode)
}

// update the lookups
idToTask.put(id, task)
idToNode.put(id, node)

id
}

/** Adds a task to be managed
*
* Throws an [[IllegalArgumentException]] if a cycle was found after logging each strongly connected component with
* a cycle in the graph.
*
* @param ignoreExists true if we just return the task id for already added tasks, false if we are to throw an [[IllegalArgumentException]]
* @param task the given task.
* @return the task identifier.
*/
*
* Throws an [[IllegalArgumentException]] if a cycle was found after logging each strongly connected component with
* a cycle in the graph.
*
* @param ignoreExists true if we just return the task id for already added tasks, false if we are to throw an [[IllegalArgumentException]]
* @param task the given task.
* @return the task identifier.
*/
protected[execsystem] def addTask(task: Task, enclosingNode: Option[GraphNode], ignoreExists: Boolean = false): TaskId = {
nh13 marked this conversation as resolved.
Show resolved Hide resolved
// Make sure the id we will assign the task are not being tracked.
if (idToTask.contains(nextId)) throw new IllegalArgumentException(s"Task '${task.name}' with id '$nextId' was already added!")
if (idToNode.contains(nextId)) throw new IllegalArgumentException(s"Task '${task.name}' with id '$nextId' was already added!")

taskFor(task) match {
case Some(id) if ignoreExists => id
case Some(id) => throw new IllegalArgumentException(s"Task '${task.name}' with id '$id' was already added!")
case None =>
// check for cycles
checkForCycles(task = task)

// set the task id
val id = yieldAndThen(nextId) {nextId += 1}
// set the task info
require(task._taskInfo.isEmpty) // should not have any info!
val info = new TaskExecutionInfo(
task=task,
taskId=id,
status=UNKNOWN,
script=scriptPathFor(task=task, id=id, attemptIndex=1),
logFile=logPathFor(task=task, id=id, attemptIndex=1),
submissionDate=Some(Instant.now())
)
task._taskInfo = Some(info)

// create the graph node
val node = predecessorsOf(task=task) match {
case None => new GraphNode(task=task, predecessorNodes=Nil, state=GraphNodeState.ORPHAN, enclosingNode=enclosingNode)
case Some(predecessors) => new GraphNode(task=task, predecessorNodes=predecessors, enclosingNode=enclosingNode)
}

// update the lookups
idToTask.put(id, task)
idToNode.put(id, node)

id
// add the task
addTaskNoChecking(task, enclosingNode)
}
}

Expand All @@ -125,12 +128,25 @@ trait TaskTracker extends TaskManagerLike with LazyLogging {
* @param ignoreExists true if we just return the task id for already added tasks, false if we are to throw an [[IllegalArgumentException]]
* @return the task identifiers.
*/
protected[execsystem] def addTasks(tasks: Iterable[Task], enclosingNode: Option[GraphNode] = None, ignoreExists: Boolean = false): List[TaskId] = {
tasks.map(task => addTask(task=task, enclosingNode=enclosingNode, ignoreExists=ignoreExists)).toList
protected[execsystem] def addTasks(tasks: Seq[Task], enclosingNode: Option[GraphNode] = None, ignoreExists: Boolean = false): Seq[TaskId] = {
// Make sure the id we will assign the task are not being tracked.
if (idToTask.contains(nextId)) throw new IllegalArgumentException(s"Task id '$nextId' was already added!")

val tasksToAdd = tasks.flatMap { task =>
taskFor(task) match {
case Some(_) if ignoreExists => None
case Some(id) => throw new IllegalArgumentException(s"Task '${task.name}' with id '$id' was already added!")
case None => Some(task)
}
}

checkForCycles(tasksToAdd:_*)

tasks.map { task => taskFor(task).getOrElse(addTaskNoChecking(task, enclosingNode)) }
}

override def addTasks(tasks: Task*): Seq[TaskId] = {
tasks.map(task => addTask(task, enclosingNode=None, ignoreExists=false))
this.addTasks(tasks, enclosingNode=None, ignoreExists=false)
}

override def taskFor(id: TaskId): Option[Task] = idToTask.get(id)
Expand Down Expand Up @@ -259,17 +275,17 @@ trait TaskTracker extends TaskManagerLike with LazyLogging {
*
* @param task a task in the graph to check.
*/
protected def checkForCycles(task: Task): Unit = {
protected def checkForCycles(task: Task*): Unit = {
// check for cycles
if (Task.hasCycle(task)) {
if (Task.hasCycle(task:_*)) {
logger.error("Task was part of a graph that had a cycle")
for (component <- Task.findStronglyConnectedComponents(task = task)) {
for (component <- Task.findStronglyConnectedComponents(task = task:_*)) {
if (Task.isComponentACycle(component = component)) {
logger.error("Tasks were part of a strongly connected component with a cycle: "
+ component.map(t => s"[${t.name}]").mkString(", "))
}
}
throw new IllegalArgumentException(s"Task was part of a graph that had a cycle [${task.name}]")
throw new IllegalArgumentException(s"Task(s) had cyclical dependencies [${task.map(_.name).mkString(",")}]")
}
}

Expand Down
22 changes: 11 additions & 11 deletions core/src/main/scala/dagr/core/tasksystem/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,22 @@ object Task {
* @param task the task to begin search.
* @return true if the DAG to which this task belongs has a cycle, false otherwise.
*/
private[core] def hasCycle(task: Task): Boolean = {
findStronglyConnectedComponents(task).exists(component => isComponentACycle(component))
private[core] def hasCycle(task: Task*): Boolean = {
findStronglyConnectedComponents(task:_*).exists(component => isComponentACycle(component))
}

/** Finds all the strongly connected components of the graph to which this task is connected.
*
* Uses Tarjan's algorithm: https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
*
* @param task a task in the graph to check.
* @param task one or more tasks in the graph to check.
* @return the set of strongly connected components.
*/
private[core] def findStronglyConnectedComponents(task: Task): Set[Set[Task]] = {
private[core] def findStronglyConnectedComponents(task: Task*): Set[Set[Task]] = {

// 1. find all tasks connected to this task
val visited: mutable.Set[Task] = new mutable.HashSet[Task]()
val toVisit: mutable.Set[Task] = mutable.HashSet[Task](task)
val toVisit: mutable.Set[Task] = mutable.HashSet[Task](task:_*)

while (toVisit.nonEmpty) {
val nextTask: Task = toVisit.head
Expand Down Expand Up @@ -122,16 +122,16 @@ object Task {
if (!data.indexes.contains(w)) {
// Successor w has not yet been visited; recurse on it
findStronglyConnectedComponent(w, data)
data.lowLink.put(v, math.min(data.lowLink.get(v).get, data.lowLink.get(w).get))
data.lowLink.put(v, math.min(data.lowLink(v), data.lowLink(w)))
}
else if (data.onStack(w)) {
// Successor w is in stack S and hence in the current SCC
data.lowLink.put(v, math.min(data.lowLink.get(v).get, data.lowLink.get(w).get))
data.lowLink.put(v, math.min(data.lowLink(v), data.lowLink(w)))
}
}

// If v is a root node, pop the stack and generate an SCC
if (data.indexes.get(v).get == data.lowLink.get(v).get) {
if (data.indexes(v) == data.lowLink(v)) {
val component: mutable.Set[Task] = new mutable.HashSet[Task]()
breakable {
while (data.stack.nonEmpty) {
Expand Down Expand Up @@ -190,9 +190,9 @@ trait Task extends Dependable {
/** Removes this as a dependency for other */
override def !=>(other: Dependable): Unit = other.headTasks.foreach(_.removeDependency(this))

override def headTasks: Iterable[Task] = Seq(this)
override def tailTasks: Iterable[Task] = Seq(this)
override def allTasks: Iterable[Task] = Seq(this)
override def headTasks: Iterable[Task] = Some(this)
override def tailTasks: Iterable[Task] = Some(this)
override def allTasks: Iterable[Task] = Some(this)

/**
* Removes a dependency by removing the supplied task from the list of dependencies for this task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class TaskManagerTest extends UnitSpec with OptionValues with LazyLogging with B
"TaskManager" should "not overwrite an existing task when adding a task, or throw an IllegalArgumentException when ignoreExists is false" in {
val task: UnitTask = new ShellCommand("exit", "0") withName "exit 0" requires ResourceSet.empty
val taskManager: TestTaskManager = getDefaultTaskManager()
taskManager.addTasks(tasks=Seq(task, task), ignoreExists=true) shouldBe List(0, 0)
taskManager.addTasks(tasks=Seq(task, task), enclosingNode=None, ignoreExists=true) shouldBe List(0, 0)
an[IllegalArgumentException] should be thrownBy taskManager.addTask(task=task, enclosingNode=None, ignoreExists=false)
}

Expand Down
8 changes: 4 additions & 4 deletions core/src/test/scala/dagr/core/tasksystem/DependableTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ class DependableTest extends UnitSpec {
k.tasksDependedOn should contain theSameElementsAs Seq(i)
}

"Pipeline.root" should "return the same things as Pipeline from the *tasks methods" in {
"Pipeline.root" should "return the same things as Pipeline from the *tasks methods*" in {
val pipeline = new Pipeline() {
override def build() = root ==> (A :: B :: C) ==> (X :: Y :: Z)
}

pipeline.root.headTasks shouldBe pipeline.headTasks
pipeline.root.tailTasks shouldBe pipeline.tailTasks
pipeline.root.allTasks shouldBe pipeline.allTasks
pipeline.root.headTasks.toList should contain theSameElementsInOrderAs pipeline.headTasks
pipeline.root.tailTasks.toList should contain theSameElementsInOrderAs pipeline.tailTasks
pipeline.root.allTasks.toList should contain theSameElementsInOrderAs pipeline.allTasks
}
}