Skip to content
This repository has been archived by the owner on Jun 20, 2024. It is now read-only.

Scatter gather flatMap/groupBy fixup #358

Merged
merged 2 commits into from
Aug 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions tasks/src/main/scala/dagr/tasks/ScatterGather.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ object ScatterGather {
}
}

/** Implementation of a Partitioner that takes the partitions from an existing Partitioner, and then groups them. */
private class GroupByPartitioner[Result, Key](partitioner: Partitioner[Result], f: Result => Key) extends SimpleInJvmTask with Partitioner[(Key, Seq[Result])] {
var partitions: Option[Seq[(Key, Seq[Result])]] = None
override def run(): Unit = {
partitioner.partitions match {
case None => throw new IllegalStateException(s"partitioner.partitions called before partitions populated.")
case Some(_partition) => this.partitions = Some(_partition.groupBy(f).toSeq)
}
}
}

/**
* Implementation of a Scatter that is just a thinly veiled wrapper around the Scatterer
* being used to generate the set of scatters/partitions to operate on.
Expand All @@ -126,8 +137,11 @@ object ScatterGather {
override def gather[NextResult <: Task](f: Seq[Result] => NextResult): Gather[Result,NextResult] =
throw new UnsupportedOperationException("gather not supported on an unmapped Scatter")

override def groupBy[Key](f: Result => Key) : Scatter[(Key, Seq[Result])] =
throw new UnsupportedOperationException("groupBy not supported on an unmapped Scatter")
override def groupBy[Key](f: Result => Key) : Scatter[(Key, Seq[Result])] = {
val grouper = new GroupByPartitioner[Result, Key](partitioner, f)
this ==> grouper.scatter
grouper.scatter
}

override def flatMap[NextResult](f: Result => Scatter[NextResult]) : Scatter[NextResult] = {
this.map(f).flatMap(identity)
Expand Down
54 changes: 49 additions & 5 deletions tasks/src/test/scala/dagr/tasks/ScatterGatherTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl
def run(): Unit = Io.writeLines(output, Seq(number.toString))
}

private case class WriteNumberTuple(numbers: (Int, Int), output: Path) extends SimpleInJvmTask {
def run(): Unit = Io.writeLines(output, Seq(Seq(numbers._1, numbers._2).map(_.toString).mkString("\t")))
}

"ScatterGather" should "run a simple scatter-gather pipeline on files" in {
val lines = Seq("one", "one two", "one two three", "one two three four", "one two three four five")
val lengths = Seq(1,2,3,4,5)
Expand All @@ -132,11 +136,7 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl

val taskManager = buildTaskManager
taskManager.addTask(pipeline)
taskManager.runToCompletion(true).foreach { case (task, info) =>
if (TaskStatus.isTaskNotDone(info.status)) {
println(s"${task.name} $info")
}
}
taskManager.runToCompletion(true)

val sum1 = Io.readLines(sumOfCounts).next().toInt
val sum2 = Io.readLines(sumOfSquares).next().toInt
Expand Down Expand Up @@ -351,4 +351,48 @@ class ScatterGatherTests extends UnitSpec with LazyLogging with BeforeAndAfterAl

sumFlatMap shouldBe lengths.sum
}

it should "flatMap a scatter, then groupBy and map" in {
val lines = Seq("one", "one two", "one two three", "one two three four", "one two three four five")

// setup the input and output
val input = tmp()
val countsByWordLengthOut = tmp()
Io.writeLines(input, lines)

val pipeline = new Pipeline() {
override def build(): Unit = {
// the initial scatter: scatters across lines
val scatter: Scatter[Path] = Scatter(SplitByLine(input=input))

// scatter from flatMap: each line is scattered across words, then flatMap makes a scatter across all words (all lines)
val scatterByWordFlatMap: Scatter[Path] = scatter.flatMap { pathToLine =>
val scatter = Scatter(SplitLineByWord(pathToLine))
root ==> scatter
scatter
}

// group by word length
val groupedByWordLength = scatterByWordFlatMap.groupBy { pathToWord => Io.readLines(pathToWord).next().length }

// map: count how many occurrences of words of a given length
val countsByWordLength = groupedByWordLength.map { case (wordLength, tasks) => WriteNumberTuple(numbers=(wordLength, tasks.length), output=tmp()) }

// gather: concatenate them all
countsByWordLength.gather { tasks => Concat(inputs = tasks.map(_.output), output = countsByWordLengthOut) }

root ==> scatter
}
}

val taskManager = buildTaskManager
taskManager.addTask(pipeline)
taskManager.runToCompletion(true)

val outLines = Io.readLines(countsByWordLengthOut).toList
outLines.map { line =>
val Array(wordLength: String, count: String) = line.split("\t")
(wordLength.toInt, count.toInt)
}.sortBy(_._1) should contain theSameElementsInOrderAs Seq((3, 9), (4, 3), (5, 3))
}
}