From 51856cb10150b18762519dbbaad3a9d2ac4ebf54 Mon Sep 17 00:00:00 2001 From: Alfred Xu Date: Wed, 8 Jun 2022 23:56:49 +0800 Subject: [PATCH] fix rounds over decimal in Spark 330+ (#5786) Passes the datatype of round-like functions directly to GPU overrides, so as to adapt different Spark versions. Signed-off-by: sperlingxx --- .../src/main/python/arithmetic_ops_test.py | 14 ++++- .../nvidia/spark/rapids/GpuOverrides.scala | 4 +- .../spark/sql/rapids/mathExpressions.scala | 51 ++++++++----------- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/integration_tests/src/main/python/arithmetic_ops_test.py b/integration_tests/src/main/python/arithmetic_ops_test.py index a1d897f2042..2a28f05f800 100644 --- a/integration_tests/src/main/python/arithmetic_ops_test.py +++ b/integration_tests/src/main/python/arithmetic_ops_test.py @@ -487,9 +487,19 @@ def test_shift_right_unsigned(data_gen): 'shiftrightunsigned(a, cast(null as INT))', 'shiftrightunsigned(a, b)')) +_arith_data_gens_for_round = numeric_gens + _arith_decimal_gens_no_neg_scale + [ + decimal_gen_32bit_neg_scale, + DecimalGen(precision=15, scale=-8), + DecimalGen(precision=30, scale=-5), + pytest.param(_decimal_gen_36_neg5, marks=pytest.mark.skipif( + is_spark_330_or_later(), reason='This case overflows in Spark 3.3.0+')), + pytest.param(_decimal_gen_38_neg10, marks=pytest.mark.skipif( + is_spark_330_or_later(), reason='This case overflows in Spark 3.3.0+')) +] + @incompat @approximate_float -@pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', _arith_data_gens_for_round, ids=idfn) def test_decimal_bround(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( @@ -501,7 +511,7 @@ def test_decimal_bround(data_gen): @incompat @approximate_float -@pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn) +@pytest.mark.parametrize('data_gen', _arith_data_gens_for_round, ids=idfn) def test_decimal_round(data_gen): assert_gpu_and_cpu_are_equal_collect( lambda spark: unary_op_df(spark, data_gen).selectExpr( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 307eb9e5a3a..050183db1de 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2401,7 +2401,7 @@ object GpuOverrides extends Logging { } } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuBRound(lhs, rhs) + GpuBRound(lhs, rhs, a.dataType) }), expr[Round]( "Round an expression to d decimal places using HALF_UP rounding mode", @@ -2422,7 +2422,7 @@ object GpuOverrides extends Logging { } } override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = - GpuRound(lhs, rhs) + GpuRound(lhs, rhs, a.dataType) }), expr[PythonUDF]( "UDF run in an external python process. Does not actually run on the GPU, but " + diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala index c444e72090e..2e2d1686162 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/mathExpressions.scala @@ -21,9 +21,8 @@ import java.io.Serializable import ai.rapids.cudf._ import ai.rapids.cudf.ast.BinaryOperator import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression -import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, ImplicitCastInputTypes} +import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes} import org.apache.spark.sql.rapids.shims.RapidsFloorCeilUtils import org.apache.spark.sql.types._ @@ -556,32 +555,16 @@ abstract class CudfBinaryMathExpression(name: String) extends CudfBinaryExpressi override def dataType: DataType = DoubleType } -abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBinaryExpression - with Serializable with ImplicitCastInputTypes { +// Due to SPARK-39226, the dataType of round-like functions differs by Spark versions. +abstract class GpuRoundBase(child: Expression, scale: Expression, outputType: DataType) + extends GpuBinaryExpression with Serializable with ImplicitCastInputTypes { override def left: Expression = child override def right: Expression = scale def roundMode: RoundMode - override lazy val dataType: DataType = child.dataType match { - // if the new scale is bigger which means we are scaling up, - // keep the original scale as `Decimal` does - case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale) - case t => t - } - - // Avoid repeated evaluation since `scale` is a constant int, - // avoid unnecessary `child` evaluation in both codegen and non-codegen eval - // by checking if scaleV == null as well. - private lazy val scaleV: Any = scale match { - case _: GpuExpression => - withResource(scale.columnarEval(null).asInstanceOf[GpuScalar]) { s => - s.getValue - } - case _ => scale.eval(EmptyRow) - } - private lazy val _scale: Int = scaleV.asInstanceOf[Int] + override def dataType: DataType = outputType override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) @@ -590,9 +573,19 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin val lhsValue = value.getBase val scaleVal = scale.getValue.asInstanceOf[Int] - dataType match { - case DecimalType.Fixed(_, scaleVal) => - lhsValue.round(scaleVal, roundMode) + child.dataType match { + case DecimalType.Fixed(_, s) => + // Only needs to perform round when required scale < input scale + val rounded = if (scaleVal < s) { + lhsValue.round(scaleVal, roundMode) + } else { + lhsValue.incRefCount() + } + withResource(rounded) { _ => + // Fit the output datatype + rounded.castTo( + DecimalUtil.createCudfDecimal(dataType.asInstanceOf[DecimalType])) + } case ByteType => fixUpOverflowInts(() => Scalar.fromByte(0.toByte), scaleVal, lhsValue) case ShortType => @@ -766,13 +759,13 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin } } -case class GpuBRound(child: Expression, scale: Expression) extends - GpuRoundBase(child, scale) { +case class GpuBRound(child: Expression, scale: Expression, outputType: DataType) extends + GpuRoundBase(child, scale, outputType) { override def roundMode: RoundMode = RoundMode.HALF_EVEN } -case class GpuRound(child: Expression, scale: Expression) extends - GpuRoundBase(child, scale) { +case class GpuRound(child: Expression, scale: Expression, outputType: DataType) extends + GpuRoundBase(child, scale, outputType) { override def roundMode: RoundMode = RoundMode.HALF_UP }