Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx committed Nov 6, 2020
1 parent d9b1281 commit 7a4bf25
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm {
}

def isSupportedByCudf(schema: Seq[Attribute]): Boolean = {
schema.forall(a => GpuParquetScanBase.isSupportedType(a.dataType))
schema.forall(a => GpuColumnVector.isSupportedType(a.dataType))
}

/**
Expand Down
13 changes: 2 additions & 11 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.OrcFilters
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{DataType, DecimalType, StructType}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -113,20 +113,11 @@ object GpuOrcScanBase {
meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet")
}
schema.foreach { field =>
if (!isSupportedType(field.dataType)) {
if (!GpuColumnVector.isSupportedType(field.dataType)) {
meta.willNotWorkOnGpu(s"GpuOrcScan does not support fields of type ${field.dataType}")
}
}
}
// We need this specialized type check method because
// R/W ORC data with decimal columns has not supported by cuDF yet.
def isSupportedType(dataType: DataType): Boolean = {
GpuColumnVector.isSupportedType(dataType) match {
case false => false
case true if dataType.isInstanceOf[DecimalType] => false
case _ => true
}
}
}

case class GpuOrcPartitionReaderFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,17 @@ object GpuOverrides {
case DateType => true
case TimestampType => ZoneId.systemDefault().normalized() == GpuOverrides.UTC_TIMEZONE_ID
case StringType => true
case dt: DecimalType if dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION => true
case _ => false
}

/**
* A walkaround method to include DecimalType for expressions who supports Decimal.
*/
def isSupportedTypeWithDecimal(dataType: DataType): Boolean = dataType match {
case dt: DecimalType => dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION
case dt => isSupportedType(dt)
}

/**
* Checks to see if any expressions are a String Literal
*/
Expand Down Expand Up @@ -554,7 +561,7 @@ object GpuOverrides {
*/
override def areAllSupportedTypes(types: DataType*): Boolean = types.forall {
case CalendarIntervalType => true
case x => isSupportedType(x)
case x => isSupportedTypeWithDecimal(x)
}

}),
Expand All @@ -569,7 +576,7 @@ object GpuOverrides {
def isSupported(t: DataType) = t match {
case MapType(StringType, StringType, _) => true
case BinaryType => true
case _ => isSupportedType(t)
case _ => isSupportedTypeWithDecimal(t)
}
override def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupported)
override def convertToGpu(child: Expression): GpuExpression =
Expand All @@ -580,7 +587,7 @@ object GpuOverrides {
(att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) {
def isSupported(t: DataType) = t match {
case MapType(StringType, StringType, _) => true
case _ => isSupportedType(t)
case _ => isSupportedTypeWithDecimal(t)
}
override def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupported)
// This is the only NOOP operator. It goes away when things are bound
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.InputFileUtils
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{DataType, DecimalType, MapType, StringType, StructType, TimestampType}
import org.apache.spark.sql.types.{MapType, StringType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -175,16 +175,6 @@ object GpuParquetScanBase {
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
}
}

// We need this specialized type check method because
// R/W parquet data with decimal columns has not supported by cuDF yet.
def isSupportedType(dataType: DataType): Boolean = {
GpuColumnVector.isSupportedType(dataType) match {
case false => false
case true if dataType.isInstanceOf[DecimalType] => false
case _ => true
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ object GpuScalar {
case DType.TIMESTAMP_DAYS => v.getInt
case DType.TIMESTAMP_MICROSECONDS => v.getLong
case DType.STRING => v.getJavaString
case dt: DType if dt.isDecimalType => v.getBigDecimal
case dt: DType if dt.isDecimalType => Decimal(v.getBigDecimal)
case t => throw new IllegalStateException(s"$t is not a supported rapids scalar type yet")
}

Expand All @@ -89,6 +89,7 @@ object GpuScalar {
case b: Boolean => Scalar.fromBool(b)
case s: String => Scalar.fromString(s)
case s: UTF8String => Scalar.fromString(s.toString)
case dec: Decimal => Scalar.fromBigDecimal(dec.toBigDecimal.bigDecimal)
case dec: BigDecimal => Scalar.fromBigDecimal(dec.bigDecimal)
case _ =>
throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet")
Expand All @@ -98,6 +99,7 @@ object GpuScalar {
case _ if v == null => Scalar.fromNull(GpuColumnVector.getRapidsType(t))
case _ if t.isInstanceOf[DecimalType] =>
var bigDec = v match {
case vv: Decimal => vv.toBigDecimal.bigDecimal
case vv: BigDecimal => vv.bigDecimal
case vv: Double => BigDecimal(vv).bigDecimal
case vv: Float => BigDecimal(vv).bigDecimal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,85 @@

package com.nvidia.spark.rapids.unit

import java.math.{BigDecimal => BigDec}

import scala.util.Random

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.{GpuScalar, GpuUnitTests}
import com.nvidia.spark.rapids.{GpuAlias, GpuLiteral, GpuOverrides, GpuScalar, GpuUnitTests, RapidsConf, TestUtils}
import org.scalatest.Matchers

import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Literal}
import org.apache.spark.sql.types.{Decimal, DecimalType}

class DecimalUnitTest extends GpuUnitTests with Matchers {
Random.setSeed(1234L)

private val dec32Data = Array.fill[BigDecimal](10)(
BigDecimal(Random.nextInt() / 10, Random.nextInt(5)))
private val dec64Data = Array.fill[BigDecimal](10)(
BigDecimal(Random.nextLong() / 1000, Random.nextInt(10)))
private val dec32Data = Array.fill[Decimal](10)(
Decimal.fromDecimal(BigDecimal(Random.nextInt() / 10, Random.nextInt(5))))
private val dec64Data = Array.fill[Decimal](10)(
Decimal.fromDecimal(BigDecimal(Random.nextLong() / 1000, Random.nextInt(10))))

test("test decimal as scalar") {
Array(dec32Data, dec64Data).flatten.foreach { dec =>
// test GpuScalar.from(v: Any)
withResource(GpuScalar.from(dec)) { s =>
s.getType.getScale shouldEqual -dec.scale
GpuScalar.extract(s).asInstanceOf[BigDec] shouldEqual dec.bigDecimal
GpuScalar.extract(s).asInstanceOf[Decimal] shouldEqual dec
}
// test GpuScalar.from(v: Any, t: DataType)
val dt = DecimalType(DType.DECIMAL64_MAX_PRECISION, dec.scale)
val dbl = dec.doubleValue()
withResource(GpuScalar.from(dbl, dt)) { s =>
withResource(GpuScalar.from(dec.toDouble, dt)) { s =>
s.getType.getScale shouldEqual -dt.scale
GpuScalar.extract(s).asInstanceOf[BigDec].doubleValue() shouldEqual dbl
GpuScalar.extract(s).asInstanceOf[Decimal].toDouble shouldEqual dec.toDouble
}
val str = dec.toString()
withResource(GpuScalar.from(str, dt)) { s =>
withResource(GpuScalar.from(dec.toString(), dt)) { s =>
s.getType.getScale shouldEqual -dt.scale
GpuScalar.extract(s).asInstanceOf[BigDec].toString shouldEqual str
GpuScalar.extract(s).asInstanceOf[Decimal].toString shouldEqual dec.toString()
}
val long = dec.longValue()
withResource(GpuScalar.from(long, DecimalType(DType.DECIMAL64_MAX_PRECISION, 0))) { s =>
val long = dec.toLong
withResource(GpuScalar.from(long, DecimalType(dec.precision, 0))) { s =>
s.getType.getScale shouldEqual 0
GpuScalar.extract(s).asInstanceOf[BigDec].longValue() shouldEqual long
GpuScalar.extract(s).asInstanceOf[Decimal].toLong shouldEqual long
}
}
// test exception throwing
assertThrows[IllegalStateException] {
withResource(GpuScalar.from(true, DecimalType(10, 1))) { _ => }
}
assertThrows[IllegalArgumentException] {
val bigDec = BigDecimal(Long.MaxValue / 100, 0)
val bigDec = Decimal(BigDecimal(Long.MaxValue / 100, 0))
withResource(GpuScalar.from(bigDec, DecimalType(15, 1))) { _ => }
}
}

test("test basic expressions with decimal data") {
val rapidsConf = new RapidsConf(Map[String, String]())

val cpuLit = Literal(dec32Data(0), DecimalType(dec32Data(0).precision, dec32Data(0).scale))
val wrapperLit = GpuOverrides.wrapExpr(cpuLit, rapidsConf, None)
wrapperLit.tagForGpu()
wrapperLit.canExprTreeBeReplaced shouldBe true
val gpuLit = wrapperLit.convertToGpu().asInstanceOf[GpuLiteral]
gpuLit.columnarEval(null) shouldEqual cpuLit.eval(null)
gpuLit.sql shouldEqual cpuLit.sql

val cpuAlias = Alias(cpuLit, "A")()
val wrapperAlias = GpuOverrides.wrapExpr(cpuAlias, rapidsConf, None)
wrapperAlias.tagForGpu()
wrapperAlias.canExprTreeBeReplaced shouldBe true
val gpuAlias = wrapperAlias.convertToGpu().asInstanceOf[GpuAlias]
gpuAlias.dataType shouldEqual cpuAlias.dataType
gpuAlias.sql shouldEqual cpuAlias.sql
gpuAlias.columnarEval(null) shouldEqual cpuAlias.eval(null)

val cpuAttrRef = AttributeReference("test123", cpuLit.dataType)()
val wrapperAttrRef = GpuOverrides.wrapExpr(cpuAttrRef, rapidsConf, None)
wrapperAttrRef.tagForGpu()
wrapperAttrRef.canExprTreeBeReplaced shouldBe true
val gpuAttrRef = wrapperAttrRef.convertToGpu().asInstanceOf[AttributeReference]
gpuAttrRef.sql shouldEqual cpuAttrRef.sql
gpuAttrRef.sameRef(cpuAttrRef) shouldBe true

// inconvertible because of precision overflow
val wrp = GpuOverrides.wrapExpr(Literal(Decimal(12345L), DecimalType(38, 10)), rapidsConf, None)
wrp.tagForGpu()
wrp.canExprTreeBeReplaced shouldBe false
}
}

0 comments on commit 7a4bf25

Please sign in to comment.