Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run test on pull request. #160

Merged
merged 13 commits into from
Oct 5, 2024
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ on:
push:
branches:
- main
pull_request:

jobs:
build:
strategy:
fail-fast: false
matrix:
scala: ["2.12.12"]
spark: ["3.0.1"]
spark: ["3.0.1", "3.1.3", "3.2.4", "3.3.4"]
zeotuan marked this conversation as resolved.
Show resolved Hide resolved
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: olafurpg/setup-scala@v10
- name: Test
run: sbt -Dspark.testVersion=${{ matrix.spark }} ++${{ matrix.scala }} test
run: sbt -Dspark.testVersion=${{ matrix.spark }} +test
- name: Code Quality
run: sbt scalafmtCheckAll
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
version = 2.6.3

lineEndings = preserve
align = more
maxColumn = 150
docstrings = JavaDoc
19 changes: 15 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,21 @@ organization := "com.github.mrpowers"
name := "spark-daria"

version := "1.2.3"
crossScalaVersions := Seq("2.12.15", "2.13.8")
scalaVersion := "2.12.15"
//scalaVersion := "2.13.8"
val sparkVersion = "3.2.1"

val versionRegex = """^(.*)\.(.*)\.(.*)$""".r

val scala2_13 = "2.13.14"
val scala2_12 = "2.12.20"

val sparkVersion = System.getProperty("spark.testVersion", "3.3.4")
crossScalaVersions := {
sparkVersion match {
case versionRegex("3", m, _) if m.toInt >= 2 => Seq(scala2_12, scala2_13)
case versionRegex("3", _, _) => Seq(scala2_12)
}
}

scalaVersion := crossScalaVersions.value.head

libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % "provided"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % sparkVersion % "provided"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,36 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.random.XORShiftRandomAdapted

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false) extends TernaryExpression
with ExpectsInputTypes
with Stateful
with ExpressionWithRandomSeed {
import scala.util.Try

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
with Stateful
with ExpressionWithRandomSeed {

override def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}

@transient protected lazy val shapeVal: Double = shape.dataType match {
case IntegerType => shape.eval().asInstanceOf[Int]
case LongType => shape.eval().asInstanceOf[Long]
case IntegerType => shape.eval().asInstanceOf[Int]
case LongType => shape.eval().asInstanceOf[Long]
case FloatType | DoubleType => shape.eval().asInstanceOf[Double]
}

@transient protected lazy val scaleVal: Double = scale.dataType match {
case IntegerType => scale.eval().asInstanceOf[Int]
case LongType => scale.eval().asInstanceOf[Long]
case IntegerType => scale.eval().asInstanceOf[Int]
case LongType => scale.eval().asInstanceOf[Long]
case FloatType | DoubleType => scale.eval().asInstanceOf[Double]
}

Expand All @@ -38,7 +40,7 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
}
@transient private var distribution: GammaDistribution = _

def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)
def this() = this(Try(org.apache.spark.sql.catalyst.analysis.UnresolvedSeed).getOrElse(Literal(42L, LongType)), Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

Expand All @@ -48,13 +50,13 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val distributionClassName = classOf[GammaDistribution].getName
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
ctx.addPartitionInitializationStatement(
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);")
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);"
)
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""",
isNull = FalseLiteral)
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral)
}

override def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)
Expand All @@ -80,5 +82,6 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi
}

object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma = RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))
}
10 changes: 5 additions & 5 deletions src/main/scala/org/apache/spark/sql/daria/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ object functions {
private def withExpr(expr: Expression): Column = Column(expr)

def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)

def randLaplace(seed: Long, mu: Double, beta: Double): Column = {
val mu_ = lit(mu)
val mu_ = lit(mu)
val beta_ = lit(beta)
val u = rand(seed)
val u = rand(seed)
when(u < 0.5, mu_ + beta_ * log(lit(2) * u))
.otherwise(mu_ - beta_ * log(lit(2) * (lit(1) - u)))
.alias("laplace_random")
}

def randLaplace(mu: Double, beta: Double): Column = randLaplace(Utils.random.nextLong, mu, beta)
def randLaplace(): Column = randLaplace(0.0, 1.0)
def randLaplace(): Column = randLaplace(0.0, 1.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class XORShiftRandomAdapted(init: Long) extends java.util.Random(init: Long) wit
nextSeed ^= (nextSeed >>> 35)
nextSeed ^= (nextSeed << 4)
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
(nextSeed & ((1L << bits) - 1)).asInstanceOf[Int]
}

override def setSeed(s: Long): Unit = {
Expand All @@ -29,4 +29,3 @@ class XORShiftRandomAdapted(init: Long) extends java.util.Random(init: Long) wit
this.seed = XORShiftRandom.hashSeed(RandomGeneratorFactory.convertToLong(seed))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.apache.spark.sql.SparkSession
trait SparkSessionTestWrapper {

lazy val spark: SparkSession = {
SparkSession
val session = SparkSession
.builder()
.master("local")
.appName("spark session")
Expand All @@ -14,6 +14,8 @@ trait SparkSessionTestWrapper {
"1"
)
.getOrCreate()
session.sparkContext.setLogLevel("ERROR")
session
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1047,22 +1047,6 @@ object TransformationsTest extends TestSuite with DataFrameComparer with ColumnC
}

'withParquetCompatibleColumnNames - {
"blows up if the column name is invalid for Parquet" - {
val df = spark
.createDF(
List(
("pablo")
),
List(
("Column That {Will} Break\t;", StringType, true)
)
)
val path = new java.io.File("./tmp/blowup/example").getCanonicalPath
val e = intercept[org.apache.spark.sql.AnalysisException] {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed since newer spark version no longer has problem handling this

df.write.parquet(path)
}
}

"converts column names to be Parquet compatible" - {
val actualDF = spark
.createDF(
Expand Down
28 changes: 16 additions & 12 deletions src/test/scala/org/apache/spark/sql/daria/functionsTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar
'rand_gamma - {
"has correct mean and standard deviation" - {
val sourceDF = spark.range(100000).select(randGamma(2.0, 2.0))
val stats = sourceDF.agg(
mean("gamma_random").as("mean"),
stddev("gamma_random").as("stddev")
).collect()(0)

val gammaMean = stats.getAs[Double]("mean")
val stats = sourceDF
.agg(
mean("gamma_random").as("mean"),
stddev("gamma_random").as("stddev")
)
.collect()(0)

val gammaMean = stats.getAs[Double]("mean")
val gammaStddev = stats.getAs[Double]("stddev")

// Gamma distribution with shape=2.0 and scale=2.0 has mean=4.0 and stddev=sqrt(8.0)
Expand All @@ -31,12 +33,14 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar
'rand_laplace - {
"has correct mean and standard deviation" - {
val sourceDF = spark.range(100000).select(randLaplace())
val stats = sourceDF.agg(
mean("laplace_random").as("mean"),
stddev("laplace_random").as("std_dev")
).collect()(0)

val laplaceMean = stats.getAs[Double]("mean")
val stats = sourceDF
.agg(
mean("laplace_random").as("mean"),
stddev("laplace_random").as("std_dev")
)
.collect()(0)

val laplaceMean = stats.getAs[Double]("mean")
val laplaceStdDev = stats.getAs[Double]("std_dev")

// Laplace distribution with mean=0.0 and scale=1.0 has mean=0.0 and stddev=sqrt(2.0)
Expand Down
Loading