From 87fc89772424668f7726045db4a4f069562e1765 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Fri, 16 Apr 2021 09:16:55 -0500 Subject: [PATCH] Auto generate the supported types for the file formats (#2149) Signed-off-by: Robert (Bobby) Evans --- docs/supported_ops.md | 87 ++++--- .../spark/rapids/GpuBatchScanExec.scala | 4 +- .../com/nvidia/spark/rapids/GpuOrcScan.scala | 7 +- .../nvidia/spark/rapids/GpuOverrides.scala | 35 +++ .../spark/rapids/GpuParquetFileFormat.scala | 8 +- .../nvidia/spark/rapids/GpuParquetScan.scala | 12 +- .../com/nvidia/spark/rapids/TypeChecks.scala | 227 ++++++++---------- .../spark/sql/rapids/GpuOrcFileFormat.scala | 7 +- 8 files changed, 201 insertions(+), 186 deletions(-) diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 4de98fc44cb..127770c97b1 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -20026,12 +20026,11 @@ dates or timestamps, or for a lack of type coercion support. ARRAY MAP STRUCT +UDT -Parquet -Input -S -S +CSV +Read S S S @@ -20040,18 +20039,41 @@ dates or timestamps, or for a lack of type coercion support. S S S +S* S - NS - -PS (missing nested BINARY) -PS (missing nested BINARY) -PS (missing nested BINARY) + +NS + + + + + -Output -S -S +Write +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS +NS + +NS + + + + + + + +ORC +Read S S S @@ -20060,18 +20082,19 @@ dates or timestamps, or for a lack of type coercion support. S S S +S* S - NS - + +NS + +NS NS NS NS -ORC -Input -S +Write S S S @@ -20080,18 +20103,20 @@ dates or timestamps, or for a lack of type coercion support. S S S +S* S NS + NS + NS - NS NS NS -Output -S +Parquet +Read S S S @@ -20100,19 +20125,19 @@ dates or timestamps, or for a lack of type coercion support. S S S +S* S +S* + NS -NS -NS - -NS -NS + +PS* (missing nested BINARY, UDT) +PS* (missing nested BINARY, UDT) +PS* (missing nested BINARY, UDT) NS -CSV -Input -S +Write S S S @@ -20121,10 +20146,12 @@ dates or timestamps, or for a lack of type coercion support. S S S +S* S +S* + NS -NS -NS + NS NS NS diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala index d13fcf7fbe7..495c325fc88 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala @@ -281,9 +281,7 @@ object GpuCSVScan { } // TODO parsedOptions.emptyValueInRead - if (readSchema.exists(_.dataType.isInstanceOf[DecimalType])) { - meta.willNotWorkOnGpu("DecimalType is not supported") - } + FileFormatChecks.tag(meta, readSchema, CsvFormatType, ReadFileOp) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index 7f8b5853c15..30820ab4569 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -109,15 +109,12 @@ object GpuOrcScanBase { s"${RapidsConf.ENABLE_ORC_READ} to true") } + FileFormatChecks.tag(meta, schema, OrcFormatType, ReadFileOp) + if (sparkSession.conf .getOption("spark.sql.orc.mergeSchema").exists(_.toBoolean)) { meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet") } - schema.foreach { field => - if (!GpuOverrides.isSupportedType(field.dataType)) { - meta.willNotWorkOnGpu(s"GpuOrcScan does not support fields of type ${field.dataType}") - } - } } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 6dd67b4e6fa..9542fedad2e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -410,6 +410,25 @@ trait GpuOverridesListener { costOptimizations: Seq[Optimization]) } +sealed trait FileFormatType +object CsvFormatType extends FileFormatType { + override def toString = "CSV" +} +object ParquetFormatType extends FileFormatType { + override def toString = "Parquet" +} +object OrcFormatType extends FileFormatType { + override def toString = "ORC" +} + +sealed trait FileFormatOp +object ReadFileOp extends FileFormatOp { + override def toString = "read" +} +object WriteFileOp extends FileFormatOp { + override def toString = "write" +} + object GpuOverrides { val FLOAT_DIFFERS_GROUP_INCOMPAT = "when enabling these, there may be extra groups produced for floating point grouping " + @@ -735,6 +754,22 @@ object GpuOverrides { .map(r => r.wrap(expr, conf, parent, r).asInstanceOf[BaseExprMeta[INPUT]]) .getOrElse(new RuleNotFoundExprMeta(expr, conf, parent)) + val fileFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = Map( + (CsvFormatType, FileFormatChecks( + cudfRead = TypeSig.commonCudfTypes, + cudfWrite = TypeSig.none, + sparkSig = TypeSig.atomics)), + (ParquetFormatType, FileFormatChecks( + cudfRead = (TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT + TypeSig.ARRAY + + TypeSig.MAP).nested(), + cudfWrite = TypeSig.commonCudfTypes + TypeSig.DECIMAL, + sparkSig = (TypeSig.atomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.UDT).nested())), + (OrcFormatType, FileFormatChecks( + cudfReadWrite = TypeSig.commonCudfTypes, + sparkSig = (TypeSig.atomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.UDT).nested()))) + val commonExpressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( expr[Literal]( "Holds a static value from the query", diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index d1a8c3a979e..90f78fdddb1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -42,12 +42,6 @@ object GpuParquetFileFormat { options: Map[String, String], schema: StructType): Option[GpuParquetFileFormat] = { - val unSupportedTypes = - schema.filterNot(field => GpuOverrides.isSupportedType(field.dataType, allowDecimal = true)) - if (unSupportedTypes.nonEmpty) { - meta.willNotWorkOnGpu(s"These types aren't supported for parquet $unSupportedTypes") - } - val sqlConf = spark.sessionState.conf val parquetOptions = new ParquetOptions(options, sqlConf) @@ -61,6 +55,8 @@ object GpuParquetFileFormat { s"${RapidsConf.ENABLE_PARQUET_WRITE} to true") } + FileFormatChecks.tag(meta, schema, ParquetFormatType, WriteFileOp) + parseCompressionType(parquetOptions.compressionCodecClassName) .getOrElse(meta.willNotWorkOnGpu( s"compression codec ${parquetOptions.compressionCodecClassName} is not supported")) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 5d346edf9ec..90f627c8103 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -134,17 +134,7 @@ object GpuParquetScanBase { s"${RapidsConf.ENABLE_PARQUET_READ} to true") } - for (field <- readSchema) { - if (!GpuOverrides.isSupportedType( - field.dataType, - allowMaps = true, - allowArray = true, - allowStruct = true, - allowNesting = true, - allowDecimal = meta.conf.decimalTypeEnabled)) { - meta.willNotWorkOnGpu(s"GpuParquetScan does not support fields of type ${field.dataType}") - } - } + FileFormatChecks.tag(meta, readSchema, ParquetFormatType, ReadFileOp) val schemaHasStrings = readSchema.exists { field => TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[StringType]) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index ee7acc2e1b1..a505e960dc1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -442,6 +442,11 @@ object TypeSig { */ val numeric: TypeSig = integral + fp + DECIMAL + /** + * All values that correspond to Spark's AtomicType + */ + val atomics: TypeSig = numeric + BINARY + BOOLEAN + DATE + STRING + TIMESTAMP + /** * numeric + CALENDAR */ @@ -553,6 +558,65 @@ case class ContextChecks( } } +/** + * Checks for either a read or a write of a given file format. + */ +class FileFormatChecks private ( + sig: TypeSig, + sparkSig: TypeSig) + extends TypeChecks[SupportLevel] { + + def tag(meta: RapidsMeta[_, _, _], + schema: StructType, + fileType: FileFormatType, + op: FileFormatOp): Unit = { + val allowDecimal = meta.conf.decimalTypeEnabled + + val unsupportedOutputTypes = schema.fields + .filterNot(attr => sig.isSupportedByPlugin(attr.dataType, allowDecimal)) + .toSet + + if (unsupportedOutputTypes.nonEmpty) { + meta.willNotWorkOnGpu("unsupported data types " + + unsupportedOutputTypes.mkString(", ") + s" in $op for $fileType") + } + } + + override def support(dataType: TypeEnum.Value): SupportLevel = + sig.getSupportLevel(dataType, sparkSig) + + override def tag(meta: RapidsMeta[_, _, _]): Unit = + throw new IllegalStateException("Internal Error not supported") +} + +object FileFormatChecks { + /** + * File format checks with separate read and write signatures for cudf. + */ + def apply( + cudfRead: TypeSig, + cudfWrite: TypeSig, + sparkSig: TypeSig): Map[FileFormatOp, FileFormatChecks] = Map( + (ReadFileOp, new FileFormatChecks(cudfRead, sparkSig)), + (WriteFileOp, new FileFormatChecks(cudfWrite, sparkSig)) + ) + + /** + * File format checks where read and write have the same signature for cudf. + */ + def apply( + cudfReadWrite: TypeSig, + sparkSig: TypeSig): Map[FileFormatOp, FileFormatChecks] = + apply(cudfReadWrite, cudfReadWrite, sparkSig) + + def tag(meta: RapidsMeta[_, _, _], + schema: StructType, + fileType: FileFormatType, + op: FileFormatOp): Unit = { + GpuOverrides.fileFormats(fileType)(op).tag(meta, schema, fileType, op) + } +} + /** * Checks the input and output types supported by a SparkPlan node. We don't currently separate * input checks from output checks. We can add this in if something needs it. @@ -777,7 +841,7 @@ object CreateNamedStructCheck extends ExprChecks { valueSig.tagExprParam(meta, expr, "value") } if (!resultSig.isSupportedByPlugin(origExpr.dataType, meta.conf.decimalTypeEnabled)) { - meta.willNotWorkOnGpu(s"unsupported data types in output: ${origExpr.dataType}") + meta.willNotWorkOnGpu(s"unsupported data type in output: ${origExpr.dataType}") } } } @@ -1144,6 +1208,16 @@ object SupportedOpsDocs { println("") } + private def ioChecksHeaderLine(): Unit = { + println("") + println("Format") + println("Direction") + TypeEnum.values.foreach { t => + println(s"$t") + } + println("") + } + def help(): Unit = { val headerEveryNLines = 15 // scalastyle:off line.size.limit @@ -1458,132 +1532,33 @@ object SupportedOpsDocs { println("This table tries to clarify that. Be aware that some types may be disabled in some") println("cases for either reads or writes because of processing limitations, like rebasing") println("dates or timestamps, or for a lack of type coercion support.") - // TODO this should be automatically generated println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") - println("") // BOOLEAN - println("") // BYTE - println("") // SHORT - println("") // INT - println("") // LONG - println("") // FLOAT - println("") // DOUBLE - println("") // DATE - println("") // TIMESTAMP - println("") // STRING - println("") // DECIMAL - println("") // NULL - println("") // BINARY - println("") // CALENDAR - println("") // ARRAY - println("") // MAP - println("") // STRUCT - println("") - println("") - println("") - println("") // BOOLEAN - println("") // BYTE - println("") // SHORT - println("") // INT - println("") // LONG - println("") // FLOAT - println("") // DOUBLE - println("") // DATE - println("") // TIMESTAMP - println("") // STRING - println("") // DECIMAL - println("") // NULL - println("") // BINARY - println("") // CALENDAR - println("") // ARRAY - println("") // MAP - println("") // STRUCT - println("") - println("") - println("") - println("") - println("") // BOOLEAN - println("") // BYTE - println("") // SHORT - println("") // INT - println("") // LONG - println("") // FLOAT - println("") // DOUBLE - println("") // DATE - println("") // TIMESTAMP - println("") // STRING - println("") // DECIMAL - println("") // NULL - println("") // BINARY - println("") // CALENDAR - println("") // ARRAY - println("") // MAP - println("") // STRUCT - println("") - println("") - println("") - println("") // BOOLEAN - println("") // BYTE - println("") // SHORT - println("") // INT - println("") // LONG - println("") // FLOAT - println("") // DOUBLE - println("") // DATE - println("") // TIMESTAMP - println("") // STRING - println("") // DECIMAL - println("") // NULL - println("") // BINARY - println("") // CALENDAR - println("") // ARRAY - println("") // MAP - println("") // STRUCT - println("") - println("") - println("") - println("") - println("") // BOOLEAN - println("") // BYTE - println("") // SHORT - println("") // INT - println("") // LONG - println("") // FLOAT - println("") // DOUBLE - println("") // DATE - println("") // TIMESTAMP - println("") // STRING - println("") // DECIMAL - println("") // NULL - println("") // BINARY - println("") // CALENDAR - println("") // ARRAY - println("") // MAP - println("") // STRUCT - println("") + ioChecksHeaderLine() + totalCount = 0 + nextOutputAt = headerEveryNLines + GpuOverrides.fileFormats.toSeq.sortBy(_._1.toString).foreach { + case (format, ioMap) => + if (totalCount >= nextOutputAt) { + ioChecksHeaderLine() + nextOutputAt = totalCount + headerEveryNLines + } + val read = ioMap(ReadFileOp) + val write = ioMap(WriteFileOp) + println("") + println("") + println("") + TypeEnum.values.foreach { t => + println(read.support(t).htmlTag) + } + println("") + println("") + println("") + TypeEnum.values.foreach { t => + println(write.support(t).htmlTag) + } + println("") + totalCount += 2 + } println("
FormatDirectionBOOLEANBYTESHORTINTLONGFLOATDOUBLEDATETIMESTAMPSTRINGDECIMALNULLBINARYCALENDARARRAYMAPSTRUCT
ParquetInputSSSSSSSSSSSNSPS (missing nested BINARY)PS (missing nested BINARY)PS (missing nested BINARY)
OutputSSSSSSSSSSSNSNSNSNS
ORCInputSSSSSSSSSSNSNSNSNSNSNS
OutputSSSSSSSSSSNSNSNSNSNSNS
CSVInputSSSSSSSSSSNSNSNSNSNSNSNS
" + s"$formatRead
Write
") // scalastyle:on line.size.limit } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala index 8bc99822d77..c4f0f95b43b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -47,11 +47,6 @@ object GpuOrcFileFormat extends Logging { options: Map[String, String], schema: StructType): Option[GpuOrcFileFormat] = { - val unSupportedTypes = schema.filterNot(field => GpuOverrides.isSupportedType(field.dataType)) - if (unSupportedTypes.nonEmpty) { - meta.willNotWorkOnGpu(s"These types aren't supported for orc $unSupportedTypes") - } - if (!meta.conf.isOrcEnabled) { meta.willNotWorkOnGpu("ORC input and output has been disabled. To enable set" + s"${RapidsConf.ENABLE_ORC} to true") @@ -62,6 +57,8 @@ object GpuOrcFileFormat extends Logging { s"${RapidsConf.ENABLE_ORC_WRITE} to true") } + FileFormatChecks.tag(meta, schema, OrcFormatType, WriteFileOp) + val sqlConf = spark.sessionState.conf val parameters = CaseInsensitiveMap(options)