diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java index e624852e19d..12236dcd54d 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/iceberg/spark/source/GpuIcebergReader.java @@ -23,6 +23,7 @@ import java.util.NoSuchElementException; import ai.rapids.cudf.Scalar; +import com.nvidia.spark.rapids.CastOptions$; import com.nvidia.spark.rapids.GpuCast; import com.nvidia.spark.rapids.GpuColumnVector; import com.nvidia.spark.rapids.GpuScalar; @@ -158,7 +159,7 @@ static ColumnarBatch addUpcastsIfNeeded(ColumnarBatch batch, Schema expectedSche GpuColumnVector oldColumn = columns[i]; columns[i] = GpuColumnVector.from( GpuCast.doCast(oldColumn.getBase(), oldColumn.dataType(), expectedSparkType, - false, false, false), expectedSparkType); + CastOptions$.MODULE$.DEFAULT_CAST_OPTIONS()), expectedSparkType); } ColumnarBatch newBatch = new ColumnarBatch(columns, batch.numRows()); columns = null; diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala index 2c3c98d20b9..1a188d47660 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala @@ -144,9 +144,7 @@ case class ApproxPercentileFromTDigestExpr( // array and return that (after converting from Double to finalDataType) withResource(cv.getBase.approxPercentile(Array(p))) { percentiles => withResource(percentiles.extractListElement(0)) { childView => - withResource(doCast(childView, DataTypes.DoubleType, finalDataType, - ansiMode = false, legacyCastToString = false, - stringToDateAnsiModeEnabled = false)) { childCv => + withResource(doCast(childView, DataTypes.DoubleType, finalDataType)) { childCv => GpuColumnVector.from(childCv.copyToColumnVector(), dataType) } } @@ -159,9 +157,7 @@ case class ApproxPercentileFromTDigestExpr( GpuColumnVector.from(percentiles.incRefCount(), dataType) } else { withResource(percentiles.getChildColumnView(0)) { childView => - withResource(doCast(childView, DataTypes.DoubleType, finalDataType, - ansiMode = false, legacyCastToString = false, - stringToDateAnsiModeEnabled = false)) { childCv => + withResource(doCast(childView, DataTypes.DoubleType, finalDataType)) { childCv => withResource(percentiles.replaceListChild(childCv)) { x => GpuColumnVector.from(x.copyToColumnVector(), dataType) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 573f3ced068..7a4bdb592b4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -163,6 +163,82 @@ final class CastExprMeta[INPUT <: UnaryExpression with TimeZoneAwareExpression w override protected val needTimezoneTagging: Boolean = false } +object CastOptions { + val DEFAULT_CAST_OPTIONS = new CastOptions(false, false, false) + val ARITH_ANSI_OPTIONS = new CastOptions(false, true, false) + + def getArithmeticCastOptions(failOnError: Boolean): CastOptions = + if (failOnError) ARITH_ANSI_OPTIONS else DEFAULT_CAST_OPTIONS +} + +/** + * This class is used to encapsulate parameters to use to help determine how to + * cast + * + * @param legacyCastComplexTypesToString If we should use legacy casting method + * @param ansiMode Whether the cast should be ANSI compliant + * @param stringToDateAnsiMode Whether to cast String to Date using ANSI compliance + */ +class CastOptions( + legacyCastComplexTypesToString: Boolean, + ansiMode: Boolean, + stringToDateAnsiMode: Boolean) extends Serializable { + + /** + * Retuns the left bracket to use when surrounding brackets when converting + * map or struct types to string + * example: + * [ "a" -> "b"] when legacyCastComplexTypesToString is enabled + * otherwise { "a" -> "b" } + */ + val leftBracket: String = if (legacyCastComplexTypesToString) "[" else "{" + + /** + * Returns the right bracket to use when surrounding brackets when converting + * map or struct types to string + * example: + * [ "a" -> "b"] when legacyCastComplexTypesToString is enabled + * otherwise { "a" -> "b" } + */ + val rightBracket: String = if (legacyCastComplexTypesToString) "]" else "}" + + /** + * Returns the string value to use to represent null elements in array/struct/map. + */ + val nullString: String = if (legacyCastComplexTypesToString) "" else "null" + + /** + * Returns whether a decimal value with exponents should be + * converted to a plain string, exactly like Java BigDecimal.toPlainString() + * example: + * plain string value of decimal 1.23E+7 is 12300000 + */ + val useDecimalPlainString: Boolean = ansiMode + + /** + * Returns whether the binary data should be printed as hex values + * instead of ascii values + */ + val useHexFormatForBinary: Boolean = false + + /** + * Returns whether we should cast using ANSI compliance + */ + val isAnsiMode: Boolean = ansiMode + + /** + * Returns whether we should use ANSI compliance when casting a String + * to Date + */ + val useAnsiStringToDateMode: Boolean = stringToDateAnsiMode + + /** + * Returns whether we should use legacy behavior to convert complex types + * like structs/maps to a String + */ + val useLegacyComplexTypesToString: Boolean = legacyCastComplexTypesToString +} + object GpuCast { private val DATE_REGEX_YYYY_MM_DD = "\\A\\d{4}\\-\\d{1,2}\\-\\d{1,2}([ T](:?[\\r\\n]|.)*)?\\Z" @@ -191,14 +267,13 @@ object GpuCast { input: ColumnView, fromDataType: DataType, toDataType: DataType, - ansiMode: Boolean, - legacyCastToString: Boolean, - stringToDateAnsiModeEnabled: Boolean): ColumnVector = { - + options: CastOptions = CastOptions.DEFAULT_CAST_OPTIONS): ColumnVector = { if (DataType.equalsStructurally(fromDataType, toDataType)) { return input.copyToColumnVector() } + val ansiMode = options.isAnsiMode + (fromDataType, toDataType) match { case (NullType, to) => GpuColumnVector.columnVectorFromNull(input.getRowCount.toInt, to) @@ -249,8 +324,7 @@ object GpuCast { castTimestampToString(input) case (StructType(fields), StringType) => - castStructToString(input, fields, ansiMode, legacyCastToString, - stringToDateAnsiModeEnabled) + castStructToString(input, fields, options) // ansi cast from larger-than-long integral-like types, to long case (dt: DecimalType, LongType) if ansiMode => @@ -441,7 +515,7 @@ object GpuCast { case BooleanType => castStringToBool(trimmed, ansiMode) case DateType => - if (stringToDateAnsiModeEnabled) { + if (options.useAnsiStringToDateMode) { castStringToDateAnsi(trimmed, ansiMode) } else { castStringToDate(trimmed) @@ -467,26 +541,22 @@ object GpuCast { case (ArrayType(nestedFrom, _), ArrayType(nestedTo, _)) => withResource(input.getChildColumnView(0)) { childView => - withResource(doCast(childView, nestedFrom, nestedTo, - ansiMode, legacyCastToString, stringToDateAnsiModeEnabled)) { childColumnVector => + withResource(doCast(childView, nestedFrom, nestedTo, options)) { childColumnVector => withResource(input.replaceListChild(childColumnVector))(_.copyToColumnVector()) } } case (ArrayType(elementType, _), StringType) => - castArrayToString( - input, elementType, ansiMode, legacyCastToString, stringToDateAnsiModeEnabled - ) + castArrayToString(input, elementType, options) case (from: StructType, to: StructType) => - castStructToStruct(from, to, input, ansiMode, legacyCastToString, - stringToDateAnsiModeEnabled) + castStructToStruct(from, to, input, options) case (from: MapType, to: MapType) => - castMapToMap(from, to, input, ansiMode, legacyCastToString, stringToDateAnsiModeEnabled) + castMapToMap(from, to, input, options) case (from: MapType, _: StringType) => - castMapToString(input, from, ansiMode, legacyCastToString, stringToDateAnsiModeEnabled) + castMapToString(input, from, options) case (dayTime: DataType, _: StringType) if GpuTypeShims.isSupportedDayTimeType(dayTime) => GpuIntervalUtils.toDayTimeIntervalString(input, dayTime) @@ -684,16 +754,21 @@ object GpuCast { */ private def concatenateStringArrayElements( input: ColumnView, - legacyCastToString: Boolean): ColumnVector = { + options: CastOptions, + castingBinaryData: Boolean = false): ColumnVector = { + + import options._ + val emptyStr = "" val spaceStr = " " - val nullStr = if (legacyCastToString) "" else "null" - val sepStr = if (legacyCastToString) "," else ", " + val sepStr = if (useHexFormatForBinary && castingBinaryData) spaceStr + else if (useLegacyComplexTypesToString) "," else ", " + withResource( - Seq(emptyStr, spaceStr, nullStr, sepStr).safeMap(Scalar.fromString) - ){ case Seq(empty, space, nullRep, sep) => + Seq(emptyStr, spaceStr, nullString, sepStr).safeMap(Scalar.fromString) + ) { case Seq(empty, space, nullRep, sep) => - val withSpacesIfLegacy = if (!legacyCastToString) { + val withSpacesIfLegacy = if (!useLegacyComplexTypesToString) { withResource(input.getChildColumnView(0)) { _.replaceNulls(nullRep) } @@ -724,7 +799,7 @@ object GpuCast { val strCol = withResource(concatenated) { _.replaceNulls(empty) } - if (!legacyCastToString) { + if (!useLegacyComplexTypesToString) { strCol } else { // If the first char of a string is ' ', remove it (only for legacyCastToString = true) @@ -741,26 +816,24 @@ object GpuCast { private def castArrayToString(input: ColumnView, elementType: DataType, - ansiMode: Boolean, - legacyCastToString: Boolean, - stringToDateAnsiModeEnabled: Boolean): ColumnVector = { + options: CastOptions): ColumnVector = { - val (leftStr, rightStr) = ("[", "]") + // We use square brackets for arrays regardless + val (leftStr, rightStr) = ("[", "]") val emptyStr = "" - val nullStr = if (legacyCastToString) "" else "null" val numRows = input.getRowCount.toInt withResource( - Seq(leftStr, rightStr, emptyStr, nullStr).safeMap(Scalar.fromString) + Seq(leftStr, rightStr, emptyStr, options.nullString).safeMap(Scalar.fromString) ){ case Seq(left, right, empty, nullRep) => val strChildContainsNull = withResource(input.getChildColumnView(0)) {child => doCast( - child, elementType, StringType, ansiMode, legacyCastToString, stringToDateAnsiModeEnabled) + child, elementType, StringType, options) } val concatenated = withResource(strChildContainsNull) { _ => withResource(input.replaceListChild(strChildContainsNull)) { - concatenateStringArrayElements(_, legacyCastToString) + concatenateStringArrayElements(_, options) } } @@ -782,45 +855,45 @@ object GpuCast { private def castMapToString( input: ColumnView, from: MapType, - ansiMode: Boolean, - legacyCastToString: Boolean, - stringToDateAnsiModeEnabled: Boolean): ColumnVector = { + options: CastOptions): ColumnVector = { val numRows = input.getRowCount.toInt val (arrowStr, emptyStr, spaceStr) = ("->", "", " ") - val (leftStr, rightStr, nullStr) = - if (legacyCastToString) ("[", "]", "") else ("{", "}", "null") // cast the key column and value column to string columns val (strKey, strValue) = withResource(input.getChildColumnView(0)) { kvStructColumn => val strKey = withResource(kvStructColumn.getChildColumnView(0)) { keyColumn => doCast( - keyColumn, from.keyType, StringType, ansiMode, - legacyCastToString, stringToDateAnsiModeEnabled) + keyColumn, from.keyType, StringType, options) } val strValue = closeOnExcept(strKey) {_ => withResource(kvStructColumn.getChildColumnView(1)) { valueColumn => doCast( - valueColumn, from.valueType, StringType, ansiMode, - legacyCastToString, stringToDateAnsiModeEnabled) + valueColumn, from.valueType, StringType, options) } } (strKey, strValue) } + import options._ // concatenate the key-value pairs to string // Example: ("key", "value") -> "key -> value" withResource( - Seq(leftStr, rightStr, arrowStr, emptyStr, nullStr, spaceStr).safeMap(Scalar.fromString) + Seq(leftBracket, + rightBracket, + arrowStr, + emptyStr, + nullString, + spaceStr).safeMap(Scalar.fromString) ) { case Seq(leftScalar, rightScalar, arrowScalar, emptyScalar, nullScalar, spaceScalar) => val strElements = withResource(Seq(strKey, strValue)) { case Seq(strKey, strValue) => val numElements = strKey.getRowCount.toInt withResource(Seq(spaceScalar, arrowScalar).safeMap(ColumnVector.fromScalar(_, numElements)) - ) {case Seq(spaceCol, arrowCol) => - if (legacyCastToString) { + ) { case Seq(spaceCol, arrowCol) => + if (useLegacyComplexTypesToString) { withResource( spaceCol.mergeAndSetValidity(BinaryOp.BITWISE_AND, strValue) - ) {spaceBetweenSepAndVal => + ) { spaceBetweenSepAndVal => ColumnVector.stringConcatenate( emptyScalar, nullScalar, Array(strKey, spaceCol, arrowCol, spaceBetweenSepAndVal, strValue)) @@ -835,7 +908,7 @@ object GpuCast { // concatenate elements val strCol = withResource(strElements) { _ => withResource(input.replaceListChild(strElements)) { - concatenateStringArrayElements(_, legacyCastToString) + concatenateStringArrayElements(_, options) } } val resPreValidityFix = withResource(strCol) { _ => @@ -855,14 +928,12 @@ object GpuCast { private def castStructToString( input: ColumnView, inputSchema: Array[StructField], - ansiMode: Boolean, - legacyCastToString: Boolean, - stringToDateAnsiModeEnabled: Boolean): ColumnVector = { + options: CastOptions): ColumnVector = { + + import options._ - val (leftStr, rightStr) = if (legacyCastToString) ("[", "]") else ("{", "}") val emptyStr = "" - val nullStr = if (legacyCastToString) "" else "null" - val separatorStr = if (legacyCastToString) "," else ", " + val separatorStr = if (useLegacyComplexTypesToString) "," else ", " val spaceStr = " " val numRows = input.getRowCount.toInt val numInputColumns = input.getNumChildren @@ -879,8 +950,7 @@ object GpuCast { // 3.1+: {firstCol columns += leftColumn.incRefCount() withResource(input.getChildColumnView(0)) { firstColumnView => - columns += doCast(firstColumnView, inputSchema.head.dataType, StringType, - ansiMode, legacyCastToString, stringToDateAnsiModeEnabled) + columns += doCast(firstColumnView, inputSchema.head.dataType, StringType, options) } for (nonFirstIndex <- 1 until numInputColumns) { withResource(input.getChildColumnView(nonFirstIndex)) { nonFirstColumnView => @@ -888,9 +958,8 @@ object GpuCast { // 3.1+: ", " columns += sepColumn.incRefCount() val nonFirstColumn = doCast(nonFirstColumnView, - inputSchema(nonFirstIndex).dataType, StringType, ansiMode, legacyCastToString, - stringToDateAnsiModeEnabled) - if (legacyCastToString) { + inputSchema(nonFirstIndex).dataType, StringType, options) + if (useLegacyComplexTypesToString) { // " " if non-null columns += spaceColumn.mergeAndSetValidity(BinaryOp.BITWISE_AND, nonFirstColumnView) } @@ -905,8 +974,8 @@ object GpuCast { } } - withResource(Seq(emptyStr, nullStr, separatorStr, spaceStr, leftStr, rightStr) - .safeMap(Scalar.fromString)) { + withResource(Seq(emptyStr, nullString, separatorStr, spaceStr, leftBracket, rightBracket) + .safeMap(Scalar.fromString)) { case Seq(emptyScalar, nullScalar, columnScalars@_*) => withResource( @@ -1216,20 +1285,16 @@ object GpuCast { from: MapType, to: MapType, input: ColumnView, - ansiMode: Boolean, - legacyCastToString: Boolean, - stringToDateAnsiModeEnabled: Boolean): ColumnVector = { + options: CastOptions): ColumnVector = { // For cudf a map is a list of (key, value) structs, but lets keep it in ColumnView as much // as possible withResource(input.getChildColumnView(0)) { kvStructColumn => val castKey = withResource(kvStructColumn.getChildColumnView(0)) { keyColumn => - doCast(keyColumn, from.keyType, to.keyType, ansiMode, legacyCastToString, - stringToDateAnsiModeEnabled) + doCast(keyColumn, from.keyType, to.keyType, options) } withResource(castKey) { castKey => val castValue = withResource(kvStructColumn.getChildColumnView(1)) { valueColumn => - doCast(valueColumn, from.valueType, to.valueType, - ansiMode, legacyCastToString, stringToDateAnsiModeEnabled) + doCast(valueColumn, from.valueType, to.valueType, options) } withResource(castValue) { castValue => withResource(ColumnView.makeStructView(castKey, castValue)) { castKvStructColumn => @@ -1248,17 +1313,13 @@ object GpuCast { from: StructType, to: StructType, input: ColumnView, - ansiMode: Boolean, - legacyCastToString: Boolean, - stringToDateAnsiModeEnabled: Boolean): ColumnVector = { + options: CastOptions): ColumnVector = { withResource(new ArrayBuffer[ColumnVector](from.length)) { childColumns => from.indices.foreach { index => childColumns += doCast( input.getChildColumnView(index), from(index).dataType, - to(index).dataType, - ansiMode, - legacyCastToString, stringToDateAnsiModeEnabled) + to(index).dataType, options) } withResource(ColumnView.makeStructView(childColumns: _*)) { casted => if (input.getNullCount == 0) { @@ -1492,12 +1553,15 @@ case class GpuCast( dataType: DataType, ansiMode: Boolean = false, timeZoneId: Option[String] = None, - legacyCastToString: Boolean = false, + legacyCastComplexTypesToString: Boolean = false, stringToDateAnsiModeEnabled: Boolean = false) extends GpuUnaryExpression with TimeZoneAwareExpression with NullIntolerant { import GpuCast._ + private val options: CastOptions = + new CastOptions(legacyCastComplexTypesToString, ansiMode, stringToDateAnsiModeEnabled) + // when ansi mode is enabled, some cast expressions can throw exceptions on invalid inputs override def hasSideEffects: Boolean = super.hasSideEffects || { (child.dataType, dataType) match { @@ -1563,7 +1627,5 @@ case class GpuCast( } override def doColumnar(input: GpuColumnVector): ColumnVector = - doCast(input.getBase, input.dataType(), dataType, ansiMode, legacyCastToString, - stringToDateAnsiModeEnabled) - + doCast(input.getBase, input.dataType(), dataType, options) } \ No newline at end of file diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index 6e48d7bb8d0..f971ddd2aa4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -383,8 +383,7 @@ object GpuOrcScan { case (f: DType, t: DType) if f.isDecimalType && t.isDecimalType => val fromDataType = DecimalType(f.getDecimalMaxPrecision, -f.getScale) val toDataType = DecimalType(t.getDecimalMaxPrecision, -t.getScale) - GpuCast.doCast(col, fromDataType, toDataType, ansiMode=false, legacyCastToString = false, - stringToDateAnsiModeEnabled = false) + GpuCast.doCast(col, fromDataType, toDataType) case (DType.STRING, DType.STRING) if originalFromDt.isInstanceOf[CharType] => // Trim trailing whitespace off of output strings, to match CPU output. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala index af9cfc8f717..3b5aea39bf0 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuJsonToStructs.scala @@ -197,7 +197,7 @@ case class GpuJsonToStructs( } else { val col = rawTable.getColumn(i) // getSparkType is only used to get the from type for cast - doCast(col, getSparkType(col), dtype, false, false, false) + doCast(col, getSparkType(col), dtype) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala index e1cbc6b0c06..2e8d46e6f6f 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala @@ -401,13 +401,16 @@ trait GpuDecimalMultiplyBase extends GpuExpression { def regularMultiply(batch: ColumnarBatch): GpuColumnVector = { val castLhs = withResource(left.columnarEval(batch)) { lhs => - GpuCast.doCast(lhs.getBase, lhs.dataType(), intermediateLhsType, ansiMode = failOnError, - legacyCastToString = false, stringToDateAnsiModeEnabled = false) + GpuCast.doCast( + lhs.getBase, + lhs.dataType(), + intermediateLhsType, + CastOptions.getArithmeticCastOptions(failOnError)) } val ret = withResource(castLhs) { castLhs => val castRhs = withResource(right.columnarEval(batch)) { rhs => - GpuCast.doCast(rhs.getBase, rhs.dataType(), intermediateRhsType, ansiMode = failOnError, - legacyCastToString = false, stringToDateAnsiModeEnabled = false) + GpuCast.doCast(rhs.getBase, rhs.dataType(), intermediateRhsType, + CastOptions.getArithmeticCastOptions(failOnError)) } withResource(castRhs) { castRhs => withResource(castLhs.mul(castRhs, @@ -436,7 +439,7 @@ trait GpuDecimalMultiplyBase extends GpuExpression { } withResource(ret) { ret => GpuColumnVector.from(GpuCast.doCast(ret, intermediateResultType, dataType, - ansiMode = failOnError, legacyCastToString = false, stringToDateAnsiModeEnabled = false), + CastOptions.getArithmeticCastOptions(failOnError)), dataType) } } @@ -851,14 +854,18 @@ trait GpuDecimalDivideBase extends GpuExpression { def regularDivide(batch: ColumnarBatch): GpuColumnVector = { val castLhs = withResource(left.columnarEval(batch)) { lhs => - GpuCast.doCast(lhs.getBase, lhs.dataType(), intermediateLhsType, ansiMode = failOnError, - legacyCastToString = false, stringToDateAnsiModeEnabled = false) + GpuCast.doCast( + lhs.getBase, + lhs.dataType(), + intermediateLhsType, + CastOptions.getArithmeticCastOptions(failOnError)) + } val ret = withResource(castLhs) { castLhs => val castRhs = withResource(right.columnarEval(batch)) { rhs => withResource(divByZeroFixes(rhs.getBase)) { fixed => - GpuCast.doCast(fixed, rhs.dataType(), intermediateRhsType, ansiMode = failOnError, - legacyCastToString = false, stringToDateAnsiModeEnabled = false) + GpuCast.doCast(fixed, rhs.dataType(), intermediateRhsType, + CastOptions.getArithmeticCastOptions(failOnError)) } } withResource(castRhs) { castRhs => @@ -871,7 +878,7 @@ trait GpuDecimalDivideBase extends GpuExpression { // in the common case with us. It will also handle rounding the result to the final scale // to match what Spark does. GpuColumnVector.from(GpuCast.doCast(ret, intermediateResultType, dataType, - ansiMode = failOnError, legacyCastToString = false, stringToDateAnsiModeEnabled = false), + CastOptions.getArithmeticCastOptions(failOnError)), dataType) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index d373b2459c1..e0fe58b0857 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -460,7 +460,7 @@ case class GpuSecondsToTimestamp(child: Expression) extends GpuNumberToTimestamp } case DoubleType | FloatType => (input: GpuColumnVector) => { - GpuCast.doCast(input.getBase, input.dataType, TimestampType, false, false, false) + GpuCast.doCast(input.getBase, input.dataType, TimestampType) } case dt: DecimalType => (input: GpuColumnVector) => { diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/arithmetic.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/arithmetic.scala index ce978508826..2a4ec2cc902 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/arithmetic.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/arithmetic.scala @@ -95,11 +95,11 @@ trait GpuAddSub extends CudfBinaryArithmetic { } else { // eval operands using the output precision val castLhs = withResource(left.columnarEval(batch)) { lhs => - GpuCast.doCast(lhs.getBase(), leftInputType, resultType, false, false, false) + GpuCast.doCast(lhs.getBase(), leftInputType, resultType) } val castRhs = closeOnExcept(castLhs){ _ => withResource(right.columnarEval(batch)) { rhs => - GpuCast.doCast(rhs.getBase(), rightInputType, resultType, false, false, false) + GpuCast.doCast(rhs.getBase(), rightInputType, resultType) } } @@ -342,14 +342,14 @@ case class GpuDecimalRemainder(left: Expression, right: Expression) private def regularRemainder(batch: ColumnarBatch): GpuColumnVector = { val castLhs = withResource(left.columnarEval(batch)) { lhs => - GpuCast.doCast(lhs.getBase, lhs.dataType(), intermediateLhsType, ansiMode = failOnError, - legacyCastToString = false, stringToDateAnsiModeEnabled = false) + GpuCast.doCast(lhs.getBase, lhs.dataType(), intermediateLhsType, + CastOptions.getArithmeticCastOptions(failOnError)) } withResource(castLhs) { castLhs => val castRhs = withResource(right.columnarEval(batch)) { rhs => withResource(divByZeroFixes(rhs.getBase)) { fixed => - GpuCast.doCast(fixed, rhs.dataType(), intermediateRhsType, ansiMode = failOnError, - legacyCastToString = false, stringToDateAnsiModeEnabled = false) + GpuCast.doCast(fixed, rhs.dataType(), intermediateRhsType, + CastOptions.getArithmeticCastOptions(failOnError)) } } withResource(castRhs) { castRhs =>