diff --git a/docs/configs.md b/docs/configs.md
index 6d8f74d1c3e..a65ca2db758 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -51,11 +51,13 @@ Name | Description | Default Value
spark.rapids.shuffle.ucx.managementServerHost|The host to be used to start the management server|null
spark.rapids.shuffle.ucx.useWakeup|When set to true, use UCX's event-based progress (epoll) in order to wake up the progress thread when needed, instead of a hot loop.|true
spark.rapids.sql.batchSizeBytes|Set the target number of bytes for a GPU batch. Splits sizes for input data is covered by separate configs. The maximum setting is 2 GB to avoid exceeding the cudf row count limit of a column.|2147483647
+spark.rapids.sql.castDecimalToString.enabled|When set to true, casting from decimal to string is supported on the GPU. The GPU does NOT produce exact same string as spark produces, but producing strings which are semantically equal. For instance, given input BigDecimal(123, -2), the GPU produces "12300", which spark produces "1.23E+4".|false
spark.rapids.sql.castFloatToDecimal.enabled|Casting from floating point types to decimal on the GPU returns results that have tiny difference compared to results returned from CPU.|false
spark.rapids.sql.castFloatToIntegralTypes.enabled|Casting from floating point types to integral types on the GPU supports a slightly different range of values when using Spark 3.1.0 or later. Refer to the CAST documentation for more details.|false
spark.rapids.sql.castFloatToString.enabled|Casting from floating point types to string on the GPU returns results that have a different precision than the default results of Spark.|false
spark.rapids.sql.castStringToDecimal.enabled|When set to true, enables casting from strings to decimal type on the GPU. Currently string to decimal type on the GPU might produce results which slightly differed from the correct results when the string represents any number exceeding the max precision that CAST_STRING_TO_FLOAT can keep. For instance, the GPU returns 99999999999999987 given input string "99999999999999999". The cause of divergence is that we can not cast strings containing scientific notation to decimal directly. So, we have to cast strings to floats firstly. Then, cast floats to decimals. The first step may lead to precision loss.|false
spark.rapids.sql.castStringToFloat.enabled|When set to true, enables casting from strings to float types (float, double) on the GPU. Currently hex values aren't supported on the GPU. Also note that casting from string to float types on the GPU returns incorrect results when the string represents any number "1.7976931348623158E308" <= x < "1.7976931348623159E308" and "-1.7976931348623158E308" >= x > "-1.7976931348623159E308" in both these cases the GPU returns Double.MaxValue while CPU returns "+Infinity" and "-Infinity" respectively|false
+spark.rapids.sql.castStringToInteger.enabled|When set to true, enables casting from strings to integer types (byte, short, int, long) on the GPU. Casting from string to integer types on the GPU returns incorrect results when the string represents a number larger than Long.MaxValue or smaller than Long.MinValue.|false
spark.rapids.sql.castStringToTimestamp.enabled|When set to true, casting from string to timestamp is supported on the GPU. The GPU only supports a subset of formats when casting strings to timestamps. Refer to the CAST documentation for more details.|false
spark.rapids.sql.concurrentGpuTasks|Set the number of tasks that can execute concurrently per GPU. Tasks may temporarily block when the number of concurrent tasks in the executor exceeds this amount. Allowing too many concurrent tasks on the same GPU may lead to GPU out of memory errors.|1
spark.rapids.sql.csvTimestamps.enabled|When set to true, enables the CSV parser to read timestamps. The default output format for Spark includes a timezone at the end. Anything except the UTC timezone is not supported. Timestamps after 2038 and before 1902 are also not supported.|false
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 8bec5a084d7..0e80d01ec1e 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -18117,7 +18117,7 @@ and the accelerator produces the same result.
NS |
|
NS |
-NS |
+S |
S* |
|
|
@@ -18521,7 +18521,7 @@ and the accelerator produces the same result.
NS |
|
NS |
-NS |
+S |
S* |
|
|
diff --git a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala
index 9c658a63701..380e302d05b 100644
--- a/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala
+++ b/shims/spark311/src/main/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala
@@ -139,7 +139,7 @@ class Spark311Shims extends Spark301Shims {
// stringChecks are the same
// binaryChecks are the same
- override val decimalChecks: TypeSig = none
+ override val decimalChecks: TypeSig = DECIMAL + STRING
override val sparkDecimalSig: TypeSig = numeric + BOOLEAN + STRING
// calendarChecks are the same
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
index f2a24a15a7a..672c1ac2068 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala
@@ -409,6 +409,9 @@ case class GpuCast(
castDecimalToDecimal(inputVector, from, to)
}
+ case (_: DecimalType, StringType) =>
+ input.castTo(DType.STRING)
+
case _ =>
input.castTo(GpuColumnVector.getNonNestedRapidsType(dataType))
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index ae6cfffc3ae..9a2392a2bdb 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -591,6 +591,22 @@ object RapidsConf {
.booleanConf
.createWithDefault(false)
+ val ENABLE_CAST_STRING_TO_INTEGER = conf("spark.rapids.sql.castStringToInteger.enabled")
+ .doc("When set to true, enables casting from strings to integer types (byte, short, " +
+ "int, long) on the GPU. Casting from string to integer types on the GPU returns incorrect " +
+ "results when the string represents a number larger than Long.MaxValue or smaller than " +
+ "Long.MinValue.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val ENABLE_CAST_DECIMAL_TO_STRING = conf("spark.rapids.sql.castDecimalToString.enabled")
+ .doc("When set to true, casting from decimal to string is supported on the GPU. The GPU " +
+ "does NOT produce exact same string as spark produces, but producing strings which are " +
+ "semantically equal. For instance, given input BigDecimal(123, -2), the GPU produces " +
+ "\"12300\", which spark produces \"1.23E+4\".")
+ .booleanConf
+ .createWithDefault(false)
+
val ENABLE_CSV_TIMESTAMPS = conf("spark.rapids.sql.csvTimestamps.enabled")
.doc("When set to true, enables the CSV parser to read timestamps. The default output " +
"format for Spark includes a timezone at the end. Anything except the UTC timezone is not " +
@@ -1200,6 +1216,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val isCastFloatToIntegralTypesEnabled: Boolean = get(ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES)
+ lazy val isCastDecimalToStringEnabled: Boolean = get(ENABLE_CAST_DECIMAL_TO_STRING)
+
lazy val isCsvTimestampEnabled: Boolean = get(ENABLE_CSV_TIMESTAMPS)
lazy val isParquetEnabled: Boolean = get(ENABLE_PARQUET)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index e9469429d32..4539271cb5a 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -772,7 +772,7 @@ class CastChecks extends ExprChecks {
val binaryChecks: TypeSig = none
val sparkBinarySig: TypeSig = STRING + BINARY
- val decimalChecks: TypeSig = DECIMAL
+ val decimalChecks: TypeSig = DECIMAL + STRING
val sparkDecimalSig: TypeSig = numeric + BOOLEAN + TIMESTAMP + STRING
val calendarChecks: TypeSig = none
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
index eaa6224dfbe..fdc8472f1de 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala
@@ -381,6 +381,20 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite {
comparisonFunc = Some(compareStringifiedFloats))
}
+ test("ansi_cast decimal to string") {
+ val sqlCtx = SparkSession.getActiveSession.get.sqlContext
+ sqlCtx.setConf("spark.sql.legacy.allowNegativeScaleOfDecimal", "true")
+ sqlCtx.setConf("spark.rapids.sql.castDecimalToString.enabled", "true")
+
+ Seq(10, 15, 18).foreach { precision =>
+ Seq(-precision, -5, 0, 5, precision).foreach { scale =>
+ testCastToString(DataTypes.createDecimalType(precision, scale),
+ ansiMode = true,
+ comparisonFunc = Some(compareStringifiedDecimalsInSemantic))
+ }
+ }
+ }
+
private def castToStringExpectedFun[T]: T => Option[String] = (d: T) => Some(String.valueOf(d))
private def testCastToString[T](dataType: DataType, ansiMode: Boolean,
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
index e89f27bb07b..f145c56ff75 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala
@@ -228,6 +228,19 @@ class CastOpSuite extends GpuExpressionTestSuite {
testCastToString[Double](DataTypes.DoubleType, comparisonFunc = Some(compareStringifiedFloats))
}
+ test("cast decimal to string") {
+ val sqlCtx = SparkSession.getActiveSession.get.sqlContext
+ sqlCtx.setConf("spark.sql.legacy.allowNegativeScaleOfDecimal", "true")
+ sqlCtx.setConf("spark.rapids.sql.castDecimalToString.enabled", "true")
+
+ Seq(10, 15, 18).foreach { precision =>
+ Seq(-precision, -5, 0, 5, precision).foreach { scale =>
+ testCastToString(DataTypes.createDecimalType(precision, scale),
+ comparisonFunc = Some(compareStringifiedDecimalsInSemantic))
+ }
+ }
+ }
+
private def testCastToString[T](
dataType: DataType,
comparisonFunc: Option[(String, String) => Boolean] = None) {
@@ -481,6 +494,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 2),
scale = 2,
+ ansiEnabled = true,
customRandGenerator = Some(new scala.util.Random(1234L)))
// fromScale > toScale
@@ -489,6 +503,7 @@ class CastOpSuite extends GpuExpressionTestSuite {
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 10),
scale = 2,
+ ansiEnabled = true,
customRandGenerator = Some(new scala.util.Random(1234L)))
testCastToDecimal(DataTypes.createDecimalType(18, 18),
scale = 15,
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuExpressionTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuExpressionTestSuite.scala
index fa0ff51a891..f42d7e3c65f 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuExpressionTestSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuExpressionTestSuite.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020, NVIDIA CORPORATION.
+ * Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
package com.nvidia.spark.rapids
-import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType}
+import org.apache.spark.sql.types.{DataType, DataTypes, Decimal, DecimalType, StructType}
abstract class GpuExpressionTestSuite extends SparkQueryCompareTestSuite {
@@ -172,6 +172,11 @@ abstract class GpuExpressionTestSuite extends SparkQueryCompareTestSuite {
}
}
+ def compareStringifiedDecimalsInSemantic(expected: String, actual: String): Boolean = {
+ (expected == null && actual == null) ||
+ (expected != null && actual != null && Decimal(expected) == Decimal(actual))
+ }
+
private def getAs(column: RapidsHostColumnVector, index: Int, dataType: DataType): Option[Any] = {
if (column.isNullAt(index)) {
None