From f93e7ce565240c4776be86f5be136aa11f394dcc Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 6 Aug 2021 09:37:24 -0500 Subject: [PATCH 1/8] Update nested cast support to be more generic --- docs/supported_ops.md | 4 +- .../src/main/python/cast_test.py | 17 ++ .../com/nvidia/spark/rapids/GpuCast.scala | 261 ++++++++---------- .../com/nvidia/spark/rapids/TypeChecks.scala | 18 +- 4 files changed, 142 insertions(+), 158 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index ed56fe63359..4548f2be41c 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -20587,7 +20587,7 @@ and the accelerator produces the same result. -PS
missing nested BOOLEAN, BYTE, SHORT, LONG, DATE, TIMESTAMP, STRING, DECIMAL, NULL, BINARY, CALENDAR, MAP, STRUCT, UDT
+PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, STRUCT, UDT
@@ -20991,7 +20991,7 @@ and the accelerator produces the same result. -PS
missing nested BOOLEAN, BYTE, SHORT, LONG, DATE, TIMESTAMP, STRING, DECIMAL, NULL, BINARY, CALENDAR, MAP, STRUCT, UDT
+PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, STRUCT, UDT
diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index d04066f4cb2..2a361ffb413 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -30,3 +30,20 @@ def test_cast_empty_string_to_int(): 'CAST(a as INTEGER)', 'CAST(a as LONG)')) +# These tests are not intended to be exhaustive. The scala test CastOpSuite should cover +# just about everything for non-nested values. This is intended to check that the +# recursive code in nested type checks, like arrays, is working properly. So we are going +# pick child types that are simple to cast. Upcasting integer values and casting them to strings +@pytest.mark.parametrize('data_gen,to_type', [ + (ArrayGen(byte_gen), ArrayType(IntegerType())), + (ArrayGen(byte_gen), ArrayType(StringType())), + (ArrayGen(byte_gen), ArrayType(DecimalType(6, 2))), + (ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(IntegerType()))), + (ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(StringType()))), + (ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(DecimalType(6, 2)))), + ], ids=idfn) +def test_cast_nested(data_gen, to_type): + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen).select(f.col('a').cast(to_type))) + + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 561448c16e7..f765a9e19da 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -39,9 +39,9 @@ class CastExprMeta[INPUT <: CastBase]( rule: DataFromReplacementRule) extends UnaryExprMeta[INPUT](cast, conf, parent, rule) { - val fromType = cast.child.dataType - val toType = cast.dataType - var legacyCastToString = ShimLoader.getSparkShims.getLegacyComplexTypeToString() + val fromType: DataType = cast.child.dataType + val toType: DataType = cast.dataType + val legacyCastToString: Boolean = ShimLoader.getSparkShims.getLegacyComplexTypeToString() override def tagExprForGpu(): Unit = recursiveTagExprForGpuCheck() @@ -49,80 +49,62 @@ class CastExprMeta[INPUT <: CastBase]( fromDataType: DataType = fromType, toDataType: DataType = toType) : Unit = { - if (!conf.isCastFloatToDecimalEnabled && toDataType.isInstanceOf[DecimalType] && - (fromDataType == DataTypes.FloatType || fromDataType == DataTypes.DoubleType)) { - willNotWorkOnGpu("the GPU will use a different strategy from Java's BigDecimal to convert " + - "floating point data types to decimals and this can produce results that slightly " + - "differ from the default behavior in Spark. To enable this operation on the GPU, set " + - s"${RapidsConf.ENABLE_CAST_FLOAT_TO_DECIMAL} to true.") - } - if (!conf.isCastFloatToStringEnabled && toDataType == DataTypes.StringType && - (fromDataType == DataTypes.FloatType || fromDataType == DataTypes.DoubleType)) { - willNotWorkOnGpu("the GPU will use different precision than Java's toString method when " + - "converting floating point data types to strings and this can produce results that " + - "differ from the default behavior in Spark. To enable this operation on the GPU, set" + - s" ${RapidsConf.ENABLE_CAST_FLOAT_TO_STRING} to true.") - } - if (!conf.isCastStringToFloatEnabled && cast.child.dataType == DataTypes.StringType && - Seq(DataTypes.FloatType, DataTypes.DoubleType).contains(cast.dataType)) { - willNotWorkOnGpu("Currently hex values aren't supported on the GPU. Also note " + - "that casting from string to float types on the GPU returns incorrect results when the " + - "string represents any number \"1.7976931348623158E308\" <= x < " + - "\"1.7976931348623159E308\" and \"-1.7976931348623159E308\" < x <= " + - "\"-1.7976931348623158E308\" in both these cases the GPU returns Double.MaxValue while " + - "CPU returns \"+Infinity\" and \"-Infinity\" respectively. To enable this operation on " + - "the GPU, set" + s" ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.") - } - if (!conf.isCastStringToTimestampEnabled && fromDataType == DataTypes.StringType - && toDataType == DataTypes.TimestampType) { - willNotWorkOnGpu("the GPU only supports a subset of formats " + - "when casting strings to timestamps. Refer to the CAST documentation " + - "for more details. To enable this operation on the GPU, set" + - s" ${RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP} to true.") - } - // FIXME: https://github.com/NVIDIA/spark-rapids/issues/2019 - if (!conf.isCastStringToDecimalEnabled && cast.child.dataType == DataTypes.StringType && - cast.dataType.isInstanceOf[DecimalType]) { - willNotWorkOnGpu("Currently string to decimal type on the GPU might produce results which " + - "slightly differed from the correct results when the string represents any number " + - "exceeding the max precision that CAST_STRING_TO_FLOAT can keep. For instance, the GPU " + - "returns 99999999999999987 given input string \"99999999999999999\". The cause of " + - "divergence is that we can not cast strings containing scientific notation to decimal " + - "directly. So, we have to cast strings to floats firstly. Then, cast floats to decimals. " + - "The first step may lead to precision loss. To enable this operation on the GPU, set " + - s" ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.") - } - if (fromDataType.isInstanceOf[StructType]) { - val checks = rule.getChecks.get.asInstanceOf[CastChecks] - fromDataType.asInstanceOf[StructType].foreach{field => - recursiveTagExprForGpuCheck(field.dataType) - if (toDataType == StringType) { - if (!checks.gpuCanCast(field.dataType, toDataType)) { - willNotWorkOnGpu(s"Unsupported type ${field.dataType} found in Struct column. " + - s"Casting ${field.dataType} to ${toDataType} not currently supported. Refer to " + - "CAST documentation for more details.") + (fromDataType, toDataType) match { + case (_: FloatType | _: DoubleType, _: DecimalType) if !conf.isCastFloatToDecimalEnabled => + willNotWorkOnGpu("the GPU will use a different strategy from Java's BigDecimal " + + "to convert floating point data types to decimals and this can produce results that " + + "slightly differ from the default behavior in Spark. To enable this operation on " + + s"the GPU, set ${RapidsConf.ENABLE_CAST_FLOAT_TO_DECIMAL} to true.") + case (_: FloatType | _: DoubleType, _: StringType) if !conf.isCastFloatToStringEnabled => + willNotWorkOnGpu("the GPU will use different precision than Java's toString method when " + + "converting floating point data types to strings and this can produce results that " + + "differ from the default behavior in Spark. To enable this operation on the GPU, set" + + s" ${RapidsConf.ENABLE_CAST_FLOAT_TO_STRING} to true.") + case (_: StringType, _: FloatType | _: DoubleType) if !conf.isCastStringToFloatEnabled => + willNotWorkOnGpu("Currently hex values aren't supported on the GPU. Also note " + + "that casting from string to float types on the GPU returns incorrect results when " + + "the string represents any number \"1.7976931348623158E308\" <= x < " + + "\"1.7976931348623159E308\" and \"-1.7976931348623159E308\" < x <= " + + "\"-1.7976931348623158E308\" in both these cases the GPU returns Double.MaxValue " + + "while CPU returns \"+Infinity\" and \"-Infinity\" respectively. To enable this " + + s"operation on the GPU, set ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.") + case (_: StringType, _: TimestampType) if !conf.isCastStringToTimestampEnabled => + willNotWorkOnGpu("the GPU only supports a subset of formats " + + "when casting strings to timestamps. Refer to the CAST documentation " + + "for more details. To enable this operation on the GPU, set" + + s" ${RapidsConf.ENABLE_CAST_STRING_TO_TIMESTAMP} to true.") + case (_: StringType, _: DecimalType) if !conf.isCastStringToDecimalEnabled => + // FIXME: https://github.com/NVIDIA/spark-rapids/issues/2019 + willNotWorkOnGpu("Currently string to decimal type on the GPU might produce " + + "results which slightly differed from the correct results when the string represents " + + "any number exceeding the max precision that CAST_STRING_TO_FLOAT can keep. For " + + "instance, the GPU returns 99999999999999987 given input string " + + "\"99999999999999999\". The cause of divergence is that we can not cast strings " + + "containing scientific notation to decimal directly. So, we have to cast strings " + + "to floats firstly. Then, cast floats to decimals. The first step may lead to " + + "precision loss. To enable this operation on the GPU, set " + + s" ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.") + case (structType: StructType, _) => + val checks = rule.getChecks.get.asInstanceOf[CastChecks] + structType.foreach { field => + recursiveTagExprForGpuCheck(field.dataType) + if (toDataType == StringType) { + if (!checks.gpuCanCast(field.dataType, toDataType)) { + willNotWorkOnGpu(s"Unsupported type ${field.dataType} found in Struct column. " + + s"Casting ${field.dataType} to $toDataType not currently supported. Refer to " + + "`cast` documentation for more details.") + } } } - } - } - (fromDataType, toDataType) match { - case ( - ArrayType(nestedFrom@( - FloatType | - DoubleType | - IntegerType | - ArrayType(_, _)), _), - ArrayType(nestedTo@( - FloatType | - DoubleType | - IntegerType | - ArrayType(_, _)), _)) => recursiveTagExprForGpuCheck(nestedFrom, nestedTo) - - case (nestedFrom@ArrayType(_, _), nestedTo@ArrayType(_, _)) => - willNotWorkOnGpu(s"casting from $nestedFrom to $nestedTo is not supported") - - case _ => () + case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => + recursiveTagExprForGpuCheck(nestedFrom, nestedTo) + + case (MapType(keyFrom, valueFrom, _), MapType(keyTo, valueTo, _)) => + recursiveTagExprForGpuCheck(keyFrom, keyTo) + recursiveTagExprForGpuCheck(valueFrom, valueTo) + + case _ => } } @@ -155,10 +137,10 @@ object GpuCast extends Arm { "[0-9]{2}:[0-9]{2}:[0-9]{2})" + "(.[1-9]*(?:0)?[1-9]+)?(.0*[1-9]+)?(?:.0*)?$" - val INVALID_INPUT_MESSAGE = "Column contains at least one value that is not in the " + + val INVALID_INPUT_MESSAGE: String = "Column contains at least one value that is not in the " + "required range" - val INVALID_FLOAT_CAST_MSG = "At least one value is either null or is an invalid number" + val INVALID_FLOAT_CAST_MSG: String = "At least one value is either null or is an invalid number" /** * Sanitization step for CAST string to date. Inserts a single zero before any @@ -372,36 +354,23 @@ case class GpuCast( case _ => s"CAST(${child.sql} AS ${dataType.sql})" } - override def doColumnar(input: GpuColumnVector): ColumnVector = { - (input.dataType(), dataType) match { - // Filter out casts to Decimal that utilize the ColumnVector to avoid a copy - case (ShortType | IntegerType | LongType, dt: DecimalType) => - castIntegralsToDecimal(input.getBase, dt) - - case (FloatType | DoubleType, dt: DecimalType) => - castFloatsToDecimal(input.getBase, dt) - - case (from: DecimalType, to: DecimalType) => - castDecimalToDecimal(input.getBase, from, to) - - case _ => - recursiveDoColumnar(input.getBase, input.dataType()) - } - } + override def doColumnar(input: GpuColumnVector): ColumnVector = + recursiveDoColumnar(input.getBase, input.dataType()) private def recursiveDoColumnar( input: ColumnView, fromDataType: DataType, - toDataType: DataType = dataType) - : ColumnVector = { + toDataType: DataType = dataType): ColumnVector = { (fromDataType, toDataType) match { case (NullType, to) => GpuColumnVector.columnVectorFromNull(input.getRowCount.toInt, to) + case (DateType, BooleanType | _: NumericType) => // casts from date type to numerics are always null GpuColumnVector.columnVectorFromNull(input.getRowCount.toInt, toDataType) case (DateType, StringType) => input.asStrings("%Y-%m-%d") + case (TimestampType, FloatType | DoubleType) => withResource(input.castTo(DType.INT64)) { asLongs => withResource(Scalar.fromDouble(1000000)) { microsPerSec => @@ -440,6 +409,7 @@ case class GpuCast( } case (TimestampType, StringType) => castTimestampToString(input) + case (StructType(fields), StringType) => castStructToString(input, fields) @@ -508,6 +478,10 @@ case class GpuCast( } } } + case (FloatType | DoubleType, dt: DecimalType) => + castFloatsToDecimal(input, dt) + case (from: DecimalType, to: DecimalType) => + castDecimalToDecimal(input, from, to) case (BooleanType, TimestampType) => // cudf requires casting to a long first. withResource(input.castTo(DType.INT64)) { longs => @@ -561,6 +535,8 @@ case class GpuCast( castFloatsToDecimal(fp, dt) } } + case (ShortType | IntegerType | LongType, dt: DecimalType) => + castIntegralsToDecimal(input, dt) case (ShortType | IntegerType | LongType | ByteType | StringType, BinaryType) => input.asByteList(true) @@ -581,23 +557,10 @@ case class GpuCast( case (_: DecimalType, StringType) => input.castTo(DType.STRING) - case ( - ArrayType(nestedFrom@( - FloatType | - DoubleType | - IntegerType | - ArrayType(_, _)), _), - ArrayType(nestedTo@( - FloatType | - DoubleType | - IntegerType | - ArrayType(_, _)), _)) => { - + case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => withResource(input.getChildColumnView(0))(childView => withResource(recursiveDoColumnar(childView, nestedFrom, nestedTo))(childColumnVector => withResource(input.replaceListChild(childColumnVector))(_.copyToColumnVector()))) - } - case _ => input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) } @@ -657,7 +620,7 @@ case class GpuCast( * @param inclusiveMin Whether the min value is included in the valid range or not * @param inclusiveMax Whether the max value is included in the valid range or not */ - private def replaceOutOfRangeValues(values: ColumnVector, + private def replaceOutOfRangeValues(values: ColumnView, minValue: => Scalar, maxValue: => Scalar, replaceValue: => Scalar, @@ -744,18 +707,17 @@ case class GpuCast( } } - withResource( - Seq(emptyStr, nullStr, separatorStr, spaceStr, leftStr, rightStr) - .safeMap(Scalar.fromString _) - ) { case Seq(emptyScalar, nullScalar, columnScalars@_*) => + withResource(Seq(emptyStr, nullStr, separatorStr, spaceStr, leftStr, rightStr) + .safeMap(Scalar.fromString)) { + case Seq(emptyScalar, nullScalar, columnScalars@_*) => - withResource( - columnScalars.safeMap(s => ColumnVector.fromScalar(s, numRows)) - ) { case Seq(sepColumn, spaceColumn, leftColumn, rightColumn) => + withResource( + columnScalars.safeMap(s => ColumnVector.fromScalar(s, numRows)) + ) { case Seq(sepColumn, spaceColumn, leftColumn, rightColumn) => - doCastStructToString(emptyScalar, nullScalar, sepColumn, - spaceColumn, leftColumn, rightColumn) - } + doCastStructToString(emptyScalar, nullScalar, sepColumn, + spaceColumn, leftColumn, rightColumn) + } } } @@ -1191,8 +1153,25 @@ case class GpuCast( } } - private def castIntegralsToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = { + private def castIntegralsToDecimalAfterCheck(input: ColumnView, dt: DecimalType): ColumnVector = { + if (dt.scale < 0) { + // Rounding is essential when scale is negative, + // so we apply HALF_UP rounding manually to keep align with CpuCast. + withResource(input.castTo(DecimalUtil.createCudfDecimal(dt.precision, 0))) { + scaleZero => scaleZero.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP) + } + } else if (dt.scale > 0) { + // Integer will be enlarged during casting if scale > 0, so we cast input to INT64 + // before casting it to decimal in case of overflow. + withResource(input.castTo(DType.INT64)) { long => + long.castTo(DecimalUtil.createCudfDecimal(dt.precision, dt.scale)) + } + } else { + input.castTo(DecimalUtil.createCudfDecimal(dt.precision, dt.scale)) + } + } + private def castIntegralsToDecimal(input: ColumnView, dt: DecimalType): ColumnVector = { // Use INT64 bounds instead of FLOAT64 bounds, which enables precise comparison. val (lowBound, upBound) = math.pow(10, dt.precision - dt.scale) match { case bound if bound > Long.MaxValue => (Long.MinValue, Long.MaxValue) @@ -1200,38 +1179,23 @@ case class GpuCast( } // At first, we conduct overflow check onto input column. // Then, we cast checked input into target decimal type. - val checkedInput = if (ansiMode) { + if (ansiMode) { assertValuesInRange(input, minValue = Scalar.fromLong(lowBound), maxValue = Scalar.fromLong(upBound)) - input.incRefCount() + castIntegralsToDecimalAfterCheck(input, dt) } else { - replaceOutOfRangeValues(input, + val checkedInput = replaceOutOfRangeValues(input, minValue = Scalar.fromLong(lowBound), maxValue = Scalar.fromLong(upBound), replaceValue = Scalar.fromNull(input.getType)) - } - - withResource(checkedInput) { checked => - if (dt.scale < 0) { - // Rounding is essential when scale is negative, - // so we apply HALF_UP rounding manually to keep align with CpuCast. - withResource(checked.castTo(DecimalUtil.createCudfDecimal(dt.precision, 0))) { - scaleZero => scaleZero.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP) - } - } else if (dt.scale > 0) { - // Integer will be enlarged during casting if scale > 0, so we cast input to INT64 - // before casting it to decimal in case of overflow. - withResource(checked.castTo(DType.INT64)) { long => - long.castTo(DecimalUtil.createCudfDecimal(dt.precision, dt.scale)) - } - } else { - checked.castTo(DecimalUtil.createCudfDecimal(dt.precision, dt.scale)) + withResource(checkedInput) { checked => + castIntegralsToDecimalAfterCheck(checked, dt) } } } - private def castFloatsToDecimal(input: ColumnVector, dt: DecimalType): ColumnVector = { + private def castFloatsToDecimal(input: ColumnView, dt: DecimalType): ColumnVector = { // Approach to minimize difference between CPUCast and GPUCast: // step 1. cast input to FLOAT64 (if necessary) @@ -1269,7 +1233,7 @@ case class GpuCast( val casted = if (DType.DECIMAL64_MAX_PRECISION == dt.scale) { checked.castTo(targetType) } else { - val containerType = DecimalUtil.createCudfDecimal(dt.precision, (dt.scale + 1)) + val containerType = DecimalUtil.createCudfDecimal(dt.precision, dt.scale + 1) withResource(checked.castTo(containerType)) { container => container.round(dt.scale, ai.rapids.cudf.RoundMode.HALF_UP) } @@ -1285,7 +1249,7 @@ case class GpuCast( } } - private def castDecimalToDecimal(input: ColumnVector, + private def castDecimalToDecimal(input: ColumnView, from: DecimalType, to: DecimalType): ColumnVector = { @@ -1293,10 +1257,11 @@ case class GpuCast( val isTo32Bit = DecimalType.is32BitDecimalType(to) val cudfDecimal = DecimalUtil.createCudfDecimal(to.precision, to.scale) - def castCheckedDecimal(checkedInput: ColumnVector): ColumnVector = { + def castCheckedDecimal(checkedInput: ColumnView): ColumnVector = { if (to.scale == from.scale) { if (isFrom32Bit == isTo32Bit) { - checkedInput.incRefCount() + // If the input is a ColumnVector already this will just inc the reference count + checkedInput.copyToColumnVector() } else { // the input is already checked, just cast it checkedInput.castTo(cudfDecimal) @@ -1333,7 +1298,7 @@ case class GpuCast( } def checkForOverflow( - input: ColumnVector, + input: ColumnView, to: DecimalType, isFrom32Bit: Boolean): ColumnVector = { @@ -1377,7 +1342,7 @@ case class GpuCast( // if (isFrom32Bit && prec > Decimal.MAX_INT_DIGITS || // !isFrom32Bit && prec > Decimal.MAX_LONG_DIGITS) if (isFrom32Bit && absBoundPrecision > Decimal.MAX_INT_DIGITS) { - return input.incRefCount() + return input.copyToColumnVector() } val (minValueScalar, maxValueScalar) = if (!isFrom32Bit) { val absBound = math.pow(10, absBoundPrecision).toLong @@ -1391,7 +1356,7 @@ case class GpuCast( minValue = minValueScalar, maxValue = maxValueScalar, inclusiveMin = false, inclusiveMax = false) - input.incRefCount() + input.copyToColumnVector() } else { replaceOutOfRangeValues(input, minValue = minValueScalar, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index f74d0ae5ebf..fe8d7cb68be 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1046,33 +1046,35 @@ class CastChecks extends ExprChecks { val sparkNullSig: TypeSig = all val booleanChecks: TypeSig = integral + fp + BOOLEAN + TIMESTAMP + STRING - val sparkBooleanSig: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + STRING + val sparkBooleanSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + STRING val integralChecks: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + STRING + BINARY - val sparkIntegralSig: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + STRING + BINARY + val sparkIntegralSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + STRING + BINARY val fpChecks: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + STRING - val sparkFpSig: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + STRING + val sparkFpSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + STRING val dateChecks: TypeSig = integral + fp + BOOLEAN + TIMESTAMP + DATE + STRING - val sparkDateSig: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + DATE + STRING + val sparkDateSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + DATE + STRING val timestampChecks: TypeSig = integral + fp + BOOLEAN + TIMESTAMP + DATE + STRING - val sparkTimestampSig: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + DATE + STRING + val sparkTimestampSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + DATE + STRING val stringChecks: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + DATE + STRING + BINARY - val sparkStringSig: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + DATE + CALENDAR + STRING + BINARY + val sparkStringSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + DATE + CALENDAR + STRING + BINARY val binaryChecks: TypeSig = none val sparkBinarySig: TypeSig = STRING + BINARY val decimalChecks: TypeSig = DECIMAL_64 + STRING - val sparkDecimalSig: TypeSig = gpuNumeric + BOOLEAN + TIMESTAMP + STRING + val sparkDecimalSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + STRING val calendarChecks: TypeSig = none val sparkCalendarSig: TypeSig = CALENDAR + STRING - val arrayChecks: TypeSig = ARRAY.nested(FLOAT + DOUBLE + INT + ARRAY) + val arrayChecks: TypeSig = ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY) + + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + + "to desired child type") val sparkArraySig: TypeSig = STRING + ARRAY.nested(all) From 861f60ab16c1406fef012d8cc4b5ebd40ff97af2 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 6 Aug 2021 10:41:47 -0500 Subject: [PATCH 2/8] Struct to Struct casting support --- docs/supported_ops.md | 8 +-- .../src/main/python/cast_test.py | 4 ++ .../rapids/shims/spark311/Spark311Shims.scala | 10 +++- .../com/nvidia/spark/rapids/GpuCast.scala | 60 ++++++++++++++----- .../com/nvidia/spark/rapids/TypeChecks.scala | 8 ++- 5 files changed, 67 insertions(+), 23 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 4548f2be41c..b7dfe3b19d7 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -20587,7 +20587,7 @@ and the accelerator produces the same result. -PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, STRUCT, UDT
+PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, UDT
@@ -20631,7 +20631,7 @@ and the accelerator produces the same result. -NS +PS
the struct's children must also support being cast to the desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, UDT
@@ -20991,7 +20991,7 @@ and the accelerator produces the same result. -PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, STRUCT, UDT
+PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, UDT
@@ -21035,7 +21035,7 @@ and the accelerator produces the same result. -NS +PS
the struct's children must also support being cast to the desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, UDT
diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index 2a361ffb413..27fa5fca087 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -41,6 +41,10 @@ def test_cast_empty_string_to_int(): (ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(IntegerType()))), (ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(StringType()))), (ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(DecimalType(6, 2)))), + (StructGen([('a', byte_gen)]), StructType([StructField('a', IntegerType())])), + (StructGen([('a', byte_gen), ('c', short_gen)]), StructType([StructField('b', IntegerType()), StructField('c', ShortType())])), + (StructGen([('a', ArrayGen(byte_gen)), ('c', short_gen)]), StructType([StructField('a', ArrayType(IntegerType())), StructField('c', LongType())])), + (ArrayGen(StructGen([('a', byte_gen), ('b', byte_gen)])), ArrayType(StringType())) ], ids=idfn) def test_cast_nested(data_gen, to_type): assert_gpu_and_cpu_are_equal_collect( diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala index 44033dc8ae7..18e67809bef 100644 --- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala +++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala @@ -147,13 +147,19 @@ class Spark311Shims extends Spark301Shims { // calendarChecks are the same - override val arrayChecks: TypeSig = none + override val arrayChecks: TypeSig = + ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) + + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + + "to desired child type") override val sparkArraySig: TypeSig = ARRAY.nested(all) override val mapChecks: TypeSig = none override val sparkMapSig: TypeSig = MAP.nested(all) - override val structChecks: TypeSig = none + override val structChecks: TypeSig = + STRUCT.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) + + psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " + + "desired child type") override val sparkStructSig: TypeSig = STRUCT.nested(all) override val udtChecks: TypeSig = none diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index f765a9e19da..ef5304b2626 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -84,19 +84,15 @@ class CastExprMeta[INPUT <: CastBase]( "to floats firstly. Then, cast floats to decimals. The first step may lead to " + "precision loss. To enable this operation on the GPU, set " + s" ${RapidsConf.ENABLE_CAST_STRING_TO_FLOAT} to true.") - case (structType: StructType, _) => - val checks = rule.getChecks.get.asInstanceOf[CastChecks] + case (structType: StructType, StringType) => structType.foreach { field => - recursiveTagExprForGpuCheck(field.dataType) - if (toDataType == StringType) { - if (!checks.gpuCanCast(field.dataType, toDataType)) { - willNotWorkOnGpu(s"Unsupported type ${field.dataType} found in Struct column. " + - s"Casting ${field.dataType} to $toDataType not currently supported. Refer to " + - "`cast` documentation for more details.") - } - } + recursiveTagExprForGpuCheck(field.dataType, StringType) + } + case (fromStructType: StructType, toStructType: StructType) => + fromStructType.zip(toStructType).foreach { + case (fromChild, toChild) => + recursiveTagExprForGpuCheck(fromChild.dataType, toChild.dataType) } - case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => recursiveTagExprForGpuCheck(nestedFrom, nestedTo) @@ -355,12 +351,17 @@ case class GpuCast( } override def doColumnar(input: GpuColumnVector): ColumnVector = - recursiveDoColumnar(input.getBase, input.dataType()) + recursiveDoColumnar(input.getBase, input.dataType(), dataType) private def recursiveDoColumnar( input: ColumnView, fromDataType: DataType, - toDataType: DataType = dataType): ColumnVector = { + toDataType: DataType): ColumnVector = { + + if (DataType.equalsStructurally(fromDataType, toDataType)) { + return input.copyToColumnVector() + } + (fromDataType, toDataType) match { case (NullType, to) => GpuColumnVector.columnVectorFromNull(input.getRowCount.toInt, to) @@ -561,6 +562,10 @@ case class GpuCast( withResource(input.getChildColumnView(0))(childView => withResource(recursiveDoColumnar(childView, nestedFrom, nestedTo))(childColumnVector => withResource(input.replaceListChild(childColumnVector))(_.copyToColumnVector()))) + + case (from: StructType, to: StructType) => + castStructToStruct(from, to, input) + case _ => input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) } @@ -683,7 +688,7 @@ case class GpuCast( // 3.1+: {firstCol columns += leftColumn.incRefCount() withResource(input.getChildColumnView(0)) { firstColumnView => - columns += recursiveDoColumnar(firstColumnView, inputSchema.head.dataType) + columns += recursiveDoColumnar(firstColumnView, inputSchema.head.dataType, StringType) } for (nonFirstIndex <- 1 until numInputColumns) { withResource(input.getChildColumnView(nonFirstIndex)) { nonFirstColumnView => @@ -691,7 +696,7 @@ case class GpuCast( // 3.1+: ", " columns += sepColumn.incRefCount() val nonFirstColumn = recursiveDoColumnar(nonFirstColumnView, - inputSchema(nonFirstIndex).dataType) + inputSchema(nonFirstIndex).dataType, StringType) if (legacyCastToString) { // " " if non-null columns += spaceColumn.mergeAndSetValidity(BinaryOp.BITWISE_AND, nonFirstColumnView) @@ -1153,6 +1158,31 @@ case class GpuCast( } } + private def castStructToStruct( + from: StructType, + to: StructType, + input: ColumnView): ColumnVector = { + withResource(new ArrayBuffer[ColumnVector](from.length)) { childColumns => + from.indices.foreach { index => + childColumns += recursiveDoColumnar( + input.getChildColumnView(index), + from(index).dataType, + to(index).dataType) + } + withResource(ColumnView.makeStructView(childColumns: _*)) { casted => + if (input.getNullCount == 0) { + casted.copyToColumnVector() + } else { + withResource(input.isNull) { isNull => + withResource(GpuScalar.from(null, to)) { nullVal => + isNull.ifElse(nullVal, casted) + } + } + } + } + } + } + private def castIntegralsToDecimalAfterCheck(input: ColumnView, dt: DecimalType): ColumnVector = { if (dt.scale < 0) { // Rounding is essential when scale is negative, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index fe8d7cb68be..d968534afda 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1072,7 +1072,8 @@ class CastChecks extends ExprChecks { val calendarChecks: TypeSig = none val sparkCalendarSig: TypeSig = CALENDAR + STRING - val arrayChecks: TypeSig = ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY) + + val arrayChecks: TypeSig = ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + + ARRAY + BINARY + STRUCT) + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + "to desired child type") @@ -1082,7 +1083,10 @@ class CastChecks extends ExprChecks { val sparkMapSig: TypeSig = STRING + MAP.nested(all) val structChecks: TypeSig = psNote(TypeEnum.STRING, "the struct's children must also support " + - "being cast to string") + "being cast to string") + + STRUCT.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) + + psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " + + "desired child type") val sparkStructSig: TypeSig = STRING + STRUCT.nested(all) val udtChecks: TypeSig = none From 674fc574e84244f12a9af7d8d2cd2616e5cad8c7 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 6 Aug 2021 11:04:51 -0500 Subject: [PATCH 3/8] Map to Map cast support --- docs/supported_ops.md | 12 ++++---- .../src/main/python/cast_test.py | 5 +++- .../rapids/shims/spark311/Spark311Shims.scala | 7 +++-- .../com/nvidia/spark/rapids/GpuCast.scala | 30 +++++++++++++++++++ .../com/nvidia/spark/rapids/TypeChecks.scala | 11 ++++--- 5 files changed, 52 insertions(+), 13 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index b7dfe3b19d7..acaa9cf1f38 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -20587,7 +20587,7 @@ and the accelerator produces the same result. -PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, UDT
+PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
@@ -20609,7 +20609,7 @@ and the accelerator produces the same result. -NS +PS
the map's kay and value must also support being cast to the desired child types;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
@@ -20631,7 +20631,7 @@ and the accelerator produces the same result. -PS
the struct's children must also support being cast to the desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, UDT
+PS
the struct's children must also support being cast to the desired child type(s);
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
@@ -20991,7 +20991,7 @@ and the accelerator produces the same result. -PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, UDT
+PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
@@ -21013,7 +21013,7 @@ and the accelerator produces the same result. -NS +PS
the map's kay and value must also support being cast to the desired child types;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
@@ -21035,7 +21035,7 @@ and the accelerator produces the same result. -PS
the struct's children must also support being cast to the desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, MAP, UDT
+PS
the struct's children must also support being cast to the desired child type(s);
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index 27fa5fca087..0cf329097b9 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -44,7 +44,10 @@ def test_cast_empty_string_to_int(): (StructGen([('a', byte_gen)]), StructType([StructField('a', IntegerType())])), (StructGen([('a', byte_gen), ('c', short_gen)]), StructType([StructField('b', IntegerType()), StructField('c', ShortType())])), (StructGen([('a', ArrayGen(byte_gen)), ('c', short_gen)]), StructType([StructField('a', ArrayType(IntegerType())), StructField('c', LongType())])), - (ArrayGen(StructGen([('a', byte_gen), ('b', byte_gen)])), ArrayType(StringType())) + (ArrayGen(StructGen([('a', byte_gen), ('b', byte_gen)])), ArrayType(StringType())), + (MapGen(ByteGen(nullable=False), byte_gen), MapType(StringType(), StringType())), + (MapGen(ShortGen(nullable=False), ArrayGen(byte_gen)), MapType(IntegerType(), ArrayType(ShortType()))), + (MapGen(ShortGen(nullable=False), ArrayGen(StructGen([('a', byte_gen)]))), MapType(IntegerType(), ArrayType(StructType([StructField('b', ShortType())])))) ], ids=idfn) def test_cast_nested(data_gen, to_type): assert_gpu_and_cpu_are_equal_collect( diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala index 18e67809bef..4484292cc8b 100644 --- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala +++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala @@ -153,13 +153,16 @@ class Spark311Shims extends Spark301Shims { "to desired child type") override val sparkArraySig: TypeSig = ARRAY.nested(all) - override val mapChecks: TypeSig = none + override val mapChecks: TypeSig = + MAP.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT + MAP) + + psNote(TypeEnum.MAP, "the map's kay and value must also support being cast to the " + + "desired child types") override val sparkMapSig: TypeSig = MAP.nested(all) override val structChecks: TypeSig = STRUCT.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) + psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " + - "desired child type") + "desired child type(s)") override val sparkStructSig: TypeSig = STRUCT.nested(all) override val udtChecks: TypeSig = none diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index ef5304b2626..e55dbf4b694 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -566,6 +566,9 @@ case class GpuCast( case (from: StructType, to: StructType) => castStructToStruct(from, to, input) + case (from: MapType, to: MapType) => + castMapToMap(from, to, input) + case _ => input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) } @@ -1158,6 +1161,33 @@ case class GpuCast( } } + private def castMapToMap( + from: MapType, + to: MapType, + input: ColumnView): ColumnVector = { + // For cudf a map is a list of (key, value) structs, but lets keep it in ColumnView as much + // as possible + withResource(input.getChildColumnView(0)) { kvStructColumn => + val castKey = withResource(kvStructColumn.getChildColumnView(0)) { keyColumn => + recursiveDoColumnar(keyColumn, from.keyType, to.keyType) + } + withResource(castKey) { castKey => + val castValue = withResource(kvStructColumn.getChildColumnView(1)) { valueColumn => + recursiveDoColumnar(valueColumn, from.valueType, to.valueType) + } + withResource(castValue) { castValue => + withResource(ColumnView.makeStructView(castKey, castValue)) { castKvStructColumn => + // We don't have to worry about null in the key/value struct because they are not + // allowed for maps in Spark + withResource(input.replaceListChild(castKvStructColumn)) { replacedView => + replacedView.copyToColumnVector() + } + } + } + } + } + } + private def castStructToStruct( from: StructType, to: StructType, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index d968534afda..334c60c69bd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1073,20 +1073,23 @@ class CastChecks extends ExprChecks { val sparkCalendarSig: TypeSig = CALENDAR + STRING val arrayChecks: TypeSig = ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + - ARRAY + BINARY + STRUCT) + + ARRAY + BINARY + STRUCT + MAP) + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + "to desired child type") val sparkArraySig: TypeSig = STRING + ARRAY.nested(all) - val mapChecks: TypeSig = none + val mapChecks: TypeSig = MAP.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + + STRUCT + MAP) + + psNote(TypeEnum.MAP, "the map's kay and value must also support being cast to the " + + "desired child types") val sparkMapSig: TypeSig = STRING + MAP.nested(all) val structChecks: TypeSig = psNote(TypeEnum.STRING, "the struct's children must also support " + "being cast to string") + - STRUCT.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) + + STRUCT.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT + MAP) + psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " + - "desired child type") + "desired child type(s)") val sparkStructSig: TypeSig = STRING + STRUCT.nested(all) val udtChecks: TypeSig = none From c68b7d4a1fa70b16b30dc3cf8843e63eebd4a57b Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 6 Aug 2021 11:14:26 -0500 Subject: [PATCH 4/8] Missed signoff Signed-off-by: Robert (Bobby) Evans From a878c0d83a43cbeded5e4df640884904fa82611a Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 6 Aug 2021 11:55:30 -0500 Subject: [PATCH 5/8] Addressed review comments --- .../nvidia/spark/rapids/shims/spark311/Spark311Shims.scala | 4 ++-- .../src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala index 4484292cc8b..a4ad64b48ba 100644 --- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala +++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala @@ -150,12 +150,12 @@ class Spark311Shims extends Spark301Shims { override val arrayChecks: TypeSig = ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + - "to desired child type") + "the desired child type") override val sparkArraySig: TypeSig = ARRAY.nested(all) override val mapChecks: TypeSig = MAP.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT + MAP) + - psNote(TypeEnum.MAP, "the map's kay and value must also support being cast to the " + + psNote(TypeEnum.MAP, "the map's key and value must also support being cast to the " + "desired child types") override val sparkMapSig: TypeSig = MAP.nested(all) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 334c60c69bd..fe5d03ada0c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -1075,13 +1075,13 @@ class CastChecks extends ExprChecks { val arrayChecks: TypeSig = ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT + MAP) + psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " + - "to desired child type") + "the desired child type") val sparkArraySig: TypeSig = STRING + ARRAY.nested(all) val mapChecks: TypeSig = MAP.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT + MAP) + - psNote(TypeEnum.MAP, "the map's kay and value must also support being cast to the " + + psNote(TypeEnum.MAP, "the map's key and value must also support being cast to the " + "desired child types") val sparkMapSig: TypeSig = STRING + MAP.nested(all) From b6272aacc9f9e039499e2f4374890294ec40c5cd Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 6 Aug 2021 12:20:43 -0500 Subject: [PATCH 6/8] Forgot to regenerate docs --- docs/supported_ops.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index acaa9cf1f38..25973533021 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -20587,7 +20587,7 @@ and the accelerator produces the same result. -PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
+PS
The array's child type must also support being cast to the desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
@@ -20609,7 +20609,7 @@ and the accelerator produces the same result. -PS
the map's kay and value must also support being cast to the desired child types;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
+PS
the map's key and value must also support being cast to the desired child types;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
@@ -20991,7 +20991,7 @@ and the accelerator produces the same result. -PS
The array's child type must also support being cast to to desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
+PS
The array's child type must also support being cast to the desired child type;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
@@ -21013,7 +21013,7 @@ and the accelerator produces the same result. -PS
the map's kay and value must also support being cast to the desired child types;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
+PS
the map's key and value must also support being cast to the desired child types;
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested CALENDAR, UDT
From dfd783da3eadb0af67256f37e2e825ebe5aeb58a Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 6 Aug 2021 12:35:16 -0500 Subject: [PATCH 7/8] Forgot to remove tests that are now not needed --- .../src/main/python/array_test.py | 35 +------------------ 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index 298918ed4b4..260bd20c4f2 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -165,39 +165,6 @@ def test_array_element_at_all_null_ansi_not_fail(data_gen): conf={'spark.sql.ansi.enabled':True, 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}) - -@pytest.mark.parametrize('child_gen', [ - float_gen, - double_gen, - int_gen -], ids=idfn) -@pytest.mark.parametrize('child_to_type', [ - FloatType(), - DoubleType(), - IntegerType(), -], ids=idfn) -@pytest.mark.parametrize('depth', [1, 2, 3], ids=idfn) -def test_array_cast_recursive(child_gen, child_to_type, depth): - def cast_func(spark): - depth_rng = range(0, depth) - nested_gen = reduce(lambda dg, i: ArrayGen(dg, max_length=int(max(1, 16 / (2 ** i)))), - depth_rng, child_gen) - nested_type = reduce(lambda t, _: ArrayType(t), depth_rng, child_to_type) - df = two_col_df(spark, int_gen, nested_gen) - res = df.select(df.b.cast(nested_type)) - return res - assert_gpu_and_cpu_are_equal_collect(cast_func) - - -@allow_non_gpu('ProjectExec', 'Alias', 'Cast') -def test_array_cast_fallback(): - def cast_float_to_double(spark): - df = two_col_df(spark, int_gen, ArrayGen(int_gen)) - res = df.select(df.b.cast(ArrayType(StringType()))) - return res - assert_gpu_and_cpu_are_equal_collect(cast_float_to_double) - - @pytest.mark.parametrize('child_gen', [ byte_gen, string_gen, @@ -214,4 +181,4 @@ def cast_array(spark): df = two_col_df(spark, int_gen, ArrayGen(child_gen)) res = df.select(df.b.cast(ArrayType(child_to_type))) return res - assert_gpu_and_cpu_are_equal_collect(cast_array) \ No newline at end of file + assert_gpu_and_cpu_are_equal_collect(cast_array) From ba1715c94d46d18020664668dee0f33e96b75e4d Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 6 Aug 2021 14:20:10 -0500 Subject: [PATCH 8/8] Fixed a bug and made sure no more like it exist --- .../src/main/python/array_test.py | 17 -- .../src/main/python/cast_test.py | 3 +- .../com/nvidia/spark/rapids/GpuCast.scala | 271 ++++++++++-------- 3 files changed, 150 insertions(+), 141 deletions(-) diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index 260bd20c4f2..655db6e9353 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -165,20 +165,3 @@ def test_array_element_at_all_null_ansi_not_fail(data_gen): conf={'spark.sql.ansi.enabled':True, 'spark.sql.legacy.allowNegativeScaleOfDecimal': True}) -@pytest.mark.parametrize('child_gen', [ - byte_gen, - string_gen, - decimal_gen_default, -], ids=idfn) -@pytest.mark.parametrize('child_to_type', [ - FloatType(), - DoubleType(), - IntegerType(), -], ids=idfn) -@allow_non_gpu('ProjectExec', 'Alias', 'Cast') -def test_array_cast_bad_from_good_to_fallback(child_gen, child_to_type): - def cast_array(spark): - df = two_col_df(spark, int_gen, ArrayGen(child_gen)) - res = df.select(df.b.cast(ArrayType(child_to_type))) - return res - assert_gpu_and_cpu_are_equal_collect(cast_array) diff --git a/integration_tests/src/main/python/cast_test.py b/integration_tests/src/main/python/cast_test.py index 0cf329097b9..3e2f28ca1f7 100644 --- a/integration_tests/src/main/python/cast_test.py +++ b/integration_tests/src/main/python/cast_test.py @@ -36,6 +36,7 @@ def test_cast_empty_string_to_int(): # pick child types that are simple to cast. Upcasting integer values and casting them to strings @pytest.mark.parametrize('data_gen,to_type', [ (ArrayGen(byte_gen), ArrayType(IntegerType())), + (ArrayGen(StringGen('[0-9]{1,5}')), ArrayType(IntegerType())), (ArrayGen(byte_gen), ArrayType(StringType())), (ArrayGen(byte_gen), ArrayType(DecimalType(6, 2))), (ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(IntegerType()))), @@ -52,5 +53,3 @@ def test_cast_empty_string_to_int(): def test_cast_nested(data_gen, to_type): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).select(f.col('a').cast(to_type))) - - diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index e55dbf4b694..a574b0124cd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -285,78 +285,14 @@ object GpuCast extends Arm { } } } -} - -/** - * Casts using the GPU - */ -case class GpuCast( - child: Expression, - dataType: DataType, - ansiMode: Boolean = false, - timeZoneId: Option[String] = None, - legacyCastToString: Boolean = false) - extends GpuUnaryExpression with TimeZoneAwareExpression with NullIntolerant { - - import GpuCast._ - - override def toString: String = if (ansiMode) { - s"ansi_cast($child as ${dataType.simpleString})" - } else { - s"cast($child as ${dataType.simpleString})" - } - - override def checkInputDataTypes(): TypeCheckResult = { - if (Cast.canCast(child.dataType, dataType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}") - } - } - - override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable - - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = - copy(timeZoneId = Option(timeZoneId)) - - /** - * Under certain conditions during hash partitioning, Spark will attempt to replace casts - * with semantically equivalent expressions. This method is overridden to prevent Spark - * from substituting non-GPU expressions. - */ - override def semanticEquals(other: Expression): Boolean = other match { - case g: GpuExpression => - if (this == g) { - true - } else { - super.semanticEquals(g) - } - case _ => false - } - - // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; - // Otherwise behave like Expression.resolved. - override lazy val resolved: Boolean = - childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) - private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) - - override def sql: String = dataType match { - // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, - // this type of casting can only be introduced by the analyzer, and can be omitted when - // converting back to SQL query string. - case _: ArrayType | _: MapType | _: StructType => child.sql - case _ => s"CAST(${child.sql} AS ${dataType.sql})" - } - - override def doColumnar(input: GpuColumnVector): ColumnVector = - recursiveDoColumnar(input.getBase, input.dataType(), dataType) private def recursiveDoColumnar( input: ColumnView, fromDataType: DataType, - toDataType: DataType): ColumnVector = { + toDataType: DataType, + ansiMode: Boolean, + legacyCastToString: Boolean): ColumnVector = { if (DataType.equalsStructurally(fromDataType, toDataType)) { return input.copyToColumnVector() @@ -412,7 +348,7 @@ case class GpuCast( castTimestampToString(input) case (StructType(fields), StringType) => - castStructToString(input, fields) + castStructToString(input, fields, ansiMode, legacyCastToString) // ansi cast from larger-than-integer integral types, to integer case (LongType, IntegerType) if ansiMode => @@ -463,7 +399,7 @@ case class GpuCast( withResource(FloatUtils.infinityToNulls(inputWithNansToNull)) { inputWithoutNanAndInfinity => if (fromDataType == FloatType && - ShimLoader.getSparkShims.hasCastFloatTimestampUpcast) { + ShimLoader.getSparkShims.hasCastFloatTimestampUpcast) { withResource(inputWithoutNanAndInfinity.castTo(DType.FLOAT64)) { doubles => withResource(doubles.mul(microsPerSec, DType.INT64)) { inputTimesMicrosCv => @@ -480,9 +416,9 @@ case class GpuCast( } } case (FloatType | DoubleType, dt: DecimalType) => - castFloatsToDecimal(input, dt) + castFloatsToDecimal(input, dt, ansiMode) case (from: DecimalType, to: DecimalType) => - castDecimalToDecimal(input, from, to) + castDecimalToDecimal(input, from, to, ansiMode) case (BooleanType, TimestampType) => // cudf requires casting to a long first. withResource(input.castTo(DType.INT64)) { longs => @@ -519,7 +455,7 @@ case class GpuCast( case DateType => castStringToDate(trimmed) case TimestampType => - castStringToTimestamp(trimmed) + castStringToTimestamp(trimmed, ansiMode) case FloatType | DoubleType => castStringToFloats(trimmed, ansiMode, GpuColumnVector.getNonNestedRapidsType(toDataType)) @@ -533,41 +469,44 @@ case class GpuCast( // string to fp64. Then, cast fp64 to target decimal type to enforce HALF_UP rounding. withResource(input.strip()) { trimmed => withResource(castStringToFloats(trimmed, ansiMode, DType.FLOAT64)) { fp => - castFloatsToDecimal(fp, dt) + castFloatsToDecimal(fp, dt, ansiMode) } } case (ShortType | IntegerType | LongType, dt: DecimalType) => - castIntegralsToDecimal(input, dt) + castIntegralsToDecimal(input, dt, ansiMode) case (ShortType | IntegerType | LongType | ByteType | StringType, BinaryType) => input.asByteList(true) case (ShortType | IntegerType | LongType, dt: DecimalType) => withResource(input.copyToColumnVector()) { inputVector => - castIntegralsToDecimal(inputVector, dt) + castIntegralsToDecimal(inputVector, dt, ansiMode) } case (FloatType | DoubleType, dt: DecimalType) => withResource(input.copyToColumnVector()) { inputVector => - castFloatsToDecimal(inputVector, dt) + castFloatsToDecimal(inputVector, dt, ansiMode) } case (from: DecimalType, to: DecimalType) => - castDecimalToDecimal(input.copyToColumnVector(), from, to) + castDecimalToDecimal(input.copyToColumnVector(), from, to, ansiMode) case (_: DecimalType, StringType) => input.castTo(DType.STRING) case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => - withResource(input.getChildColumnView(0))(childView => - withResource(recursiveDoColumnar(childView, nestedFrom, nestedTo))(childColumnVector => - withResource(input.replaceListChild(childColumnVector))(_.copyToColumnVector()))) + withResource(input.getChildColumnView(0)) { childView => + withResource(recursiveDoColumnar(childView, nestedFrom, nestedTo, + ansiMode, legacyCastToString)) { childColumnVector => + withResource(input.replaceListChild(childColumnVector))(_.copyToColumnVector()) + } + } case (from: StructType, to: StructType) => - castStructToStruct(from, to, input) + castStructToStruct(from, to, input, ansiMode, legacyCastToString) case (from: MapType, to: MapType) => - castMapToMap(from, to, input) + castMapToMap(from, to, input, ansiMode, legacyCastToString) case _ => input.castTo(GpuColumnVector.getNonNestedRapidsType(toDataType)) @@ -585,10 +524,10 @@ case class GpuCast( * @throws IllegalStateException if any values in the column are not within the specified range */ private def assertValuesInRange(values: ColumnView, - minValue: => Scalar, - maxValue: => Scalar, - inclusiveMin: Boolean = true, - inclusiveMax: Boolean = true): Unit = { + minValue: => Scalar, + maxValue: => Scalar, + inclusiveMin: Boolean = true, + inclusiveMax: Boolean = true): Unit = { def throwIfAny(cv: ColumnView): Unit = { withResource(cv) { cv => @@ -629,11 +568,11 @@ case class GpuCast( * @param inclusiveMax Whether the max value is included in the valid range or not */ private def replaceOutOfRangeValues(values: ColumnView, - minValue: => Scalar, - maxValue: => Scalar, - replaceValue: => Scalar, - inclusiveMin: Boolean = true, - inclusiveMax: Boolean = true): ColumnVector = { + minValue: => Scalar, + maxValue: => Scalar, + replaceValue: => Scalar, + inclusiveMin: Boolean = true, + inclusiveMax: Boolean = true): ColumnVector = { withResource(minValue) { minValue => withResource(maxValue) { maxValue => @@ -668,8 +607,11 @@ case class GpuCast( } } - private def castStructToString(input: ColumnView, - inputSchema: Array[StructField]): ColumnVector = { + private def castStructToString( + input: ColumnView, + inputSchema: Array[StructField], + ansiMode: Boolean, + legacyCastToString: Boolean): ColumnVector = { val (leftStr, rightStr) = if (legacyCastToString) ("[", "]") else ("{", "}") val emptyStr = "" @@ -680,18 +622,19 @@ case class GpuCast( val numInputColumns = input.getNumChildren def doCastStructToString( - emptyScalar: Scalar, - nullScalar: Scalar, - sepColumn: ColumnVector, - spaceColumn: ColumnVector, - leftColumn: ColumnVector, - rightColumn: ColumnVector) = { + emptyScalar: Scalar, + nullScalar: Scalar, + sepColumn: ColumnVector, + spaceColumn: ColumnVector, + leftColumn: ColumnVector, + rightColumn: ColumnVector): ColumnVector = { withResource(ArrayBuffer.empty[ColumnVector]) { columns => // legacy: [firstCol // 3.1+: {firstCol columns += leftColumn.incRefCount() withResource(input.getChildColumnView(0)) { firstColumnView => - columns += recursiveDoColumnar(firstColumnView, inputSchema.head.dataType, StringType) + columns += recursiveDoColumnar(firstColumnView, inputSchema.head.dataType, StringType, + ansiMode, legacyCastToString) } for (nonFirstIndex <- 1 until numInputColumns) { withResource(input.getChildColumnView(nonFirstIndex)) { nonFirstColumnView => @@ -699,7 +642,7 @@ case class GpuCast( // 3.1+: ", " columns += sepColumn.incRefCount() val nonFirstColumn = recursiveDoColumnar(nonFirstColumnView, - inputSchema(nonFirstIndex).dataType, StringType) + inputSchema(nonFirstIndex).dataType, StringType, ansiMode, legacyCastToString) if (legacyCastToString) { // " " if non-null columns += spaceColumn.mergeAndSetValidity(BinaryOp.BITWISE_AND, nonFirstColumnView) @@ -799,7 +742,7 @@ case class GpuCast( } } withResource(sanitized.castTo(dType)) { parsedInt => - withResource(GpuScalar.from(null, dataType)) { nullVal => + withResource(Scalar.fromNull(dType)) { nullVal => isInt.ifElse(parsedInt, nullVal) } } @@ -1072,7 +1015,7 @@ case class GpuCast( } } - private def castStringToTimestamp(input: ColumnVector): ColumnVector = { + private def castStringToTimestamp(input: ColumnVector, ansiMode: Boolean): ColumnVector = { // special timestamps val today = DateUtils.currentDate() @@ -1164,16 +1107,19 @@ case class GpuCast( private def castMapToMap( from: MapType, to: MapType, - input: ColumnView): ColumnVector = { + input: ColumnView, + ansiMode: Boolean, + legacyCastToString: Boolean): ColumnVector = { // For cudf a map is a list of (key, value) structs, but lets keep it in ColumnView as much // as possible withResource(input.getChildColumnView(0)) { kvStructColumn => val castKey = withResource(kvStructColumn.getChildColumnView(0)) { keyColumn => - recursiveDoColumnar(keyColumn, from.keyType, to.keyType) + recursiveDoColumnar(keyColumn, from.keyType, to.keyType, ansiMode, legacyCastToString) } withResource(castKey) { castKey => val castValue = withResource(kvStructColumn.getChildColumnView(1)) { valueColumn => - recursiveDoColumnar(valueColumn, from.valueType, to.valueType) + recursiveDoColumnar(valueColumn, from.valueType, to.valueType, + ansiMode, legacyCastToString) } withResource(castValue) { castValue => withResource(ColumnView.makeStructView(castKey, castValue)) { castKvStructColumn => @@ -1191,13 +1137,17 @@ case class GpuCast( private def castStructToStruct( from: StructType, to: StructType, - input: ColumnView): ColumnVector = { + input: ColumnView, + ansiMode: Boolean, + legacyCastToString: Boolean): ColumnVector = { withResource(new ArrayBuffer[ColumnVector](from.length)) { childColumns => from.indices.foreach { index => childColumns += recursiveDoColumnar( input.getChildColumnView(index), from(index).dataType, - to(index).dataType) + to(index).dataType, + ansiMode, + legacyCastToString) } withResource(ColumnView.makeStructView(childColumns: _*)) { casted => if (input.getNullCount == 0) { @@ -1231,7 +1181,10 @@ case class GpuCast( } } - private def castIntegralsToDecimal(input: ColumnView, dt: DecimalType): ColumnVector = { + private def castIntegralsToDecimal( + input: ColumnView, + dt: DecimalType, + ansiMode: Boolean): ColumnVector = { // Use INT64 bounds instead of FLOAT64 bounds, which enables precise comparison. val (lowBound, upBound) = math.pow(10, dt.precision - dt.scale) match { case bound if bound > Long.MaxValue => (Long.MinValue, Long.MaxValue) @@ -1255,7 +1208,10 @@ case class GpuCast( } } - private def castFloatsToDecimal(input: ColumnView, dt: DecimalType): ColumnVector = { + private def castFloatsToDecimal( + input: ColumnView, + dt: DecimalType, + ansiMode: Boolean): ColumnVector = { // Approach to minimize difference between CPUCast and GPUCast: // step 1. cast input to FLOAT64 (if necessary) @@ -1309,9 +1265,11 @@ case class GpuCast( } } - private def castDecimalToDecimal(input: ColumnView, + private def castDecimalToDecimal( + input: ColumnView, from: DecimalType, - to: DecimalType): ColumnVector = { + to: DecimalType, + ansiMode: Boolean): ColumnVector = { val isFrom32Bit = DecimalType.is32BitDecimalType(from) val isTo32Bit = DecimalType.is32BitDecimalType(to) @@ -1329,38 +1287,39 @@ case class GpuCast( } else if (to.scale > from.scale) { checkedInput.castTo(cudfDecimal) } else { - withResource(checkedInput.round(to.scale, ai.rapids.cudf.RoundMode.HALF_UP)) { - rounded => rounded.castTo(cudfDecimal) - } + withResource(checkedInput.round(to.scale, ai.rapids.cudf.RoundMode.HALF_UP)) { + rounded => rounded.castTo(cudfDecimal) + } } } if (to.scale <= from.scale) { if (!isFrom32Bit && isTo32Bit) { // check for overflow when 64bit => 32bit - withResource(checkForOverflow(input, to, isFrom32Bit)) { checkedInput => + withResource(checkForOverflow(input, to, isFrom32Bit, ansiMode)) { checkedInput => castCheckedDecimal(checkedInput) } } else { if (to.scale < 0 && !SQLConf.get.allowNegativeScaleOfDecimalEnabled) { throw new IllegalStateException(s"Negative scale is not allowed: ${to.scale}. " + - s"You can use spark.sql.legacy.allowNegativeScaleOfDecimal=true " + - s"to enable legacy mode to allow it.") + s"You can use spark.sql.legacy.allowNegativeScaleOfDecimal=true " + + s"to enable legacy mode to allow it.") } castCheckedDecimal(input) } } else { // from.scale > to.scale - withResource(checkForOverflow(input, to, isFrom32Bit)) { checkedInput => + withResource(checkForOverflow(input, to, isFrom32Bit, ansiMode)) { checkedInput => castCheckedDecimal(checkedInput) } } } def checkForOverflow( - input: ColumnView, - to: DecimalType, - isFrom32Bit: Boolean): ColumnVector = { + input: ColumnView, + to: DecimalType, + isFrom32Bit: Boolean, + ansiMode: Boolean): ColumnVector = { // Decimal numbers in general terms have two parts, a part before decimal (whole number) // and a part after decimal (fractional number) @@ -1428,3 +1387,71 @@ case class GpuCast( checkedInput } } + +/** + * Casts using the GPU + */ +case class GpuCast( + child: Expression, + dataType: DataType, + ansiMode: Boolean = false, + timeZoneId: Option[String] = None, + legacyCastToString: Boolean = false) + extends GpuUnaryExpression with TimeZoneAwareExpression with NullIntolerant { + + import GpuCast._ + + override def toString: String = if (ansiMode) { + s"ansi_cast($child as ${dataType.simpleString})" + } else { + s"cast($child as ${dataType.simpleString})" + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (Cast.canCast(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}") + } + } + + override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + /** + * Under certain conditions during hash partitioning, Spark will attempt to replace casts + * with semantically equivalent expressions. This method is overridden to prevent Spark + * from substituting non-GPU expressions. + */ + override def semanticEquals(other: Expression): Boolean = other match { + case g: GpuExpression => + if (this == g) { + true + } else { + super.semanticEquals(g) + } + case _ => false + } + + // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; + // Otherwise behave like Expression.resolved. + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) + + private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) + + override def sql: String = dataType match { + // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, + // this type of casting can only be introduced by the analyzer, and can be omitted when + // converting back to SQL query string. + case _: ArrayType | _: MapType | _: StructType => child.sql + case _ => s"CAST(${child.sql} AS ${dataType.sql})" + } + + override def doColumnar(input: GpuColumnVector): ColumnVector = + recursiveDoColumnar(input.getBase, input.dataType(), dataType, ansiMode, legacyCastToString) + +}