Skip to content

Commit

Permalink
Accumulators support
Browse files Browse the repository at this point in the history
  • Loading branch information
manojlds committed Oct 25, 2017
1 parent 1ad4695 commit 9ec9fe0
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 20 deletions.
58 changes: 41 additions & 17 deletions src/main/scala/sparkplug/SparkPlug.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package sparkplug

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
Expand All @@ -12,11 +13,12 @@ case class SparkPlugCheckpointDetails(checkpointDir: String,
rulesPerStage: Int,
numberOfPartitions: Int)

case class SparkPlug(isPlugDetailsEnabled: Boolean,
plugDetailsColumn: String,
isValidateRulesEnabled: Boolean,
checkpointDetails: Option[SparkPlugCheckpointDetails])(
implicit val spark: SparkSession) {
case class SparkPlug(
isPlugDetailsEnabled: Boolean,
plugDetailsColumn: String,
isValidateRulesEnabled: Boolean,
checkpointDetails: Option[SparkPlugCheckpointDetails],
isAccumulatorsEnabled: Boolean)(implicit val spark: SparkSession) {

private val tableName = "__plug_table__"

Expand All @@ -31,18 +33,36 @@ case class SparkPlug(isPlugDetailsEnabled: Boolean,
} else {
registerUdf(spark)
setupCheckpointing(spark, checkpointDetails)
Right(
spark.sparkContext
.broadcast(rules)
.value
.zipWithIndex
.foldLeft(preProcessInput(in)) {
case (df: DataFrame, (rule: PlugRule, ruleNumber: Int)) =>
repartitionAndCheckpoint(applyRule(df, rule), ruleNumber)
})
Right(plugDf(in, rules))
}
}

private def plugDf(in: DataFrame, rules: List[PlugRule]) = {
val out = spark.sparkContext
.broadcast(rules)
.value
.zipWithIndex
.foldLeft(preProcessInput(in)) {
case (df: DataFrame, (rule: PlugRule, ruleNumber: Int)) =>
repartitionAndCheckpoint(applyRule(df, rule), ruleNumber)
}

Option(isAccumulatorsEnabled)
.filter(identity)
.foreach(_ => {
val accumulatorChanged =
spark.sparkContext.longAccumulator(s"SparkPlug.Changed")
out
.filter(
_.getAs[Seq[GenericRowWithSchema]](plugDetailsColumn).nonEmpty)
.foreach((_: Row) => {
accumulatorChanged.add(1)
})
})

out
}

def validate(schema: StructType, rules: List[PlugRule]) = {
rules
.groupBy(_.name)
Expand Down Expand Up @@ -134,8 +154,8 @@ case class SparkPlugBuilder(
isPlugDetailsEnabled: Boolean = false,
plugDetailsColumn: String = "plugDetails",
isValidateRulesEnabled: Boolean = false,
checkpointDetails: Option[SparkPlugCheckpointDetails] = None)(
implicit val spark: SparkSession) {
checkpointDetails: Option[SparkPlugCheckpointDetails] = None,
isAccumulatorsEnabled: Boolean = false)(implicit val spark: SparkSession) {
def enablePlugDetails(plugDetailsColumn: String = plugDetailsColumn) =
copy(isPlugDetailsEnabled = true, plugDetailsColumn = plugDetailsColumn)
def enableRulesValidation = copy(isValidateRulesEnabled = true)
Expand All @@ -148,11 +168,15 @@ case class SparkPlugBuilder(
rulesPerStage,
numberOfParitions)))

def enableAccumulators =
copy(isAccumulatorsEnabled = true, isPlugDetailsEnabled = true)

def create() =
new SparkPlug(isPlugDetailsEnabled,
plugDetailsColumn,
isValidateRulesEnabled,
checkpointDetails)
checkpointDetails,
isAccumulatorsEnabled)
}

object SparkPlug {
Expand Down
81 changes: 78 additions & 3 deletions src/test/scala/sparkplug/SparkPlugSpec.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
package sparkplug

import java.io.File

import org.apache.spark.SparkConf
import org.apache.spark.scheduler.{
AccumulableInfo,
SparkListener,
SparkListenerJobEnd,
SparkListenerStageCompleted
}
import org.apache.spark.sql.SparkSession
import org.scalatest._
import org.scalatest.concurrent.ScalaFutures
import sparkplug.models.{
PlugAction,
PlugDetail,
PlugRule,
PlugRuleValidationError
}

import scala.concurrent.{Future, Promise}

case class TestRow(title: String, brand: String, price: Int)
case class TestRowWithPlugDetails(title: String,
brand: String,
Expand All @@ -28,7 +35,42 @@ case class TestRowWithStruct(title: String,
brand: String,
price: Option[TestPriceDetails])

class SparkPlugSpec extends FlatSpec with Matchers {
trait SpecAccumulatorsSparkListener extends ScalaFutures {

implicit val spark: SparkSession

def addListener: Future[Map[String, Long]] = {
val promise = Promise[Map[String, Long]]()

spark.sparkContext.addSparkListener(new SparkListener {
val accumulatorsNamespace = "SparkPlug"
var accumulators = Map[String, Long]()

override def onStageCompleted(
stageCompleted: SparkListenerStageCompleted) {
stageCompleted.stageInfo.accumulables.foreach {
case (_, info: AccumulableInfo) =>
info.name
.filter(_.startsWith(accumulatorsNamespace))
.foreach((s: String) => {
accumulators = accumulators ++ Map(s -> info.value.get.asInstanceOf[Long])
})
}
}

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
if (!promise.isCompleted) promise success accumulators
}
})

promise.future
}
}

class SparkPlugSpec
extends FlatSpec
with Matchers
with SpecAccumulatorsSparkListener {
implicit val spark: SparkSession = SparkSession.builder
.config(new SparkConf())
.enableHiveSupport()
Expand Down Expand Up @@ -167,6 +209,39 @@ class SparkPlugSpec extends FlatSpec with Matchers {
output.filter(_.title == "Galaxy").head.price should be(700)
}

it should "set accumulators" in {
val df = spark.createDataFrame(
List(
TestRow("iPhone", "Apple", 300),
TestRow("Galaxy", "Samsung", 200)
))
val sparkPlug = SparkPlug.builder.enableAccumulators
.create()
val rules = List(
PlugRule("rule1",
"version1",
"title like '%iPhone%'",
Seq(PlugAction("price", "1000"))),
PlugRule("rule2",
"version1",
"title like '%Galaxy%'",
Seq(PlugAction("price", "700"))),
PlugRule("rule3",
"version1",
"title like '%Galaxy%'",
Seq(PlugAction("price", "700")))
)

import spark.implicits._
val accumulator = addListener
val output = sparkPlug.plug(df, rules).right.get.as[TestRow].collect()
output.length should be(2)
accumulator.futureValue should be(
Map(
"SparkPlug.Changed" -> 2
))
}

it should "be able to validate derived values" in {
val df = spark.createDataFrame(List.empty[TestRow])
val sparkPlug = SparkPlug.builder.enableRulesValidation.create()
Expand Down

0 comments on commit 9ec9fe0

Please sign in to comment.