Skip to content

Commit

Permalink
fix rounds over decimal in Spark 330+ (#5786)
Browse files Browse the repository at this point in the history
Passes the datatype of round-like functions directly to GPU overrides, so as to adapt different Spark versions.

Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx authored Jun 8, 2022
1 parent 8d3c6e7 commit 51856cb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 33 deletions.
14 changes: 12 additions & 2 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,19 @@ def test_shift_right_unsigned(data_gen):
'shiftrightunsigned(a, cast(null as INT))',
'shiftrightunsigned(a, b)'))

_arith_data_gens_for_round = numeric_gens + _arith_decimal_gens_no_neg_scale + [
decimal_gen_32bit_neg_scale,
DecimalGen(precision=15, scale=-8),
DecimalGen(precision=30, scale=-5),
pytest.param(_decimal_gen_36_neg5, marks=pytest.mark.skipif(
is_spark_330_or_later(), reason='This case overflows in Spark 3.3.0+')),
pytest.param(_decimal_gen_38_neg10, marks=pytest.mark.skipif(
is_spark_330_or_later(), reason='This case overflows in Spark 3.3.0+'))
]

@incompat
@approximate_float
@pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', _arith_data_gens_for_round, ids=idfn)
def test_decimal_bround(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand All @@ -501,7 +511,7 @@ def test_decimal_bround(data_gen):

@incompat
@approximate_float
@pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', _arith_data_gens_for_round, ids=idfn)
def test_decimal_round(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2401,7 +2401,7 @@ object GpuOverrides extends Logging {
}
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuBRound(lhs, rhs)
GpuBRound(lhs, rhs, a.dataType)
}),
expr[Round](
"Round an expression to d decimal places using HALF_UP rounding mode",
Expand All @@ -2422,7 +2422,7 @@ object GpuOverrides extends Logging {
}
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuRound(lhs, rhs)
GpuRound(lhs, rhs, a.dataType)
}),
expr[PythonUDF](
"UDF run in an external python process. Does not actually run on the GPU, but " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import java.io.Serializable
import ai.rapids.cudf._
import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression

import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.rapids.shims.RapidsFloorCeilUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -556,32 +555,16 @@ abstract class CudfBinaryMathExpression(name: String) extends CudfBinaryExpressi
override def dataType: DataType = DoubleType
}

abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBinaryExpression
with Serializable with ImplicitCastInputTypes {
// Due to SPARK-39226, the dataType of round-like functions differs by Spark versions.
abstract class GpuRoundBase(child: Expression, scale: Expression, outputType: DataType)
extends GpuBinaryExpression with Serializable with ImplicitCastInputTypes {

override def left: Expression = child
override def right: Expression = scale

def roundMode: RoundMode

override lazy val dataType: DataType = child.dataType match {
// if the new scale is bigger which means we are scaling up,
// keep the original scale as `Decimal` does
case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale)
case t => t
}

// Avoid repeated evaluation since `scale` is a constant int,
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
// by checking if scaleV == null as well.
private lazy val scaleV: Any = scale match {
case _: GpuExpression =>
withResource(scale.columnarEval(null).asInstanceOf[GpuScalar]) { s =>
s.getValue
}
case _ => scale.eval(EmptyRow)
}
private lazy val _scale: Int = scaleV.asInstanceOf[Int]
override def dataType: DataType = outputType

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)

Expand All @@ -590,9 +573,19 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin
val lhsValue = value.getBase
val scaleVal = scale.getValue.asInstanceOf[Int]

dataType match {
case DecimalType.Fixed(_, scaleVal) =>
lhsValue.round(scaleVal, roundMode)
child.dataType match {
case DecimalType.Fixed(_, s) =>
// Only needs to perform round when required scale < input scale
val rounded = if (scaleVal < s) {
lhsValue.round(scaleVal, roundMode)
} else {
lhsValue.incRefCount()
}
withResource(rounded) { _ =>
// Fit the output datatype
rounded.castTo(
DecimalUtil.createCudfDecimal(dataType.asInstanceOf[DecimalType]))
}
case ByteType =>
fixUpOverflowInts(() => Scalar.fromByte(0.toByte), scaleVal, lhsValue)
case ShortType =>
Expand Down Expand Up @@ -766,13 +759,13 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin
}
}

case class GpuBRound(child: Expression, scale: Expression) extends
GpuRoundBase(child, scale) {
case class GpuBRound(child: Expression, scale: Expression, outputType: DataType) extends
GpuRoundBase(child, scale, outputType) {
override def roundMode: RoundMode = RoundMode.HALF_EVEN
}

case class GpuRound(child: Expression, scale: Expression) extends
GpuRoundBase(child, scale) {
case class GpuRound(child: Expression, scale: Expression, outputType: DataType) extends
GpuRoundBase(child, scale, outputType) {
override def roundMode: RoundMode = RoundMode.HALF_UP
}

Expand Down

0 comments on commit 51856cb

Please sign in to comment.