diff --git a/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala b/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala index 797e2bb5ef2..b518a14014b 100644 --- a/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala +++ b/shims/spark320/src/test/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala @@ -17,7 +17,6 @@ package com.nvidia.spark.rapids.shims.spark320; import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion, TypeSig} -import com.nvidia.spark.rapids.shims.v2.TypeSig320 import org.scalatest.FunSuite import org.apache.spark.sql.types.{DayTimeIntervalType, YearMonthIntervalType} @@ -34,7 +33,7 @@ class Spark320ShimsSuite extends FunSuite { } test("TypeSig320") { - val check = TypeSig320(TypeSig.DAYTIME + TypeSig.YEARMONTH) + val check = TypeSig.DAYTIME + TypeSig.YEARMONTH assert(check.isSupportedByPlugin(DayTimeIntervalType(), false) == true) assert(check.isSupportedByPlugin(YearMonthIntervalType(), false) == true) } diff --git a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtilUntil320.scala b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtilUntil320.scala new file mode 100644 index 00000000000..9de2fefba99 --- /dev/null +++ b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtilUntil320.scala @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids.{TypeEnum, TypeSig, TypeSigUtil} + +import org.apache.spark.sql.types.DataType + +/** + * This TypeSigUtil is for [spark 3.0.1, spark 3.2.0) + */ +object TypeSigUtilUntil320 extends TypeSigUtil { + /** + * Check if this type of Spark-specific is supported by the plugin or not. + * + * @param check the Supported Types + * @param dataType the data type to be checked + * @param allowDecimal whether decimal support is enabled or not + * @return true if it is allowed else false. + */ + override def isSupported( + check: TypeEnum.ValueSet, + dataType: DataType, + allowDecimal: Boolean): Boolean = false + + /** + * Get all supported types for the spark-specific + * + * @return the all supported typ + */ + override def getAllSupportedTypes(): TypeEnum.ValueSet = + TypeEnum.values - TypeEnum.DAYTIME - TypeEnum.YEARMONTH + + /** + * Return the reason why this type is not supported.\ + * + * @param check the Supported Types + * @param dataType the data type to be checked + * @param allowDecimal whether decimal support is enabled or not + * @param notSupportedReason the reason for not supporting + * @return the reason + */ + override def reasonNotSupported( + check: TypeEnum.ValueSet, + dataType: DataType, + allowDecimal: Boolean, notSupportedReason: Seq[String]): Seq[String] = notSupportedReason + + /** + * Get checks from TypeEnum + * + * @param from the TypeEnum to be matched + * @return the TypeSigs + */ + override def getCastChecksAndSigs(from: TypeEnum.Value): (TypeSig, TypeSig) = + throw new RuntimeException("Unsupported " + from) + + /** + * Get TypeSigs from DataType + * + * @param from the data type to be matched + * @param default the default TypeSig + * @param sparkDefault the default Spark TypeSig + * @return the TypeSigs + */ + override def getCastChecksAndSigs( + from: DataType, + default: TypeSig, + sparkDefault: TypeSig): (TypeSig, TypeSig) = (default, sparkDefault) +} diff --git a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala index 2ff021f6193..d045340f755 100644 --- a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2 import scala.collection.mutable.ListBuffer -import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig} +import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig, TypeSigUtil} import com.nvidia.spark.rapids.GpuOverrides.exec import org.apache.hadoop.fs.FileStatus @@ -126,4 +126,5 @@ trait Spark30XShims extends SparkShims { ss.sparkContext.defaultParallelism } + override def getTypeSigUtil(): TypeSigUtil = TypeSigUtilUntil320 } diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index 3e5d74f5ee0..343815ae012 100644 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -137,6 +137,7 @@ trait Spark32XShims extends SparkShims { Spark32XShimsUtils.leafNodeDefaultParallelism(ss) } + override def getTypeSigUtil(): TypeSigUtil = TypeSigUtilFrom320 } // TODO dedupe utils inside shims diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/TypeSig320.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/TypeSig320.scala deleted file mode 100644 index 0c1440eebda..00000000000 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/TypeSig320.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.shims.v2 - -import ai.rapids.cudf.DType -import com.nvidia.spark.rapids.{TypeEnum, TypeSig} - -import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, YearMonthIntervalType} - -/** TypeSig for Spark 3.2.0+ which adds DayTimeIntervalType and YearMonthIntervalType support */ -final class TypeSig320( - override val initialTypes: TypeEnum.ValueSet, - override val maxAllowedDecimalPrecision: Int = DType.DECIMAL64_MAX_PRECISION, - override val childTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), - override val litOnlyTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), - override val notes: Map[TypeEnum.Value, String] = Map.empty) - extends TypeSig(initialTypes, maxAllowedDecimalPrecision, childTypes, litOnlyTypes, notes) { - - override protected[this] def isLitOnly(dataType: DataType): Boolean = { - dataType match { - case _: DayTimeIntervalType => litOnlyTypes.contains(TypeEnum.DAYTIME) - case _: YearMonthIntervalType => litOnlyTypes.contains(TypeEnum.YEARMONTH) - case _ => super.isLitOnly(dataType) - } - } - - override protected[this] def isSupported( - check: TypeEnum.ValueSet, - dataType: DataType, - allowDecimal: Boolean): Boolean = { - dataType match { - case _: DayTimeIntervalType => check.contains(TypeEnum.DAYTIME) - case _: YearMonthIntervalType => check.contains(TypeEnum.YEARMONTH) - case _ => super.isSupported(check, dataType, allowDecimal) - } - } - - override protected[this] def reasonNotSupported( - check: TypeEnum.ValueSet, - dataType: DataType, - isChild: Boolean, - allowDecimal: Boolean): Seq[String] = { - dataType match { - case _: DayTimeIntervalType => - basicNotSupportedMessage(dataType, TypeEnum.DAYTIME, check, isChild) - case _: YearMonthIntervalType => - basicNotSupportedMessage(dataType, TypeEnum.YEARMONTH, check, isChild) - case _ => super.reasonNotSupported(check, dataType, isChild, allowDecimal) - } - } -} - -object TypeSig320 { - - /** Convert TypeSig to TypeSig320 */ - def apply(typeSig: TypeSig): TypeSig320 = { - new TypeSig320(typeSig.initialTypes, typeSig.maxAllowedDecimalPrecision, typeSig.childTypes, - typeSig.litOnlyTypes, typeSig.notes) - } - -} diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtilFrom320.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtilFrom320.scala new file mode 100644 index 00000000000..47ef93375ad --- /dev/null +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtilFrom320.scala @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims.v2 + +import com.nvidia.spark.rapids.{TypeEnum, TypeSig, TypeSigUtil} + +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, YearMonthIntervalType} + +/** + * Add DayTimeIntervalType and YearMonthIntervalType support + */ +object TypeSigUtilFrom320 extends TypeSigUtil { + + override def isSupported( + check: TypeEnum.ValueSet, + dataType: DataType, + allowDecimal: Boolean): Boolean = { + dataType match { + case _: DayTimeIntervalType => check.contains(TypeEnum.DAYTIME) + case _: YearMonthIntervalType => check.contains(TypeEnum.YEARMONTH) + case _ => false + } + } + + override def getAllSupportedTypes(): TypeEnum.ValueSet = TypeEnum.values + + override def reasonNotSupported( + check: TypeEnum.ValueSet, + dataType: DataType, + allowDecimal: Boolean, + notSupportedReason: Seq[String]): Seq[String] = { + dataType match { + case _: DayTimeIntervalType => + if (check.contains(TypeEnum.DAYTIME)) Seq.empty else notSupportedReason + case _: YearMonthIntervalType => + if (check.contains(TypeEnum.YEARMONTH)) Seq.empty else notSupportedReason + case _ => notSupportedReason + } + } + + override def getCastChecksAndSigs( + from: DataType, + default: TypeSig, + sparkDefault: TypeSig): (TypeSig, TypeSig) = { + from match { + case _: DayTimeIntervalType => (daytimeChecks, sparkDaytimeSig) + case _: YearMonthIntervalType =>(yearmonthChecks, sparkYearmonthSig) + case _ => (default, sparkDefault) + } + } + + override def getCastChecksAndSigs(from: TypeEnum.Value): (TypeSig, TypeSig) = { + from match { + case TypeEnum.DAYTIME => (daytimeChecks, sparkDaytimeSig) + case TypeEnum.YEARMONTH => (yearmonthChecks, sparkYearmonthSig) + } + } + + def daytimeChecks: TypeSig = TypeSig.none + def sparkDaytimeSig: TypeSig = TypeSig.DAYTIME + TypeSig.STRING + + def yearmonthChecks: TypeSig = TypeSig.none + def sparkYearmonthSig: TypeSig = TypeSig.YEARMONTH + TypeSig.STRING +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 39cac5edc99..17623d54cc2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -266,6 +266,9 @@ trait SparkShims { def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean def leafNodeDefaultParallelism(ss: SparkSession): Int + + def getTypeSigUtil(): TypeSigUtil + } abstract class SparkCommonShims extends SparkShims { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index ba3dd2d0741..e3ef2c3c260 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -25,6 +25,58 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnaryExpression, WindowSpecDefinition} import org.apache.spark.sql.types._ +/** TypeSigUtil for different shim layers */ +trait TypeSigUtil { + + /** + * Check if this type of Spark-specific is supported by the plugin or not. + * @param check the Supported Types + * @param dataType the data type to be checked + * @param allowDecimal whether decimal support is enabled or not + * @return true if it is allowed else false. + */ + def isSupported(check: TypeEnum.ValueSet, dataType: DataType, allowDecimal: Boolean): Boolean + + /** + * Get all supported types for the spark-specific + * @return the all supported typ + */ + def getAllSupportedTypes(): TypeEnum.ValueSet + + /** + * Return the reason why this type is not supported.\ + * @param check the Supported Types + * @param dataType the data type to be checked + * @param allowDecimal whether decimal support is enabled or not + * @param notSupportedReason the reason for not supporting + * @return the reason + */ + def reasonNotSupported( + check: TypeEnum.ValueSet, + dataType: DataType, + allowDecimal: Boolean, + notSupportedReason: Seq[String]): Seq[String] + + /** + * Get TypeSigs from DataType + * @param from the data type to be matched + * @param default the default TypeSig + * @param sparkDefault the default Spark TypeSig + * @return the TypeSigs + */ + def getCastChecksAndSigs( + from: DataType, + default: TypeSig, + sparkDefault: TypeSig): (TypeSig, TypeSig) + + /** + * Get checks from TypeEnum + * @param from the TypeEnum to be matched + * @return the TypeSigs + */ + def getCastChecksAndSigs(from: TypeEnum.Value): (TypeSig, TypeSig) + +} /** * The level of support that the plugin has for a given type. Used for documentation generation. @@ -116,7 +168,7 @@ object TypeEnum extends Enumeration { val STRUCT: Value = Value val UDT: Value = Value val DAYTIME: Value = Value - val YEARMONTH:Value = Value + val YEARMONTH: Value = Value } /** @@ -124,12 +176,12 @@ object TypeEnum extends Enumeration { * a set of base types and a separate set of types that can be nested under the base types * (child types). It can also express if a particular base type has to be a literal or not. */ -class TypeSig ( - protected[rapids] val initialTypes: TypeEnum.ValueSet, - protected[rapids] val maxAllowedDecimalPrecision: Int = DType.DECIMAL64_MAX_PRECISION, - protected[rapids] val childTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), - protected[rapids] val litOnlyTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), - protected[rapids] val notes: Map[TypeEnum.Value, String] = Map.empty) { +final class TypeSig private( + private val initialTypes: TypeEnum.ValueSet, + private val maxAllowedDecimalPrecision: Int = DType.DECIMAL64_MAX_PRECISION, + private val childTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), + private val litOnlyTypes: TypeEnum.ValueSet = TypeEnum.ValueSet(), + private val notes: Map[TypeEnum.Value, String] = Map.empty) { /** * Add a literal restriction to the signature @@ -259,7 +311,7 @@ class TypeSig ( def isSupportedByPlugin(dataType: DataType, allowDecimal: Boolean): Boolean = isSupported(initialTypes, dataType, allowDecimal) - protected [this] def isLitOnly(dataType: DataType): Boolean = dataType match { + private [this] def isLitOnly(dataType: DataType): Boolean = dataType match { case BooleanType => litOnlyTypes.contains(TypeEnum.BOOLEAN) case ByteType => litOnlyTypes.contains(TypeEnum.BYTE) case ShortType => litOnlyTypes.contains(TypeEnum.SHORT) @@ -277,13 +329,13 @@ class TypeSig ( case _: ArrayType => litOnlyTypes.contains(TypeEnum.ARRAY) case _: MapType => litOnlyTypes.contains(TypeEnum.MAP) case _: StructType => litOnlyTypes.contains(TypeEnum.STRUCT) - case _ => false + case _ => ShimLoader.getSparkShims.getTypeSigUtil().isSupported(litOnlyTypes, dataType, false) } def isSupportedBySpark(dataType: DataType): Boolean = isSupported(initialTypes, dataType, allowDecimal = true) - protected[this] def isSupported( + private[this] def isSupported( check: TypeEnum.ValueSet, dataType: DataType, allowDecimal: Boolean): Boolean = @@ -314,7 +366,7 @@ class TypeSig ( fields.map(_.dataType).forall { t => isSupported(childTypes, t, allowDecimal) } - case _ => false + case _ => ShimLoader.getSparkShims.getTypeSigUtil().isSupported(check, dataType, allowDecimal) } def reasonNotSupported(dataType: DataType, allowDecimal: Boolean): Seq[String] = @@ -326,7 +378,7 @@ class TypeSig ( msg } - protected[this] def basicNotSupportedMessage(dataType: DataType, + private[this] def basicNotSupportedMessage(dataType: DataType, te: TypeEnum.Value, check: TypeEnum.ValueSet, isChild: Boolean): Seq[String] = { if (check.contains(te)) { Seq.empty @@ -335,7 +387,7 @@ class TypeSig ( } } - protected[this] def reasonNotSupported( + private[this] def reasonNotSupported( check: TypeEnum.ValueSet, dataType: DataType, isChild: Boolean, @@ -410,8 +462,8 @@ class TypeSig ( } else { basicNotSupportedMessage(dataType, TypeEnum.STRUCT, check, isChild) } - case _ => - Seq(withChild(isChild, s"$dataType is not supported")) + case _ => ShimLoader.getSparkShims.getTypeSigUtil().reasonNotSupported(check, dataType, + allowDecimal, Seq(withChild(isChild, s"$dataType is not supported"))) } def areAllSupportedByPlugin(types: Seq[DataType], allowDecimal: Boolean): Boolean = @@ -505,7 +557,10 @@ object TypeSig { /** * All types nested and not nested */ - val all: TypeSig = new TypeSig(TypeEnum.values, DecimalType.MAX_PRECISION, TypeEnum.values) + val all: TypeSig = { + val allSupportedTypes = ShimLoader.getSparkShims.getTypeSigUtil().getAllSupportedTypes() + new TypeSig(allSupportedTypes, DecimalType.MAX_PRECISION, allSupportedTypes) + } /** * No types supported at all @@ -555,12 +610,12 @@ object TypeSig { val UDT: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.UDT)) /** - * DayTimeIntervalType support from Spark 3.2.0+ + * DayTimeIntervalType of Spark 3.2.0+ support */ val DAYTIME: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.DAYTIME)) /** - * YearMonthIntervalType support from Spark 3.2.0+ + * YearMonthIntervalType of Spark 3.2.0+ support */ val YEARMONTH: TypeSig = new TypeSig(TypeEnum.ValueSet(TypeEnum.YEARMONTH)) @@ -1283,7 +1338,8 @@ class CastChecks extends ExprChecks { case _: ArrayType => (arrayChecks, sparkArraySig) case _: MapType => (mapChecks, sparkMapSig) case _: StructType => (structChecks, sparkStructSig) - case _ => (udtChecks, sparkUdtSig) + case _ => + ShimLoader.getSparkShims.getTypeSigUtil().getCastChecksAndSigs(from, udtChecks, sparkUdtSig) } private[this] def getChecksAndSigs(from: TypeEnum.Value): (TypeSig, TypeSig) = from match { @@ -1302,6 +1358,7 @@ class CastChecks extends ExprChecks { case TypeEnum.MAP => (mapChecks, sparkMapSig) case TypeEnum.STRUCT => (structChecks, sparkStructSig) case TypeEnum.UDT => (udtChecks, sparkUdtSig) + case _ => ShimLoader.getSparkShims.getTypeSigUtil().getCastChecksAndSigs(from) } override def tagAst(meta: BaseExprMeta[_]): Unit = { @@ -1578,13 +1635,16 @@ object ExprChecks { * Used for generating the support docs. */ object SupportedOpsDocs { + private lazy val allSupportedTypes = + ShimLoader.getSparkShims.getTypeSigUtil().getAllSupportedTypes() + private def execChecksHeaderLine(): Unit = { println("") println("Executor") println("Description") println("Notes") println("Param(s)") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(s"$t") } println("") @@ -1598,7 +1658,7 @@ object SupportedOpsDocs { println("Notes") println("Context") println("Param/Output") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(s"$t") } println("") @@ -1610,7 +1670,7 @@ object SupportedOpsDocs { println("Description") println("Notes") println("Param") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(s"$t") } println("") @@ -1620,7 +1680,7 @@ object SupportedOpsDocs { println("") println("Format") println("Direction") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(s"$t") } println("") @@ -1750,7 +1810,7 @@ object SupportedOpsDocs { } println("") val execChecks = checks.get.asInstanceOf[ExecChecks] - val allData = TypeEnum.values.map { t => + val allData = allSupportedTypes.map { t => (t, execChecks.support(t)) }.toMap @@ -1769,7 +1829,7 @@ object SupportedOpsDocs { .map(l => input + "
(" + l.mkString(";
") + ")") .getOrElse(input) println(s"$named") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(allData(t)(input).htmlTag) } println("") @@ -1833,7 +1893,7 @@ object SupportedOpsDocs { ConfHelper.getSqlFunctionsForClass(rule.tag.runtimeClass).map(_.mkString(", ")) val exprChecks = checks.get.asInstanceOf[ExprChecks] // Params can change between contexts, but should not - val allData = TypeEnum.values.map { t => + val allData = allSupportedTypes.map { t => (t, exprChecks.support(t)) }.toMap // Now we should get the same keys for each type, so we are only going to look at the first @@ -1855,7 +1915,7 @@ object SupportedOpsDocs { println("" + s"$context") data.keys.foreach { param => println(s"$param") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(allData(t)(context)(param).htmlTag) } println("") @@ -1890,19 +1950,19 @@ object SupportedOpsDocs { println(s"### `${rule.tag.runtimeClass.getSimpleName}`") println() println("") - val numTypes = TypeEnum.values.size + val numTypes = allSupportedTypes.size println("") println("") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(s"") } println("") println("") var count = 0 - TypeEnum.values.foreach { from => + allSupportedTypes.foreach { from => println(s"") - TypeEnum.values.foreach { to => + allSupportedTypes.foreach { to => println(cc.support(from, to).htmlTag) } println("") @@ -1937,7 +1997,7 @@ object SupportedOpsDocs { nextOutputAt = totalCount + headerEveryNLines } val partChecks = checks.get.asInstanceOf[PartChecks] - val allData = TypeEnum.values.map { t => + val allData = allSupportedTypes.map { t => (t, partChecks.support(t)) }.toMap // Now we should get the same keys for each type, so we are only going to look at the first @@ -1953,7 +2013,7 @@ object SupportedOpsDocs { var count = 0 representative.keys.foreach { param => println(s"") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(allData(t)(param).htmlTag) } println("") @@ -1970,7 +2030,7 @@ object SupportedOpsDocs { println(s"") println(s"") println(NotApplicable.htmlTag) // param - TypeEnum.values.foreach { _ => + allSupportedTypes.foreach { _ => println(NotApplicable.htmlTag) } println("") @@ -2000,13 +2060,13 @@ object SupportedOpsDocs { println("") println("") println("") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(read.support(t).htmlTag) } println("") println("") println("") - TypeEnum.values.foreach { t => + allSupportedTypes.foreach { t => println(write.support(t).htmlTag) } println("") @@ -2028,11 +2088,14 @@ object SupportedOpsDocs { object SupportedOpsForTools { + private lazy val allSupportedTypes = + ShimLoader.getSparkShims.getTypeSigUtil().getAllSupportedTypes() + private def outputSupportIO() { // Look at what we have for defaults for some configs because if the configs are off // it likely means something isn't completely compatible. val conf = new RapidsConf(Map.empty[String, String]) - val types = TypeEnum.values.toSeq + val types = allSupportedTypes.toSeq val header = Seq("Format", "Direction") ++ types val writeOps: Array[String] = Array.fill(types.size)("NA") println(header.mkString(","))
TO
$t
FROM$from
$param
${rule.description}${rule.notes().getOrElse("None")}
" + s"$formatRead
Write