Skip to content

Commit

Permalink
introduce DecimalType and decimal scalar
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx committed Nov 6, 2020
1 parent 857074d commit d9b1281
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class ParquetCachedBatchSerializer extends CachedBatchSerializer with Arm {
}

def isSupportedByCudf(schema: Seq[Attribute]): Boolean = {
schema.forall(a => GpuColumnVector.isSupportedType(a.dataType))
schema.forall(a => GpuParquetScanBase.isSupportedType(a.dataType))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ private static DType toRapidsOrNull(DataType type) {
return DType.TIMESTAMP_MICROSECONDS;
} else if (type instanceof StringType) {
return DType.STRING;
} else if (type instanceof DecimalType) {
DecimalType decType = (DecimalType) type;
if (decType.precision() <= DType.DECIMAL32_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL32, -decType.scale());
} else if (decType.precision() <= DType.DECIMAL64_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL64, -decType.scale());
} else {
return null;
}
}
return null;
}
Expand Down Expand Up @@ -204,6 +213,10 @@ static DataType getSparkType(DType type) {
return DataTypes.TimestampType;
case STRING:
return DataTypes.StringType;
case DECIMAL32:
return new DecimalType(DType.DECIMAL32_MAX_PRECISION, -type.getScale());
case DECIMAL64:
return new DecimalType(DType.DECIMAL64_MAX_PRECISION, -type.getScale());
default:
throw new IllegalArgumentException(type + " is not supported by spark yet.");
}
Expand Down
13 changes: 11 additions & 2 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.OrcFilters
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataType, DecimalType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -113,11 +113,20 @@ object GpuOrcScanBase {
meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet")
}
schema.foreach { field =>
if (!GpuColumnVector.isSupportedType(field.dataType)) {
if (!isSupportedType(field.dataType)) {
meta.willNotWorkOnGpu(s"GpuOrcScan does not support fields of type ${field.dataType}")
}
}
}
// We need this specialized type check method because
// R/W ORC data with decimal columns has not supported by cuDF yet.
def isSupportedType(dataType: DataType): Boolean = {
GpuColumnVector.isSupportedType(dataType) match {
case false => false
case true if dataType.isInstanceOf[DecimalType] => false
case _ => true
}
}
}

case class GpuOrcPartitionReaderFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ object GpuOverrides {
case DateType => true
case TimestampType => ZoneId.systemDefault().normalized() == GpuOverrides.UTC_TIMEZONE_ID
case StringType => true
case dt: DecimalType if dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.InputFileUtils
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{MapType, StringType, StructType, TimestampType}
import org.apache.spark.sql.types.{DataType, DecimalType, MapType, StringType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -175,6 +175,16 @@ object GpuParquetScanBase {
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
}
}

// We need this specialized type check method because
// R/W parquet data with decimal columns has not supported by cuDF yet.
def isSupportedType(dataType: DataType): Boolean = {
GpuColumnVector.isSupportedType(dataType) match {
case false => false
case true if dataType.isInstanceOf[DecimalType] => false
case _ => true
}
}
}

/**
Expand Down
19 changes: 19 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ object GpuScalar {
case DType.TIMESTAMP_DAYS => v.getInt
case DType.TIMESTAMP_MICROSECONDS => v.getLong
case DType.STRING => v.getJavaString
case dt: DType if dt.isDecimalType => v.getBigDecimal
case t => throw new IllegalStateException(s"$t is not a supported rapids scalar type yet")
}

Expand All @@ -88,12 +89,30 @@ object GpuScalar {
case b: Boolean => Scalar.fromBool(b)
case s: String => Scalar.fromString(s)
case s: UTF8String => Scalar.fromString(s.toString)
case dec: BigDecimal => Scalar.fromBigDecimal(dec.bigDecimal)
case _ =>
throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet")
}

def from(v: Any, t: DataType): Scalar = v match {
case _ if v == null => Scalar.fromNull(GpuColumnVector.getRapidsType(t))
case _ if t.isInstanceOf[DecimalType] =>
var bigDec = v match {
case vv: BigDecimal => vv.bigDecimal
case vv: Double => BigDecimal(vv).bigDecimal
case vv: Float => BigDecimal(vv).bigDecimal
case vv: String => BigDecimal(vv).bigDecimal
case vv: Double => BigDecimal(vv).bigDecimal
case vv: Long => BigDecimal(vv).bigDecimal
case vv: Int => BigDecimal(vv).bigDecimal
case vv => throw new IllegalStateException(
s"${vv.getClass} '${vv}' is not supported as a scalar yet")
}
bigDec = bigDec.setScale(t.asInstanceOf[DecimalType].scale)
if (bigDec.precision() > t.asInstanceOf[DecimalType].precision) {
throw new IllegalArgumentException(s"BigDecimal $bigDec exceeds precision constraint of $t")
}
Scalar.fromBigDecimal(bigDec)
case l: Long => t match {
case LongType => Scalar.fromLong(l)
case TimestampType => Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, l)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.unit

import java.math.{BigDecimal => BigDec}

import scala.util.Random

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.{GpuScalar, GpuUnitTests}
import org.scalatest.Matchers

import org.apache.spark.sql.types.DecimalType

class DecimalUnitTest extends GpuUnitTests with Matchers {
Random.setSeed(1234L)

private val dec32Data = Array.fill[BigDecimal](10)(
BigDecimal(Random.nextInt() / 10, Random.nextInt(5)))
private val dec64Data = Array.fill[BigDecimal](10)(
BigDecimal(Random.nextLong() / 1000, Random.nextInt(10)))

test("test decimal as scalar") {
Array(dec32Data, dec64Data).flatten.foreach { dec =>
// test GpuScalar.from(v: Any)
withResource(GpuScalar.from(dec)) { s =>
s.getType.getScale shouldEqual -dec.scale
GpuScalar.extract(s).asInstanceOf[BigDec] shouldEqual dec.bigDecimal
}
// test GpuScalar.from(v: Any, t: DataType)
val dt = DecimalType(DType.DECIMAL64_MAX_PRECISION, dec.scale)
val dbl = dec.doubleValue()
withResource(GpuScalar.from(dbl, dt)) { s =>
s.getType.getScale shouldEqual -dt.scale
GpuScalar.extract(s).asInstanceOf[BigDec].doubleValue() shouldEqual dbl
}
val str = dec.toString()
withResource(GpuScalar.from(str, dt)) { s =>
s.getType.getScale shouldEqual -dt.scale
GpuScalar.extract(s).asInstanceOf[BigDec].toString shouldEqual str
}
val long = dec.longValue()
withResource(GpuScalar.from(long, DecimalType(DType.DECIMAL64_MAX_PRECISION, 0))) { s =>
s.getType.getScale shouldEqual 0
GpuScalar.extract(s).asInstanceOf[BigDec].longValue() shouldEqual long
}
}
// test exception throwing
assertThrows[IllegalStateException] {
withResource(GpuScalar.from(true, DecimalType(10, 1))) { _ => }
}
assertThrows[IllegalArgumentException] {
val bigDec = BigDecimal(Long.MaxValue / 100, 0)
withResource(GpuScalar.from(bigDec, DecimalType(15, 1))) { _ => }
}
}
}

0 comments on commit d9b1281

Please sign in to comment.