Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rounds over decimal in Spark 330+ #5786

Merged
merged 1 commit into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions integration_tests/src/main/python/arithmetic_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,19 @@ def test_shift_right_unsigned(data_gen):
'shiftrightunsigned(a, cast(null as INT))',
'shiftrightunsigned(a, b)'))

_arith_data_gens_for_round = numeric_gens + _arith_decimal_gens_no_neg_scale + [
decimal_gen_32bit_neg_scale,
DecimalGen(precision=15, scale=-8),
DecimalGen(precision=30, scale=-5),
pytest.param(_decimal_gen_36_neg5, marks=pytest.mark.skipif(
is_spark_330_or_later(), reason='This case overflows in Spark 3.3.0+')),
pytest.param(_decimal_gen_38_neg10, marks=pytest.mark.skipif(
is_spark_330_or_later(), reason='This case overflows in Spark 3.3.0+'))
]

@incompat
@approximate_float
@pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', _arith_data_gens_for_round, ids=idfn)
def test_decimal_bround(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand All @@ -496,7 +506,7 @@ def test_decimal_bround(data_gen):

@incompat
@approximate_float
@pytest.mark.parametrize('data_gen', _arith_data_gens, ids=idfn)
@pytest.mark.parametrize('data_gen', _arith_data_gens_for_round, ids=idfn)
def test_decimal_round(data_gen):
assert_gpu_and_cpu_are_equal_collect(
lambda spark: unary_op_df(spark, data_gen).selectExpr(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2401,7 +2401,7 @@ object GpuOverrides extends Logging {
}
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuBRound(lhs, rhs)
GpuBRound(lhs, rhs, a.dataType)
}),
expr[Round](
"Round an expression to d decimal places using HALF_UP rounding mode",
Expand All @@ -2422,7 +2422,7 @@ object GpuOverrides extends Logging {
}
}
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuRound(lhs, rhs)
GpuRound(lhs, rhs, a.dataType)
}),
expr[PythonUDF](
"UDF run in an external python process. Does not actually run on the GPU, but " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ import java.io.Serializable
import ai.rapids.cudf._
import ai.rapids.cudf.ast.BinaryOperator
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression

import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.rapids.shims.RapidsFloorCeilUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -556,32 +555,16 @@ abstract class CudfBinaryMathExpression(name: String) extends CudfBinaryExpressi
override def dataType: DataType = DoubleType
}

abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBinaryExpression
with Serializable with ImplicitCastInputTypes {
// Due to SPARK-39226, the dataType of round-like functions differs by Spark versions.
abstract class GpuRoundBase(child: Expression, scale: Expression, outputType: DataType)
extends GpuBinaryExpression with Serializable with ImplicitCastInputTypes {

override def left: Expression = child
override def right: Expression = scale

def roundMode: RoundMode

override lazy val dataType: DataType = child.dataType match {
// if the new scale is bigger which means we are scaling up,
// keep the original scale as `Decimal` does
case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale)
case t => t
}

// Avoid repeated evaluation since `scale` is a constant int,
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
// by checking if scaleV == null as well.
private lazy val scaleV: Any = scale match {
case _: GpuExpression =>
withResource(scale.columnarEval(null).asInstanceOf[GpuScalar]) { s =>
s.getValue
}
case _ => scale.eval(EmptyRow)
}
private lazy val _scale: Int = scaleV.asInstanceOf[Int]
override def dataType: DataType = outputType

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)

Expand All @@ -590,9 +573,19 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin
val lhsValue = value.getBase
val scaleVal = scale.getValue.asInstanceOf[Int]

dataType match {
case DecimalType.Fixed(_, scaleVal) =>
lhsValue.round(scaleVal, roundMode)
child.dataType match {
case DecimalType.Fixed(_, s) =>
// Only needs to perform round when required scale < input scale
val rounded = if (scaleVal < s) {
lhsValue.round(scaleVal, roundMode)
} else {
lhsValue.incRefCount()
}
withResource(rounded) { _ =>
// Fit the output datatype
rounded.castTo(
DecimalUtil.createCudfDecimal(dataType.asInstanceOf[DecimalType]))
Comment on lines +586 to +587
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cast seems unnecessary in most cases, may be worth checking if the desired cudf type is already the type computed to avoid an unnecessary copy.

}
case ByteType =>
fixUpOverflowInts(() => Scalar.fromByte(0.toByte), scaleVal, lhsValue)
case ShortType =>
Expand Down Expand Up @@ -766,13 +759,13 @@ abstract class GpuRoundBase(child: Expression, scale: Expression) extends GpuBin
}
}

case class GpuBRound(child: Expression, scale: Expression) extends
GpuRoundBase(child, scale) {
case class GpuBRound(child: Expression, scale: Expression, outputType: DataType) extends
GpuRoundBase(child, scale, outputType) {
override def roundMode: RoundMode = RoundMode.HALF_EVEN
}

case class GpuRound(child: Expression, scale: Expression) extends
GpuRoundBase(child, scale) {
case class GpuRound(child: Expression, scale: Expression, outputType: DataType) extends
GpuRoundBase(child, scale, outputType) {
override def roundMode: RoundMode = RoundMode.HALF_UP
}

Expand Down