Skip to content

Commit

Permalink
Support decimal type in orc reader (NVIDIA#3239)
Browse files Browse the repository at this point in the history
* Support decimal type in orc reader

Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman authored Aug 20, 2021
1 parent 0f2d575 commit d3ae0d0
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -21513,11 +21513,11 @@ dates or timestamps, or for a lack of type coercion support.
<td>S</td>
<td><em>PS<br/>UTC is only supported TZ for TIMESTAMP</em></td>
<td>S</td>
<td><b>NS</b></td>
<td><em>PS<br/>max DECIMAL precision of 18</em></td>
<td> </td>
<td><b>NS</b></td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested DECIMAL, BINARY, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>max nested DECIMAL precision of 18;<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested BINARY, MAP, STRUCT, UDT</em></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><b>NS</b></td>
Expand Down
4 changes: 3 additions & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,11 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
double_gens = [double_gen]
double_n_long_gens = [double_gen, long_gen]
int_n_long_gens = [int_gen, long_gen]
decimal_gens = [decimal_gen_default, decimal_gen_neg_scale, decimal_gen_scale_precision,
decimal_gens_no_neg = [decimal_gen_default, decimal_gen_scale_precision,
decimal_gen_same_scale_precision, decimal_gen_64bit]

decimal_gens = [decimal_gen_neg_scale] + decimal_gens_no_neg

# all of the basic gens
all_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, date_gen, timestamp_gen, null_gen]
Expand Down
13 changes: 11 additions & 2 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,23 @@ def test_basic_read(std_input_path, name, read_func, v1_enabled_list, orc_impl,
read_func(std_input_path + '/' + name),
conf=all_confs)

# ORC does not support negative scale for decimal. So here is "decimal_gens_no_neg".
# Otherwsie it will get the below exception.
# ...
#E Caused by: java.lang.IllegalArgumentException: Missing integer at
# 'struct<`_c0`:decimal(7,^-3),`_c1`:decimal(7,3),`_c2`:decimal(7,7),`_c3`:decimal(12,2)>'
#E at org.apache.orc.TypeDescription.parseInt(TypeDescription.java:244)
#E at org.apache.orc.TypeDescription.parseType(TypeDescription.java:362)
# ...
orc_basic_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)),
TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))]
TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))] + decimal_gens_no_neg

# Some array gens, but not all because of nesting
orc_array_gens_sample = [ArrayGen(sub_gen) for sub_gen in orc_basic_gens] + [
ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10),
ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10)]
ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10),
ArrayGen(ArrayGen(decimal_gen_default, max_length=10), max_length=10)]

orc_gens_list = [orc_basic_gens,
orc_array_gens_sample,
Expand Down
63 changes: 59 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.OrcFilters
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{Decimal, DecimalType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -381,6 +381,55 @@ trait OrcCommonFunctions extends OrcCodecWritingHelper {
rawOut.write(postScriptLength.toInt)
}

protected def getPrecisionsList(types: Seq[TypeDescription]): Seq[Int] = {
types.flatMap { t =>
t.getCategory match {
case TypeDescription.Category.DECIMAL => Seq(t.getPrecision)
case c if !c.isPrimitive => getPrecisionsList(t.getChildren.asScala)
case _ => Seq.empty[Int]
}
}
}

/**
* Cast columns with precision that can be stored in an int to DECIMAL32, to save space.
*
* @param table the input table, will be closed after returning.
* @param schema the schema of the table
* @param readSchema the read schema from Spark
* @return a new table with cast columns
*/
protected def typeCastForDecimalIfNeededAndClose(table: Table, schema: Seq[TypeDescription],
readSchema: StructType): Table = {
assert(table.getNumberOfColumns == schema.length)
// 'readSchema' may have more columns than 'schema', but for now ORC reader does not support
// this case, being tracked by the issue: https://github.com/NVIDIA/spark-rapids/issues/3058.
assert(table.getNumberOfColumns == readSchema.length)

// check if there are cols with precision that can be stored in an int
val typeCastingNeeded = getPrecisionsList(schema).exists(p => p <= Decimal.MAX_INT_DIGITS)
if (typeCastingNeeded) {
withResource(table) { t =>
withResource(new Array[ColumnVector](t.getNumberOfColumns)) { newCols =>
(0 until t.getNumberOfColumns).foreach { id =>
val readField = readSchema(id)
val origCol = t.getColumn(id)
val newCol = ColumnCastUtil.ifTrueThenDeepConvertTypeAtoTypeB(origCol,
readField.dataType,
(dt, cv) => cv.getType.isDecimalType &&
!GpuColumnVector.getNonNestedRapidsType(dt).equals(cv.getType()),
(dt, cv) =>
cv.castTo(DecimalUtil.createCudfDecimal(dt.asInstanceOf[DecimalType])))
newCols(id) = newCol
}
new Table(newCols: _*)
}
}
} else {
table
}
}

}

/**
Expand Down Expand Up @@ -630,7 +679,9 @@ class GpuOrcPartitionReader(
s"but read $numColumns from $partFile")
}
metrics(NUM_OUTPUT_BATCHES) += 1
Some(table)
val colTypes = ctx.updatedReadSchema.getChildren.asScala.toArray
val tableSchema = ctx.requestedMapping.map(_.map(colTypes(_))).getOrElse(colTypes)
Some(typeCastForDecimalIfNeededAndClose(table, tableSchema, readDataSchema))
}
} finally {
if (dataBuffer != null) {
Expand Down Expand Up @@ -1398,7 +1449,9 @@ class MultiFileCloudOrcPartitionReader(
}

metrics(NUM_OUTPUT_BATCHES) += 1
Some(table)
val colTypes = updatedReadSchema.getChildren.asScala.toArray
val tableSchema = requestedMapping.map(_.map(colTypes(_))).getOrElse(colTypes)
Some(typeCastForDecimalIfNeededAndClose(table, tableSchema, readDataSchema))
}

withResource(table) { _ =>
Expand Down Expand Up @@ -1777,7 +1830,9 @@ class MultiFileOrcPartitionReader(
}

metrics(NUM_OUTPUT_BATCHES) += 1
table
val colTypes = clippedSchema.getChildren.asScala.toArray
val tableSchema = extraInfo.requestedMapping.map(_.map(colTypes(_))).getOrElse(colTypes)
typeCastForDecimalIfNeededAndClose(table, tableSchema, readDataSchema)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ object GpuOverrides {
sparkSig = (TypeSig.atomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP +
TypeSig.UDT).nested())),
(OrcFormatType, FileFormatChecks(
cudfRead = (TypeSig.commonCudfTypes + TypeSig.ARRAY).nested(),
cudfRead = (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.DECIMAL_64).nested(),
cudfWrite = TypeSig.commonCudfTypes,
sparkSig = (TypeSig.atomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP +
TypeSig.UDT).nested())))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class DecimalUnitTest extends GpuUnitTests {

withGpuSparkSession((ss: SparkSession) => {
var rootPlan = frameFromOrc("decimal-test.orc")(ss).queryExecution.executedPlan
assert(rootPlan.map(p => p).exists(_.isInstanceOf[FileSourceScanExec]))
assert(rootPlan.map(p => p).exists(_.isInstanceOf[GpuFileSourceScanExec]))
rootPlan = fromCsvDf("decimal-test.csv", decimalCsvStruct)(ss).queryExecution.executedPlan
assert(rootPlan.map(p => p).exists(_.isInstanceOf[FileSourceScanExec]))
rootPlan = frameFromParquet("decimal-test.parquet")(ss).queryExecution.executedPlan
Expand All @@ -285,7 +285,7 @@ class DecimalUnitTest extends GpuUnitTests {

withGpuSparkSession((ss: SparkSession) => {
var rootPlan = frameFromOrc("decimal-test.orc")(ss).queryExecution.executedPlan
assert(rootPlan.map(p => p).exists(_.isInstanceOf[BatchScanExec]))
assert(rootPlan.map(p => p).exists(_.isInstanceOf[GpuBatchScanExec]))
rootPlan = fromCsvDf("decimal-test.csv", decimalCsvStruct)(ss).queryExecution.executedPlan
assert(rootPlan.map(p => p).exists(_.isInstanceOf[BatchScanExec]))
rootPlan = frameFromParquet("decimal-test.parquet")(ss).queryExecution.executedPlan
Expand Down

0 comments on commit d3ae0d0

Please sign in to comment.