Skip to content

Commit

Permalink
Fix the overflow of container type when casting floats to decimal (#5766
Browse files Browse the repository at this point in the history
)

Fixes #5765

Fix the potential overflow when casting float/double to decimal. The overflow occurs on the container decimal for HALF_UP round.
 
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx authored Jun 8, 2022
1 parent 51856cb commit 59d05d2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
14 changes: 11 additions & 3 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1464,12 +1464,20 @@ object GpuCast extends Arm {
val targetType = DecimalUtil.createCudfDecimal(dt)
// If target scale reaches DECIMAL128_MAX_PRECISION, container DECIMAL can not
// be created because of precision overflow. In this case, we perform casting op directly.
val casted = if (targetType.getDecimalMaxPrecision == dt.scale) {
val casted = if (DType.DECIMAL128_MAX_PRECISION == dt.scale) {
checked.castTo(targetType)
} else {
val containerType = DecimalUtils.createDecimalType(dt.precision, dt.scale + 1)
// Increase precision by one along with scale in case of overflow, which may lead to
// the upcast of cuDF decimal type. If precision already hits the max precision, it is safe
// to increase the scale solely because we have checked and replaced out of range values.
val containerType = DecimalUtils.createDecimalType(
dt.precision + 1 min DType.DECIMAL128_MAX_PRECISION, dt.scale + 1)
withResource(checked.castTo(containerType)) { container =>
container.round(dt.scale, cudf.RoundMode.HALF_UP)
withResource(container.round(dt.scale, cudf.RoundMode.HALF_UP)) { rd =>
// The cast here is for cases that cuDF decimal type got promoted as precision + 1.
// Need to convert back to original cuDF type, to keep align with the precision.
rd.castTo(targetType)
}
}
}
// Cast NaN values to nulls
Expand Down
16 changes: 16 additions & 0 deletions tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,22 @@ class CastOpSuite extends GpuExpressionTestSuite {
}
}

test("cast float/double to decimal (include upcast of cuDF decimal type)") {
val genFloats: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(459.288333f), Tuple1(-123.456789f), Tuple1(789.100001f)))
.selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.FloatType, precision = 9, scale = 6,
customDataGenerator = Option(genFloats))

val genDoubles: SparkSession => DataFrame = (ss: SparkSession) => {
ss.createDataFrame(List(Tuple1(459.288333), Tuple1(-123.456789), Tuple1(789.100001)))
.selectExpr("_1 AS col")
}
testCastToDecimal(DataTypes.DoubleType, precision = 9, scale = 6,
customDataGenerator = Option(genDoubles))
}

test("cast decimal to decimal") {
// fromScale == toScale
testCastToDecimal(DataTypes.createDecimalType(18, 0),
Expand Down

0 comments on commit 59d05d2

Please sign in to comment.