diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index eaeca2001ab..f11c9649510 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -1464,12 +1464,20 @@ object GpuCast extends Arm { val targetType = DecimalUtil.createCudfDecimal(dt) // If target scale reaches DECIMAL128_MAX_PRECISION, container DECIMAL can not // be created because of precision overflow. In this case, we perform casting op directly. - val casted = if (targetType.getDecimalMaxPrecision == dt.scale) { + val casted = if (DType.DECIMAL128_MAX_PRECISION == dt.scale) { checked.castTo(targetType) } else { - val containerType = DecimalUtils.createDecimalType(dt.precision, dt.scale + 1) + // Increase precision by one along with scale in case of overflow, which may lead to + // the upcast of cuDF decimal type. If precision already hits the max precision, it is safe + // to increase the scale solely because we have checked and replaced out of range values. + val containerType = DecimalUtils.createDecimalType( + dt.precision + 1 min DType.DECIMAL128_MAX_PRECISION, dt.scale + 1) withResource(checked.castTo(containerType)) { container => - container.round(dt.scale, cudf.RoundMode.HALF_UP) + withResource(container.round(dt.scale, cudf.RoundMode.HALF_UP)) { rd => + // The cast here is for cases that cuDF decimal type got promoted as precision + 1. + // Need to convert back to original cuDF type, to keep align with the precision. + rd.castTo(targetType) + } } } // Cast NaN values to nulls diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala index 23cbc961763..b34dfc18821 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala @@ -702,6 +702,22 @@ class CastOpSuite extends GpuExpressionTestSuite { } } + test("cast float/double to decimal (include upcast of cuDF decimal type)") { + val genFloats: SparkSession => DataFrame = (ss: SparkSession) => { + ss.createDataFrame(List(Tuple1(459.288333f), Tuple1(-123.456789f), Tuple1(789.100001f))) + .selectExpr("_1 AS col") + } + testCastToDecimal(DataTypes.FloatType, precision = 9, scale = 6, + customDataGenerator = Option(genFloats)) + + val genDoubles: SparkSession => DataFrame = (ss: SparkSession) => { + ss.createDataFrame(List(Tuple1(459.288333), Tuple1(-123.456789), Tuple1(789.100001))) + .selectExpr("_1 AS col") + } + testCastToDecimal(DataTypes.DoubleType, precision = 9, scale = 6, + customDataGenerator = Option(genDoubles)) + } + test("cast decimal to decimal") { // fromScale == toScale testCastToDecimal(DataTypes.createDecimalType(18, 0),