Skip to content

Commit

Permalink
Changes based on Kay's review.
Browse files Browse the repository at this point in the history
  • Loading branch information
pwendell committed Jul 18, 2014
1 parent 9f18bad commit 5d8b156
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 69 deletions.
38 changes: 24 additions & 14 deletions core/src/main/scala/org/apache/spark/Accumulators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,30 @@ import org.apache.spark.serializer.JavaSerializer
*
* @param initialValue initial value of accumulator
* @param param helper object defining how to add elements of type `R` and `T`
* @param _name human-readable name for use in Spark's web UI
* @param display whether to show accumulator values Spark's web UI
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
class Accumulable[R, T] (
@transient initialValue: R,
param: AccumulableParam[R, T])
param: AccumulableParam[R, T],
_name: Option[String],
val display: Boolean)
extends Serializable {

val id = Accumulators.newId
def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
this(initialValue, param, None, true)

val id: Long = Accumulators.newId
val name = _name.getOrElse(s"accumulator_$id")

@transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
private var deserialized = false

Accumulators.register(this, true)

/** A name for this accumulator / accumulable for display in Spark's UI.
* Note that names must be unique within a SparkContext. */
def name: String = s"accumulator_$id"

/** Whether to display this accumulator in the web UI. */
def display: Boolean = true

/**
* Add more data to this accumulator / accumulable
* @param term the data to add
Expand Down Expand Up @@ -97,6 +99,16 @@ class Accumulable[R, T] (
}
}

/**
* Function to customize printing values of this accumulator.
*/
def prettyValue(_value: R) = s"$value"

/**
* Function to customize printing partially accumulated (local) values of this accumulator.
*/
def prettyPartialValue(_value: R) = prettyValue(_value)

/**
* Get the current value of this accumulator from within a task.
*
Expand Down Expand Up @@ -226,11 +238,9 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
* @param param helper object defining how to add elements of type `T`
* @tparam T result type
*/
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], _name: String,
_display: Boolean) extends Accumulable[T,T](initialValue, param) {
override def name = if (_name.eq(null)) s"accumulator_$id" else _name
override def display = _display
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, null, true)
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String],
display: Boolean) extends Accumulable[T,T](initialValue, param, name, display) {
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None, true)
}

/**
Expand Down
9 changes: 4 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -758,13 +758,12 @@ class SparkContext(config: SparkConf) extends Logging {
new Accumulator(initialValue, param)

/**
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
* values to using the `+=` method. Only the driver can access the accumulator's `value`.
*
* This version adds a custom name to the accumulator for display in the Spark UI.
* Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display
* in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the
* driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = {
new Accumulator(initialValue, param, name, true)
new Accumulator(initialValue, param, Some(name), true)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,14 +818,15 @@ class DAGScheduler(
// TODO: fail the stage if the accumulator update fails...
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
event.accumUpdates.foreach { case (id, partialValue) =>
val acc = Accumulators.originals(id)
val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
val name = acc.name
// To avoid UI cruft, ignore cases where value wasn't updated
if (partialValue != acc.zero) {
val stringPartialValue = s"${partialValue}"
val stringValue = s"${acc.value}"
stageToInfos(stage).accumulatedValues(name) = stringValue
event.taskInfo.accumulableValues += ((name, stringPartialValue))
val stringPartialValue = acc.prettyPartialValue(partialValue)
val stringValue = acc.prettyValue(acc.value)
stageToInfos(stage).accumulables(id) = AccumulableInfo(id, acc.name, stringValue)
event.taskInfo.accumulables +=
AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.scheduler

import scala.collection.mutable.HashMap
import scala.collection.mutable.Map

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.storage.RDDInfo
Expand All @@ -40,8 +39,8 @@ class StageInfo(
var completionTime: Option[Long] = None
/** If the stage failed, the reason why. */
var failureReason: Option[String] = None
/** Terminal values of accumulables updated during this stage. */
val accumulatedValues: Map[String, String] = HashMap[String, String]()
/** Terminal values of accumulables updated during this stage.*/
val accumulables = HashMap[Long, AccumulableInfo]()

def stageFailed(reason: String) {
failureReason = Some(reason)
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ class TaskInfo(

/**
* Intermediate updates to accumulables during this task. Note that it is valid for the same
* accumulable to be updated multiple times in a single task.
* accumulable to be updated multiple times in a single task or for two accumulables with the
* same name but different ID's to exist in a task.
*/
val accumulableValues = ListBuffer[(String, String)]()
val accumulables = ListBuffer[AccumulableInfo]()

/**
* The time when the task has completed successfully (including the time to remotely fetch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {

// TODO: Should probably consolidate all following into a single hash map.
val stageIdToTime = HashMap[Int, Long]()
val stageIdToAccumulables = HashMap[Int, Map[String, String]]()
val stageIdToAccumulables = HashMap[Int, Map[Long, AccumulableInfo]]()
val stageIdToInputBytes = HashMap[Int, Long]()
val stageIdToShuffleRead = HashMap[Int, Long]()
val stageIdToShuffleWrite = HashMap[Int, Long]()
Expand All @@ -75,9 +75,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
// Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage
poolToActiveStages(stageIdToPool(stageId)).remove(stageId)

val accumulables = stageIdToAccumulables.getOrElseUpdate(stageId, HashMap[String, String]())
for ((name, value) <- stageCompleted.stageInfo.accumulatedValues) {
accumulables(name) = value
val emptyMap = HashMap[Long, AccumulableInfo]()
val accumulables = stageIdToAccumulables.getOrElseUpdate(stageId, emptyMap)
for ((id, info) <- stageCompleted.stageInfo.accumulables) {
accumulables(id) = info
}

activeStages.remove(stageId)
Expand Down Expand Up @@ -155,9 +156,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener {
val info = taskEnd.taskInfo

if (info != null) {
val accumulables = stageIdToAccumulables.getOrElseUpdate(sid, HashMap[String, String]())
for ((name, value) <- info.accumulableValues) {
accumulables(name) = value
val emptyMap = HashMap[Long, AccumulableInfo]()
val accumulables = stageIdToAccumulables.getOrElseUpdate(sid, emptyMap)
for (accumulableInfo <- info.accumulables) {
accumulables(accumulableInfo.id) = accumulableInfo
}

// create executor summary map if necessary
Expand Down
11 changes: 7 additions & 4 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.ui.jobs
import java.util.Date
import javax.servlet.http.HttpServletRequest

import scala.xml.{Unparsed, Node}
import scala.xml.{Node, Unparsed}

import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.{Utils, Distribution}
import org.apache.spark.scheduler.AccumulableInfo

/** Page showing statistics and task list for a given stage */
private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
Expand Down Expand Up @@ -104,9 +105,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
</div>
// scalastyle:on
val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value")
def accumulableRow(acc: (String, String)) = <tr><td>{acc._1}</td><td>{acc._2}</td></tr>
def accumulableRow(acc: AccumulableInfo) = <tr><td>{acc.name}</td><td>{acc.value}</td></tr>
val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow,
accumulables.toSeq)
accumulables.values.toSeq)

val taskHeaders: Seq[String] =
Seq(
Expand Down Expand Up @@ -291,7 +292,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
{if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""}
</td>
<td>
{Unparsed(info.accumulableValues.map{ case (k, v) => s"$k: $v" }.mkString("<br/>"))}
{Unparsed(
info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("<br/>")
)}
</td>
<!--
TODO: Add this back after we add support to hide certain columns.
Expand Down
49 changes: 31 additions & 18 deletions core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,12 @@ private[spark] object JsonProtocol {
("Submission Time" -> submissionTime) ~
("Completion Time" -> completionTime) ~
("Failure Reason" -> failureReason) ~
("Accumulated Values" -> mapToJson(stageInfo.accumulatedValues))
("Accumulables" -> JArray(
stageInfo.accumulables.values.map(accumulableInfoToJson).toList))
}

def taskInfoToJson(taskInfo: TaskInfo): JValue = {
val accumUpdateMap = taskInfo.accumulableValues.map { case (k, v) =>
mapToJson(Map(k -> v))
}.toList
val accumUpdateMap = taskInfo.accumulables
("Task ID" -> taskInfo.taskId) ~
("Index" -> taskInfo.index) ~
("Attempt" -> taskInfo.attempt) ~
Expand All @@ -209,7 +208,14 @@ private[spark] object JsonProtocol {
("Getting Result Time" -> taskInfo.gettingResultTime) ~
("Finish Time" -> taskInfo.finishTime) ~
("Failed" -> taskInfo.failed) ~
("Accumulable Updates" -> JArray(accumUpdateMap))
("Accumulables" -> JArray(taskInfo.accumulables.map(accumulableInfoToJson).toList))
}

def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = {
("ID" -> accumulableInfo.id) ~
("Name" -> accumulableInfo.name) ~
("Update" -> accumulableInfo.update.map(new JString(_)).getOrElse(JNothing)) ~
("Value" -> accumulableInfo.value)
}

def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = {
Expand Down Expand Up @@ -485,21 +491,22 @@ private[spark] object JsonProtocol {
val stageId = (json \ "Stage ID").extract[Int]
val stageName = (json \ "Stage Name").extract[String]
val numTasks = (json \ "Number of Tasks").extract[Int]
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson)
val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson(_))
val details = (json \ "Details").extractOpt[String].getOrElse("")
val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long])
val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
val accumulatedValues = (json \ "Accumulated Values").extractOpt[JObject].map(mapFromJson(_))
val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match {
case Some(values) => values.map(accumulableInfoFromJson(_))
case None => Seq[AccumulableInfo]()
}

val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfos, details)
stageInfo.submissionTime = submissionTime
stageInfo.completionTime = completionTime
stageInfo.failureReason = failureReason
accumulatedValues.foreach { values =>
for ((k, v) <- values) {
stageInfo.accumulatedValues(k) = v
}
for (accInfo <- accumulatedValues) {
stageInfo.accumulables(accInfo.id) = accInfo
}
stageInfo
}
Expand All @@ -516,22 +523,28 @@ private[spark] object JsonProtocol {
val gettingResultTime = (json \ "Getting Result Time").extract[Long]
val finishTime = (json \ "Finish Time").extract[Long]
val failed = (json \ "Failed").extract[Boolean]
val accumulableUpdates = (json \ "Accumulable Updates").extractOpt[Seq[JValue]].map(
updates => updates.map(mapFromJson(_)))
val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match {
case Some(values) => values.map(accumulableInfoFromJson(_))
case None => Seq[AccumulableInfo]()
}

val taskInfo =
new TaskInfo(taskId, index, attempt, launchTime, executorId, host, taskLocality, speculative)
taskInfo.gettingResultTime = gettingResultTime
taskInfo.finishTime = finishTime
taskInfo.failed = failed
accumulableUpdates.foreach { maps =>
for (m <- maps) {
taskInfo.accumulableValues += m.head
}
}
accumulables.foreach { taskInfo.accumulables += _ }
taskInfo
}

def accumulableInfoFromJson(json: JValue): AccumulableInfo = {
val id = (json \ "id").extract[Long]
val name = (json \ "name").extract[String]
val update = Utils.jsonOption(json \ "update").map(_.extract[String])
val value = (json \ "value").extract[String]
AccumulableInfo(id, name, update, value)
}

def taskMetricsFromJson(json: JValue): TaskMetrics = {
if (json == JNothing) {
return TaskMetrics.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class JsonProtocolSuite extends FunSuite {

// Fields added after 1.0.0.
assert(info.details.nonEmpty)
assert(info.accumulatedValues.nonEmpty)
assert(info.accumulables.nonEmpty)
val oldJson = newJson
.removeField { case (field, _) => field == "Details" }
.removeField { case (field, _) => field == "Accumulated Values" }
Expand All @@ -138,7 +138,7 @@ class JsonProtocolSuite extends FunSuite {

assert(info.name === newInfo.name)
assert("" === newInfo.details)
assert(0 === newInfo.accumulatedValues.size)
assert(0 === newInfo.accumulables.size)
}

test("InputMetrics backward compatibility") {
Expand Down Expand Up @@ -268,7 +268,7 @@ class JsonProtocolSuite extends FunSuite {
(0 until info1.rddInfos.size).foreach { i =>
assertEquals(info1.rddInfos(i), info2.rddInfos(i))
}
assert(info1.accumulatedValues === info2.accumulatedValues)
assert(info1.accumulables === info2.accumulables)
assert(info1.details === info2.details)
}

Expand Down Expand Up @@ -301,7 +301,7 @@ class JsonProtocolSuite extends FunSuite {
assert(info1.gettingResultTime === info2.gettingResultTime)
assert(info1.finishTime === info2.finishTime)
assert(info1.failed === info2.failed)
assert(info1.accumulableValues === info2.accumulableValues)
assert(info1.accumulables === info2.accumulables)
}

private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) {
Expand Down Expand Up @@ -487,17 +487,17 @@ class JsonProtocolSuite extends FunSuite {
private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = {
val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) }
val stageInfo = new StageInfo(a, "greetings", b, rddInfos, "details")
stageInfo.accumulatedValues("acc1") = "val1"
stageInfo.accumulatedValues("acc2") = "val2"
stageInfo.accumulables("acc1") = "val1"
stageInfo.accumulables("acc2") = "val2"
stageInfo
}

private def makeTaskInfo(a: Long, b: Int, c: Int, d: Long, speculative: Boolean) = {
val taskInfo = new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL,
speculative)
taskInfo.accumulableValues += (("acc1", "val1"))
taskInfo.accumulableValues += (("acc1", "val1"))
taskInfo.accumulableValues += (("acc2", "val2"))
taskInfo.accumulables += (("acc1", "val1"))
taskInfo.accumulables += (("acc1", "val1"))
taskInfo.accumulables += (("acc2", "val2"))
taskInfo
}

Expand Down
5 changes: 3 additions & 2 deletions docs/programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,8 @@ value of the broadcast variable (e.g. if the variable is shipped to a new node l
Accumulators are variables that are only "added" to through an associative operation and can
therefore be efficiently supported in parallel. They can be used to implement counters (as in
MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers
can add support for new types.
can add support for new types. Accumulator values are displayed in Spark's UI and can be
useful for understanding the progress of running stages.

An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks
running on the cluster can then add to it using the `add` method or the `+=` operator (in Scala and Python).
Expand All @@ -1194,7 +1195,7 @@ The code below shows an accumulator being used to add up the elements of an arra
<div data-lang="scala" markdown="1">

{% highlight scala %}
scala> val accum = sc.accumulator(0)
scala> val accum = sc.accumulator(0, "My Accumulator")
accum: spark.Accumulator[Int] = 0

scala> sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum += x)
Expand Down

0 comments on commit 5d8b156

Please sign in to comment.