Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx committed Nov 6, 2020
1 parent 7a4bf25 commit 68d61f3
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,8 @@ private static DType toRapidsOrNull(DataType type) {
return DType.STRING;
} else if (type instanceof DecimalType) {
DecimalType decType = (DecimalType) type;
if (decType.precision() <= DType.DECIMAL32_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL32, -decType.scale());
} else if (decType.precision() <= DType.DECIMAL64_MAX_PRECISION) {
// Currently, maps all DecimalTypes to DType.DECIMAL64 to reduce the complexity.
if (decType.precision() <= DType.DECIMAL64_MAX_PRECISION) {
return DType.create(DType.DTypeEnum.DECIMAL64, -decType.scale());
} else {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

import java.util.ArrayList;
import java.util.Optional;
import java.math.BigDecimal;
import java.math.RoundingMode;

/**
* A GPU accelerated version of the Spark ColumnVector.
Expand Down Expand Up @@ -169,7 +169,9 @@ public ColumnarMap getMap(int ordinal) {

@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
throw new IllegalStateException("The decimal type is currently not supported by rapids cudf");
BigDecimal bigDec = cudfCv.getBigDecimal(rowId).setScale(scale, RoundingMode.UNNECESSARY);
assert bigDec.precision() <= precision;
return Decimal.apply(bigDec);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

import java.math.BigDecimal;
import java.math.RoundingMode;


/**
* A GPU accelerated version of the Spark ColumnVector.
Expand Down Expand Up @@ -112,7 +115,9 @@ public ColumnarMap getMap(int ordinal) {

@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
throw new IllegalStateException("The decimal type is currently not supported by rapids cudf");
BigDecimal bigDec = cudfCv.getBigDecimal(rowId).setScale(scale, RoundingMode.UNNECESSARY);
assert bigDec.precision() <= precision;
return Decimal.apply(bigDec);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,29 @@ object GpuOverrides {
}
}

def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupportedType)
// Set of expressions which supports DecimalType
private val decimalSupportedExpr: Set[String] = Set()

def areAllSupportedTypes[I <: B, B, O <: B](
types: Seq[DataType],
meta: Option[RapidsMeta[I, B, O]] = None): Boolean = {
if (types.exists(_.isInstanceOf[DecimalType])) {
// For DataTypes with specific RapidsMeta, we match all possible decimal unsupported
// RapidsMetas. The default ones are assumed to support decimal.
meta match {
case None => false
case Some(exprMeta: BaseExprMeta[I]) =>
val node = exprMeta.wrapped.asInstanceOf[Expression].nodeName
decimalSupportedExpr.contains(node) && types.forall(isSupportedTypeWithDecimal)
case Some(_: ScanMeta[I]) => false
case Some(_: DataWritingCommandMeta[I]) => false
case Some(_: PartMeta[I]) => false
case Some(_) => types.forall(isSupportedTypeWithDecimal)
}
} else {
types.forall(isSupportedType)
}
}

def isSupportedType(dataType: DataType): Boolean = dataType match {
case BooleanType => true
Expand All @@ -457,9 +479,9 @@ object GpuOverrides {
* A walkaround method to include DecimalType for expressions who supports Decimal.
*/
def isSupportedTypeWithDecimal(dataType: DataType): Boolean = dataType match {
case dt: DecimalType => dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION
case dt => isSupportedType(dt)
}
case dt: DecimalType => dt.precision <= ai.rapids.cudf.DType.DECIMAL64_MAX_PRECISION
case dt => isSupportedType(dt)
}

/**
* Checks to see if any expressions are a String Literal
Expand Down Expand Up @@ -786,14 +808,17 @@ object GpuOverrides {
expr[IsNull](
"Checks if a value is null",
(a, conf, p, r) => new UnaryExprMeta[IsNull](a, conf, p, r) {
override def areAllSupportedTypes(types: DataType*): Boolean = {
types.forall(isSupportedTypeWithDecimal)
}
override def convertToGpu(child: Expression): GpuExpression = GpuIsNull(child)
}),
expr[IsNotNull](
"Checks if a value is not null",
(a, conf, p, r) => new UnaryExprMeta[IsNotNull](a, conf, p, r) {
def isSupported(t: DataType) = t match {
case MapType(StringType, StringType, _) => true
case _ => isSupportedType(t)
case _ => isSupportedTypeWithDecimal(t)
}
override def areAllSupportedTypes(types: DataType*): Boolean = types.forall(isSupported)
override def convertToGpu(child: Expression): GpuExpression = GpuIsNotNull(child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](
* Check if all the types are supported in this Meta
*/
def areAllSupportedTypes(types: DataType*): Boolean = {
GpuOverrides.areAllSupportedTypes(types: _*)
GpuOverrides.areAllSupportedTypes(types, Some(this))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ object GpuScalar {
case b: Boolean => Scalar.fromBool(b)
case s: String => Scalar.fromString(s)
case s: UTF8String => Scalar.fromString(s.toString)
case dec: Decimal => Scalar.fromBigDecimal(dec.toBigDecimal.bigDecimal)
case dec: BigDecimal => Scalar.fromBigDecimal(dec.bigDecimal)
case dec: Decimal => Scalar.fromDecimal(dec.toBigDecimal.bigDecimal)
case dec: BigDecimal => Scalar.fromDecimal(dec.bigDecimal)
case _ =>
throw new IllegalStateException(s"${v.getClass} '${v}' is not supported as a scalar yet")
}
Expand All @@ -114,7 +114,7 @@ object GpuScalar {
if (bigDec.precision() > t.asInstanceOf[DecimalType].precision) {
throw new IllegalArgumentException(s"BigDecimal $bigDec exceeds precision constraint of $t")
}
Scalar.fromBigDecimal(bigDec)
Scalar.fromDecimal(bigDec)
case l: Long => t match {
case LongType => Scalar.fromLong(l)
case TimestampType => Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,9 @@ object InternalColumnarRddConverter extends Logging {

def convert(df: DataFrame): RDD[Table] = {
val schema = df.schema
if (!GpuOverrides.areAllSupportedTypes(schema.map(_.dataType) :_*)) {
val unsupported = schema.map(_.dataType).filter(!GpuOverrides.areAllSupportedTypes(_)).toSet
if (!GpuOverrides.areAllSupportedTypes(schema.map(_.dataType))) {
val unsupported = schema.map(_.dataType)
.filter(dt => !GpuOverrides.areAllSupportedTypes(Seq(dt))).toSet
throw new IllegalArgumentException(s"Cannot convert $df to GPU columnar $unsupported are " +
s"not currently supported data types for columnar.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
package com.nvidia.spark.rapids.unit

import scala.util.Random
import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.{GpuAlias, GpuLiteral, GpuOverrides, GpuScalar, GpuUnitTests, RapidsConf, TestUtils}

import ai.rapids.cudf.{ColumnVector, DType}
import com.nvidia.spark.rapids.{GpuAlias, GpuColumnVector, GpuIsNotNull, GpuIsNull, GpuLiteral, GpuOverrides, GpuScalar, GpuUnaryExpression, GpuUnitTests, RapidsConf, RapidsHostColumnVector, TestUtils}
import org.scalatest.Matchers

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Literal}
import org.apache.spark.sql.types.{Decimal, DecimalType}

class DecimalUnitTest extends GpuUnitTests with Matchers {
Random.setSeed(1234L)

private val dec32Data = Array.fill[Decimal](10)(
Decimal.fromDecimal(BigDecimal(Random.nextInt() / 10, Random.nextInt(5))))
Decimal.fromDecimal(BigDecimal(Random.nextInt() / 1000, 3 + Random.nextInt(3))))
private val dec64Data = Array.fill[Decimal](10)(
Decimal.fromDecimal(BigDecimal(Random.nextLong() / 1000, Random.nextInt(10))))
Decimal.fromDecimal(BigDecimal(Random.nextLong() / 1000, 7 + Random.nextInt(3))))

test("test decimal as scalar") {
Array(dec32Data, dec64Data).flatten.foreach { dec =>
Expand Down Expand Up @@ -64,6 +66,58 @@ class DecimalUnitTest extends GpuUnitTests with Matchers {
}
}

test("test decimal as column vector") {
withResource(
GpuColumnVector.from(ColumnVector.fromDecimals(dec32Data.map(_.toJavaBigDecimal): _*),
DecimalType(DType.DECIMAL32_MAX_PRECISION, 5))) { cv: GpuColumnVector =>

cv.getRowCount shouldEqual dec32Data.length
val (precision, scale) = cv.dataType() match {
case dt: DecimalType => (dt.precision, dt.scale)
}
withResource(cv.copyToHost()) { hostCv: RapidsHostColumnVector =>
dec32Data.zipWithIndex.foreach { case (dec, i) =>
val rescaled = dec.toJavaBigDecimal.setScale(scale)
hostCv.getInt(i) shouldEqual rescaled.unscaledValue().intValueExact()
hostCv.getDecimal(i, precision, scale) shouldEqual Decimal(rescaled)
}
}
}
val dec64WithNull = Array(null) ++ dec64Data.map(_.toJavaBigDecimal) ++ Array(null, null)
withResource(
GpuColumnVector.from(ColumnVector.fromDecimals(dec64WithNull: _*),
DecimalType(DType.DECIMAL64_MAX_PRECISION, 9))) { cv: GpuColumnVector =>
cv.getRowCount shouldEqual dec64WithNull.length
cv.hasNull shouldBe true
cv.numNulls() shouldEqual 3
val (precision, scale) = cv.dataType() match {
case dt: DecimalType => (dt.precision, dt.scale)
}
withResource(cv.copyToHost()) { hostCv: RapidsHostColumnVector =>
dec64WithNull.zipWithIndex.foreach {
case (dec, i) if dec == null =>
hostCv.getBase.isNull(i) shouldBe true
case (dec, i) =>
val rescaled = dec.setScale(scale)
hostCv.getLong(i) shouldEqual rescaled.unscaledValue().longValueExact()
hostCv.getDecimal(i, precision, scale) shouldEqual Decimal(rescaled)
}
}
}
// TODO: support fromScalar(cudf.ColumnVector cv, int rows) for fixed-point decimal in cuDF
/*
withResource(GpuScalar.from(dec64Data(0))) { scalar =>
withResource(GpuColumnVector.from(scalar, 10)) { cv =>
withResource(cv.copyToHost()) { hcv =>
(0 until 10).foreach { i =>
hcv.getDecimal(i, dec64Data(0).precision, dec64Data(0).scale) shouldEqual dec64Data(0)
}
}
}
}
*/
}

test("test basic expressions with decimal data") {
val rapidsConf = new RapidsConf(Map[String, String]())

Expand Down Expand Up @@ -97,4 +151,21 @@ class DecimalUnitTest extends GpuUnitTests with Matchers {
wrp.tagForGpu()
wrp.canExprTreeBeReplaced shouldBe false
}

test("test gpu null check operators with decimal data") {
val decArray = Array(BigDecimal(0).bigDecimal, null, BigDecimal(1).bigDecimal)
withResource(GpuColumnVector.from(ColumnVector.fromDecimals(decArray: _*), DecimalType(1, 0))
) { cv =>
withResource(GpuIsNull(null).doColumnar(cv).copyToHost()) { ret =>
ret.getBoolean(0) shouldBe false
ret.getBoolean(1) shouldBe true
ret.getBoolean(2) shouldBe false
}
withResource(GpuIsNotNull(null).doColumnar(cv).copyToHost()) { ret =>
ret.getBoolean(0) shouldBe true
ret.getBoolean(1) shouldBe false
ret.getBoolean(2) shouldBe true
}
}
}
}

0 comments on commit 68d61f3

Please sign in to comment.