diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index eb89240679b4..c7bc4cfb7473 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -163,9 +163,8 @@ private static DType toRapidsOrNull(DataType type) { return DType.STRING; } else if (type instanceof DecimalType) { DecimalType decType = (DecimalType) type; - if (decType.precision() <= DType.DECIMAL32_MAX_PRECISION) { - return DType.create(DType.DTypeEnum.DECIMAL32, -decType.scale()); - } else if (decType.precision() <= DType.DECIMAL64_MAX_PRECISION) { + // Currently, maps all DecimalTypes to DType.DECIMAL64 to reduce the complexity. + if (decType.precision() <= DType.DECIMAL64_MAX_PRECISION) { return DType.create(DType.DTypeEnum.DECIMAL64, -decType.scale()); } else { return null; diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java index 65cdb11fa627..ec76cf179ba0 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVector.java @@ -29,8 +29,8 @@ import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; -import java.util.ArrayList; -import java.util.Optional; +import java.math.BigDecimal; +import java.math.RoundingMode; /** * A GPU accelerated version of the Spark ColumnVector. @@ -169,7 +169,9 @@ public ColumnarMap getMap(int ordinal) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { - throw new IllegalStateException("The decimal type is currently not supported by rapids cudf"); + BigDecimal bigDec = cudfCv.getBigDecimal(rowId).setScale(scale, RoundingMode.UNNECESSARY); + assert bigDec.precision() <= precision; + return Decimal.apply(bigDec); } @Override diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java index 001d856903ab..02dd06c38da2 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/RapidsHostColumnVectorCore.java @@ -24,6 +24,9 @@ import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; +import java.math.BigDecimal; +import java.math.RoundingMode; + /** * A GPU accelerated version of the Spark ColumnVector. @@ -112,7 +115,9 @@ public ColumnarMap getMap(int ordinal) { @Override public Decimal getDecimal(int rowId, int precision, int scale) { - throw new IllegalStateException("The decimal type is currently not supported by rapids cudf"); + BigDecimal bigDec = cudfCv.getBigDecimal(rowId).setScale(scale, RoundingMode.UNNECESSARY); + assert bigDec.precision() <= precision; + return Decimal.apply(bigDec); } @Override diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index eed990a232c9..3925fa7f5c14 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -437,7 +437,29 @@ object GpuOverrides { } } - def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupportedType) + // Set of expressions which supports DecimalType + private val decimalSupportedExpr: Set[String] = Set() + + def areAllSupportedTypes[I <: B, B, O <: B]( + types: Seq[DataType], + meta: Option[RapidsMeta[I, B, O]] = None): Boolean = { + if (types.exists(_.isInstanceOf[DecimalType])) { + // For DataTypes with specific RapidsMeta, we match all possible decimal unsupported + // RapidsMetas. The default ones are assumed to support decimal. + meta match { + case None => false + case Some(exprMeta: BaseExprMeta[I]) => + val node = exprMeta.wrapped.asInstanceOf[Expression].nodeName + decimalSupportedExpr.contains(node) && types.forall(isSupportedTypeWithDecimal) + case Some(_: ScanMeta[I]) => false + case Some(_: DataWritingCommandMeta[I]) => false + case Some(_: PartMeta[I]) => false + case Some(_) => types.forall(isSupportedTypeWithDecimal) + } + } else { + types.forall(isSupportedType) + } + } def isSupportedType(dataType: DataType): Boolean = dataType match { case BooleanType => true @@ -457,9 +479,9 @@ object GpuOverrides { * A walkaround method to include DecimalType for expressions who supports Decimal. */ def isSupportedTypeWithDecimal(dataType: DataType): Boolean = dataType match { - case dt: DecimalType => dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION - case dt => isSupportedType(dt) - } + case dt: DecimalType => dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION + case dt => isSupportedType(dt) + } /** * Checks to see if any expressions are a String Literal @@ -786,6 +808,9 @@ object GpuOverrides { expr[IsNull]( "Checks if a value is null", (a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) { + override def areAllSupportedTypes(types: DataType*): Boolean = { + types.forall(isSupportedTypeWithDecimal) + } override def convertToGpu(child: Expression): GpuExpression = GpuIsNull(child) }), expr[IsNotNull]( @@ -793,7 +818,7 @@ object GpuOverrides { (a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) { def isSupported(t: DataType) = t match { case MapType(StringType, StringType, _) => true - case _ => isSupportedType(t) + case _ => isSupportedTypeWithDecimal(t) } override def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupported) override def convertToGpu(child: Expression): GpuExpression = GpuIsNotNull(child) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index f3f5b7425eb9..93b7753a12b1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -103,7 +103,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE]( * Check if all the types are supported in this Meta */ def areAllSupportedTypes(types: DataType*): Boolean = { - GpuOverrides.areAllSupportedTypes(types: _*) + GpuOverrides.areAllSupportedTypes(types, Some(this)) } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index ddc1d708260c..f5f961ca507a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -89,8 +89,8 @@ object GpuScalar { case b: Boolean => Scalar.fromBool(b) case s: String => Scalar.fromString(s) case s: UTF8String => Scalar.fromString(s.toString) - case dec: Decimal => Scalar.fromBigDecimal(dec.toBigDecimal.bigDecimal) - case dec: BigDecimal => Scalar.fromBigDecimal(dec.bigDecimal) + case dec: Decimal => Scalar.fromDecimal(dec.toBigDecimal.bigDecimal) + case dec: BigDecimal => Scalar.fromDecimal(dec.bigDecimal) case _ => throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet") } @@ -114,7 +114,7 @@ object GpuScalar { if (bigDec.precision() > t.asInstanceOf[DecimalType].precision) { throw new IllegalArgumentException(s"BigDecimal $bigDec exceeds precision constraint of $t") } - Scalar.fromBigDecimal(bigDec) + Scalar.fromDecimal(bigDec) case l: Long => t match { case LongType => Scalar.fromLong(l) case TimestampType => Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, l) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRddConverter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRddConverter.scala index 361f3ddca700..e4af13b51148 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRddConverter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/InternalColumnarRddConverter.scala @@ -474,8 +474,9 @@ object InternalColumnarRddConverter extends Logging { def convert(df: DataFrame): RDD[Table] = { val schema = df.schema - if (!GpuOverrides.areAllSupportedTypes(schema.map(_.dataType) :_*)) { - val unsupported = schema.map(_.dataType).filter(!GpuOverrides.areAllSupportedTypes(_)).toSet + if (!GpuOverrides.areAllSupportedTypes(schema.map(_.dataType))) { + val unsupported = schema.map(_.dataType) + .filter(dt => !GpuOverrides.areAllSupportedTypes(Seq(dt))).toSet throw new IllegalArgumentException(s"Cannot convert $df to GPU columnar $unsupported are " + s"not currently supported data types for columnar.") } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala b/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala index 49974ea9fd20..5c933623ba23 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/unit/DecimalUnitTest.scala @@ -17,9 +17,11 @@ package com.nvidia.spark.rapids.unit import scala.util.Random -import ai.rapids.cudf.DType -import com.nvidia.spark.rapids.{GpuAlias, GpuLiteral, GpuOverrides, GpuScalar, GpuUnitTests, RapidsConf, TestUtils} + +import ai.rapids.cudf.{ColumnVector, DType} +import com.nvidia.spark.rapids.{GpuAlias, GpuColumnVector, GpuIsNotNull, GpuIsNull, GpuLiteral, GpuOverrides, GpuScalar, GpuUnaryExpression, GpuUnitTests, RapidsConf, RapidsHostColumnVector, TestUtils} import org.scalatest.Matchers + import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Literal} import org.apache.spark.sql.types.{Decimal, DecimalType} @@ -27,9 +29,9 @@ class DecimalUnitTest extends GpuUnitTests with Matchers { Random.setSeed(1234L) private val dec32Data = Array.fill[Decimal](10)( - Decimal.fromDecimal(BigDecimal(Random.nextInt() / 10, Random.nextInt(5)))) + Decimal.fromDecimal(BigDecimal(Random.nextInt() / 1000, 3 + Random.nextInt(3)))) private val dec64Data = Array.fill[Decimal](10)( - Decimal.fromDecimal(BigDecimal(Random.nextLong() / 1000, Random.nextInt(10)))) + Decimal.fromDecimal(BigDecimal(Random.nextLong() / 1000, 7 + Random.nextInt(3)))) test("test decimal as scalar") { Array(dec32Data, dec64Data).flatten.foreach { dec => @@ -64,6 +66,58 @@ class DecimalUnitTest extends GpuUnitTests with Matchers { } } + test("test decimal as column vector") { + withResource( + GpuColumnVector.from(ColumnVector.fromDecimals(dec32Data.map(_.toJavaBigDecimal): _*), + DecimalType(DType.DECIMAL32_MAX_PRECISION, 5))) { cv: GpuColumnVector => + + cv.getRowCount shouldEqual dec32Data.length + val (precision, scale) = cv.dataType() match { + case dt: DecimalType => (dt.precision, dt.scale) + } + withResource(cv.copyToHost()) { hostCv: RapidsHostColumnVector => + dec32Data.zipWithIndex.foreach { case (dec, i) => + val rescaled = dec.toJavaBigDecimal.setScale(scale) + hostCv.getInt(i) shouldEqual rescaled.unscaledValue().intValueExact() + hostCv.getDecimal(i, precision, scale) shouldEqual Decimal(rescaled) + } + } + } + val dec64WithNull = Array(null) ++ dec64Data.map(_.toJavaBigDecimal) ++ Array(null, null) + withResource( + GpuColumnVector.from(ColumnVector.fromDecimals(dec64WithNull: _*), + DecimalType(DType.DECIMAL64_MAX_PRECISION, 9))) { cv: GpuColumnVector => + cv.getRowCount shouldEqual dec64WithNull.length + cv.hasNull shouldBe true + cv.numNulls() shouldEqual 3 + val (precision, scale) = cv.dataType() match { + case dt: DecimalType => (dt.precision, dt.scale) + } + withResource(cv.copyToHost()) { hostCv: RapidsHostColumnVector => + dec64WithNull.zipWithIndex.foreach { + case (dec, i) if dec == null => + hostCv.getBase.isNull(i) shouldBe true + case (dec, i) => + val rescaled = dec.setScale(scale) + hostCv.getLong(i) shouldEqual rescaled.unscaledValue().longValueExact() + hostCv.getDecimal(i, precision, scale) shouldEqual Decimal(rescaled) + } + } + } + // TODO: support fromScalar(cudf.ColumnVector cv, int rows) for fixed-point decimal in cuDF + /* + withResource(GpuScalar.from(dec64Data(0))) { scalar => + withResource(GpuColumnVector.from(scalar, 10)) { cv => + withResource(cv.copyToHost()) { hcv => + (0 until 10).foreach { i => + hcv.getDecimal(i, dec64Data(0).precision, dec64Data(0).scale) shouldEqual dec64Data(0) + } + } + } + } + */ + } + test("test basic expressions with decimal data") { val rapidsConf = new RapidsConf(Map[String, String]()) @@ -97,4 +151,21 @@ class DecimalUnitTest extends GpuUnitTests with Matchers { wrp.tagForGpu() wrp.canExprTreeBeReplaced shouldBe false } + + test("test gpu null check operators with decimal data") { + val decArray = Array(BigDecimal(0).bigDecimal, null, BigDecimal(1).bigDecimal) + withResource(GpuColumnVector.from(ColumnVector.fromDecimals(decArray: _*), DecimalType(1, 0)) + ) { cv => + withResource(GpuIsNull(null).doColumnar(cv).copyToHost()) { ret => + ret.getBoolean(0) shouldBe false + ret.getBoolean(1) shouldBe true + ret.getBoolean(2) shouldBe false + } + withResource(GpuIsNotNull(null).doColumnar(cv).copyToHost()) { ret => + ret.getBoolean(0) shouldBe true + ret.getBoolean(1) shouldBe false + ret.getBoolean(2) shouldBe true + } + } + } }