-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Initial work on supporting DecimalType #1063
Changes from all commits
2a7f4cd
5e95ded
bf22f07
c5aa897
d271c29
c0c84f4
a2f59e1
1585a2e
6f381d0
36a5a27
6535b74
ed6c8ea
384c166
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,6 +91,8 @@ private object GpuRowToColumnConverter { | |
case (TimestampType, false) => NotNullLongConverter | ||
case (StringType, true) => StringConverter | ||
case (StringType, false) => NotNullStringConverter | ||
case (dt: DecimalType, true) => new DecimalConverter(dt.precision, dt.scale) | ||
case (dt: DecimalType, false) => new NotNullDecimalConverter(dt.precision, dt.scale) | ||
// NOT SUPPORTED YET | ||
// case CalendarIntervalType => CalendarConverter | ||
case (at: ArrayType, true) => | ||
|
@@ -100,8 +102,6 @@ private object GpuRowToColumnConverter { | |
// NOT SUPPORTED YET | ||
// case st: StructType => new StructConverter(st.fields.map( | ||
// (f) => getConverterForType(f.dataType))) | ||
// NOT SUPPORTED YET | ||
// case dt: DecimalType => new DecimalConverter(dt) | ||
// NOT SUPPORTED YET | ||
case (MapType(k, v, vcn), true) => | ||
MapConverter(getConverterForType(k, nullable = false), | ||
|
@@ -289,6 +289,32 @@ private object GpuRowToColumnConverter { | |
} | ||
} | ||
|
||
private class DecimalConverter(precision: Int, scale: Int) extends TypeConverter { | ||
override def append( | ||
row: SpecializedGetters, | ||
column: Int, | ||
builder: ai.rapids.cudf.HostColumnVector.ColumnBuilder): Double = { | ||
if (row.isNullAt(column)) { | ||
builder.appendNull() | ||
} else { | ||
new NotNullDecimalConverter(precision, scale).append(row, column, builder) | ||
} | ||
// Infer the storage type via precision, because we can't access DType of builder. | ||
(if (precision > ai.rapids.cudf.DType.DECIMAL32_MAX_PRECISION) 8 else 4) + VALIDITY | ||
} | ||
} | ||
|
||
private class NotNullDecimalConverter(precision: Int, scale: Int) extends TypeConverter { | ||
override def append( | ||
row: SpecializedGetters, | ||
column: Int, | ||
builder: ai.rapids.cudf.HostColumnVector.ColumnBuilder): Double = { | ||
builder.append(row.getDecimal(column, precision, scale).toJavaBigDecimal) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer to see us use |
||
// Infer the storage type via precision, because we can't access DType of builder. | ||
if (precision > ai.rapids.cudf.DType.DECIMAL32_MAX_PRECISION) 8 else 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too it would be great to avoid conditionals in the data path. I would prefer it if we either passed in the size or had separate implementations for DECIMAL32 and DECIMAL64 |
||
} | ||
} | ||
|
||
private[this] def mapConvert( | ||
keyConverter: TypeConverter, | ||
valueConverter: TypeConverter, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -434,6 +434,11 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT, | |
override val childParts: Seq[PartMeta[_]] = Seq.empty | ||
override val childDataWriteCmds: Seq[DataWritingCommandMeta[_]] = Seq.empty | ||
|
||
// We assume that all common plans are decimal supportable by default, considering | ||
// whether decimal allowable is mainly determined in expression-level. | ||
override def isSupportedType(t: DataType): Boolean = | ||
GpuOverrides.isSupportedType(t, allowDecimal = true) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have tests that verify that we can support decimal for all top level spark operations? Have we tested join, expand, generate, filter, project, union, window, sort, or hash agregate? What about all of the arrow python UDF code where we go to/from arrow? I think it would be much better if we split this big PR up into smaller pieces and put each piece in separately with corresponding tests to show that it works, and we only add decimal to the allow list for those things that we know it works for because we have tested it. If you want me to help with this I am happy to do it. I am already in the middle of doing it for Lists I am going to add in structs, maps, binary, null type and finally calendar interval based off of how much time I have and priorities. Some of these we will only be able to do very basic things with, but that should be enough to unblock others for using them for more complicated processing. |
||
|
||
override def convertToCpu(): SparkPlan = { | ||
wrapped.withNewChildren(childPlans.map(_.convertIfNeeded())) | ||
} | ||
|
@@ -765,9 +770,13 @@ abstract class BinaryExprMeta[INPUT <: BinaryExpression]( | |
expr: INPUT, | ||
conf: RapidsConf, | ||
parent: Option[RapidsMeta[_, _, _]], | ||
rule: ConfKeysAndIncompat) | ||
rule: ConfKeysAndIncompat, | ||
allowDecimal: Boolean = false) | ||
extends ExprMeta[INPUT](expr, conf, parent, rule) { | ||
|
||
override def isSupportedType(t: DataType): Boolean = | ||
GpuOverrides.isSupportedType(t, allowDecimal = allowDecimal) | ||
|
||
override final def convertToGpu(): GpuExpression = | ||
convertToGpu(childExprs(0).convertToGpu(), childExprs(1).convertToGpu()) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,6 +144,8 @@ object GpuDivModLike { | |
case DType.INT64 => Scalar.fromLong(0L) | ||
case DType.FLOAT32 => Scalar.fromFloat(0f) | ||
case DType.FLOAT64 => Scalar.fromDouble(0) | ||
case dt if dt.isDecimalType && dt.isBackedByInt => Scalar.fromDecimal(0, 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this is going to work in all cases. I think this might be another place were we have some tech debt to pay off and need to pass in a DataType instead of a DType. |
||
case dt if dt.isDecimalType && dt.isBackedByLong => Scalar.fromDecimal(0, 0L) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the div/mod to work properly we also need to update |
||
case t => throw new IllegalArgumentException(s"Unexpected type: $t") | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is on the data path. I would like us to avoid object creation if at all possible to speed up the data path. Please make a static method instead, or use inheritance, which I think is less ideal.