Skip to content

Commit

Permalink
Use static way instead of shim layer way
Browse files Browse the repository at this point in the history
Signed-off-by: Bobby Wang <[email protected]>
  • Loading branch information
wbo4958 committed Sep 15, 2021
1 parent 583e8c5 commit 6edb050
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig, TypeSigUtil, TypeSigUtilUntil320}
import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus

Expand Down Expand Up @@ -117,5 +117,4 @@ trait Spark30XShims extends SparkShims {

override def shouldFailDivOverflow(): Boolean = false

override def getTypeSigUtil(): TypeSigUtil = TypeSigUtilUntil320
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids
package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{TypeEnum, TypeSig}

import org.apache.spark.sql.types.DataType

/**
* This TypeSigUtil is for [spark 3.0.1, spark 3.2.0)
*/
object TypeSigUtilUntil320 extends TypeSigUtil {
object TypeSigUtil extends com.nvidia.spark.rapids.TypeSigUtil {

/**
* Check if this type of Spark-specific is supported by the plugin or not.
*
Expand Down Expand Up @@ -78,4 +78,5 @@ object TypeSigUtilUntil320 extends TypeSigUtil {
from: DataType,
default: TypeSig,
sparkDefault: TypeSig): (TypeSig, TypeSig) = (default, sparkDefault)

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig, TypeSigUtil, TypeSigUtilUntil320}
import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus

Expand Down Expand Up @@ -126,5 +126,4 @@ trait Spark30XShims extends SparkShims {
ss.sparkContext.defaultParallelism
}

override def getTypeSigUtil(): TypeSigUtil = TypeSigUtilUntil320
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig, TypeSigUtil, TypeSigUtilUntil320}
import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus

Expand Down Expand Up @@ -94,5 +94,4 @@ trait Spark30XShims extends SparkShims {

override def shouldFailDivOverflow(): Boolean = false

override def getTypeSigUtil(): TypeSigUtil = TypeSigUtilUntil320
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ trait Spark32XShims extends SparkShims {
Spark32XShimsUtils.leafNodeDefaultParallelism(ss)
}

override def getTypeSigUtil(): TypeSigUtil = TypeSigUtilFrom320
}

// TODO dedupe utils inside shims
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{TypeEnum, TypeSig, TypeSigUtil}
import com.nvidia.spark.rapids.{TypeEnum, TypeSig}

import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, YearMonthIntervalType}

/**
* Add DayTimeIntervalType and YearMonthIntervalType support
*/
object TypeSigUtilFrom320 extends TypeSigUtil {
object TypeSigUtil extends com.nvidia.spark.rapids.TypeSigUtil {

override def isSupported(
check: TypeEnum.ValueSet,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,6 @@ trait SparkShims {

def leafNodeDefaultParallelism(ss: SparkSession): Int

def getTypeSigUtil(): TypeSigUtil

}

abstract class SparkCommonShims extends SparkShims {
Expand Down
22 changes: 12 additions & 10 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ package com.nvidia.spark.rapids
import java.io.{File, FileOutputStream}
import java.time.ZoneId

import ai.rapids.cudf.DType
import scala.collection.mutable

import ai.rapids.cudf.DType
import com.nvidia.spark.rapids.shims.v2.TypeSigUtil

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnaryExpression, WindowSpecDefinition}
import org.apache.spark.sql.types._

/** TypeSigUtil for different shim layers */
/** TypeSigUtil for different spark versions */
trait TypeSigUtil {

/**
Expand Down Expand Up @@ -329,7 +331,7 @@ final class TypeSig private(
case _: ArrayType => litOnlyTypes.contains(TypeEnum.ARRAY)
case _: MapType => litOnlyTypes.contains(TypeEnum.MAP)
case _: StructType => litOnlyTypes.contains(TypeEnum.STRUCT)
case _ => ShimLoader.getSparkShims.getTypeSigUtil().isSupported(litOnlyTypes, dataType, false)
case _ => TypeSigUtil.isSupported(litOnlyTypes, dataType, false)
}

def isSupportedBySpark(dataType: DataType): Boolean =
Expand Down Expand Up @@ -366,7 +368,7 @@ final class TypeSig private(
fields.map(_.dataType).forall { t =>
isSupported(childTypes, t, allowDecimal)
}
case _ => ShimLoader.getSparkShims.getTypeSigUtil().isSupported(check, dataType, allowDecimal)
case _ => TypeSigUtil.isSupported(check, dataType, allowDecimal)
}

def reasonNotSupported(dataType: DataType, allowDecimal: Boolean): Seq[String] =
Expand Down Expand Up @@ -462,7 +464,7 @@ final class TypeSig private(
} else {
basicNotSupportedMessage(dataType, TypeEnum.STRUCT, check, isChild)
}
case _ => ShimLoader.getSparkShims.getTypeSigUtil().reasonNotSupported(check, dataType,
case _ => TypeSigUtil.reasonNotSupported(check, dataType,
allowDecimal, Seq(withChild(isChild, s"$dataType is not supported")))
}

Expand Down Expand Up @@ -558,7 +560,7 @@ object TypeSig {
* All types nested and not nested
*/
val all: TypeSig = {
val allSupportedTypes = ShimLoader.getSparkShims.getTypeSigUtil().getAllSupportedTypes()
val allSupportedTypes = TypeSigUtil.getAllSupportedTypes()
new TypeSig(allSupportedTypes, DecimalType.MAX_PRECISION, allSupportedTypes)
}

Expand Down Expand Up @@ -1339,7 +1341,7 @@ class CastChecks extends ExprChecks {
case _: MapType => (mapChecks, sparkMapSig)
case _: StructType => (structChecks, sparkStructSig)
case _ =>
ShimLoader.getSparkShims.getTypeSigUtil().getCastChecksAndSigs(from, udtChecks, sparkUdtSig)
TypeSigUtil.getCastChecksAndSigs(from, udtChecks, sparkUdtSig)
}

private[this] def getChecksAndSigs(from: TypeEnum.Value): (TypeSig, TypeSig) = from match {
Expand All @@ -1358,7 +1360,7 @@ class CastChecks extends ExprChecks {
case TypeEnum.MAP => (mapChecks, sparkMapSig)
case TypeEnum.STRUCT => (structChecks, sparkStructSig)
case TypeEnum.UDT => (udtChecks, sparkUdtSig)
case _ => ShimLoader.getSparkShims.getTypeSigUtil().getCastChecksAndSigs(from)
case _ => TypeSigUtil.getCastChecksAndSigs(from)
}

override def tagAst(meta: BaseExprMeta[_]): Unit = {
Expand Down Expand Up @@ -1636,7 +1638,7 @@ object ExprChecks {
*/
object SupportedOpsDocs {
private lazy val allSupportedTypes =
ShimLoader.getSparkShims.getTypeSigUtil().getAllSupportedTypes()
TypeSigUtil.getAllSupportedTypes()

private def execChecksHeaderLine(): Unit = {
println("<tr>")
Expand Down Expand Up @@ -2089,7 +2091,7 @@ object SupportedOpsDocs {
object SupportedOpsForTools {

private lazy val allSupportedTypes =
ShimLoader.getSparkShims.getTypeSigUtil().getAllSupportedTypes()
TypeSigUtil.getAllSupportedTypes()

private def outputSupportIO() {
// Look at what we have for defaults for some configs because if the configs are off
Expand Down

0 comments on commit 6edb050

Please sign in to comment.