Skip to content

Commit

Permalink
new solution
Browse files Browse the repository at this point in the history
Signed-off-by: Bobby Wang <[email protected]>
  • Loading branch information
wbo4958 committed Sep 14, 2021
1 parent 26adbf6 commit 0a49903
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.nvidia.spark.rapids.shims.spark320;

import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion, TypeSig}
import com.nvidia.spark.rapids.shims.v2.TypeSig320
import org.scalatest.FunSuite

import org.apache.spark.sql.types.{DayTimeIntervalType, YearMonthIntervalType}
Expand All @@ -34,7 +33,7 @@ class Spark320ShimsSuite extends FunSuite {
}

test("TypeSig320") {
val check = TypeSig320(TypeSig.DAYTIME + TypeSig.YEARMONTH)
val check = TypeSig.DAYTIME + TypeSig.YEARMONTH
assert(check.isSupportedByPlugin(DayTimeIntervalType(), false) == true)
assert(check.isSupportedByPlugin(YearMonthIntervalType(), false) == true)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{TypeEnum, TypeSig, TypeSigUtil}

import org.apache.spark.sql.types.DataType

/**
* This TypeSigUtil is for [spark 3.0.1, spark 3.2.0)
*/
object TypeSigUtilUntil320 extends TypeSigUtil {
/**
* Check if this type of Spark-specific is supported by the plugin or not.
*
* @param check the Supported Types
* @param dataType the data type to be checked
* @param allowDecimal whether decimal support is enabled or not
* @return true if it is allowed else false.
*/
override def isSupported(
check: TypeEnum.ValueSet,
dataType: DataType,
allowDecimal: Boolean): Boolean = false

/**
* Get all supported types for the spark-specific
*
* @return the all supported typ
*/
override def getAllSupportedTypes(): TypeEnum.ValueSet =
TypeEnum.values - TypeEnum.DAYTIME - TypeEnum.YEARMONTH

/**
* Return the reason why this type is not supported.\
*
* @param check the Supported Types
* @param dataType the data type to be checked
* @param allowDecimal whether decimal support is enabled or not
* @param notSupportedReason the reason for not supporting
* @return the reason
*/
override def reasonNotSupported(
check: TypeEnum.ValueSet,
dataType: DataType,
allowDecimal: Boolean, notSupportedReason: Seq[String]): Seq[String] = notSupportedReason

/**
* Get checks from TypeEnum
*
* @param from the TypeEnum to be matched
* @return the TypeSigs
*/
override def getCastChecksAndSigs(from: TypeEnum.Value): (TypeSig, TypeSig) =
throw new RuntimeException("Unsupported " + from)

/**
* Get TypeSigs from DataType
*
* @param from the data type to be matched
* @param default the default TypeSig
* @param sparkDefault the default Spark TypeSig
* @return the TypeSigs
*/
override def getCastChecksAndSigs(
from: DataType,
default: TypeSig,
sparkDefault: TypeSig): (TypeSig, TypeSig) = (default, sparkDefault)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids.shims.v2

import scala.collection.mutable.ListBuffer

import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig}
import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuExec, SparkPlanMeta, SparkShims, TypeSig, TypeSigUtil}
import com.nvidia.spark.rapids.GpuOverrides.exec
import org.apache.hadoop.fs.FileStatus

Expand Down Expand Up @@ -126,4 +126,5 @@ trait Spark30XShims extends SparkShims {
ss.sparkContext.defaultParallelism
}

override def getTypeSigUtil(): TypeSigUtil = TypeSigUtilUntil320
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ trait Spark32XShims extends SparkShims {
Spark32XShimsUtils.leafNodeDefaultParallelism(ss)
}

override def getTypeSigUtil(): TypeSigUtil = TypeSigUtilFrom320
}

// TODO dedupe utils inside shims
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.v2

import com.nvidia.spark.rapids.{TypeEnum, TypeSig, TypeSigUtil}

import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, YearMonthIntervalType}

/**
* Add DayTimeIntervalType and YearMonthIntervalType support
*/
object TypeSigUtilFrom320 extends TypeSigUtil {

override def isSupported(
check: TypeEnum.ValueSet,
dataType: DataType,
allowDecimal: Boolean): Boolean = {
dataType match {
case _: DayTimeIntervalType => check.contains(TypeEnum.DAYTIME)
case _: YearMonthIntervalType => check.contains(TypeEnum.YEARMONTH)
case _ => false
}
}

override def getAllSupportedTypes(): TypeEnum.ValueSet = TypeEnum.values

override def reasonNotSupported(
check: TypeEnum.ValueSet,
dataType: DataType,
allowDecimal: Boolean,
notSupportedReason: Seq[String]): Seq[String] = {
dataType match {
case _: DayTimeIntervalType =>
if (check.contains(TypeEnum.DAYTIME)) Seq.empty else notSupportedReason
case _: YearMonthIntervalType =>
if (check.contains(TypeEnum.YEARMONTH)) Seq.empty else notSupportedReason
case _ => notSupportedReason
}
}

override def getCastChecksAndSigs(
from: DataType,
default: TypeSig,
sparkDefault: TypeSig): (TypeSig, TypeSig) = {
from match {
case _: DayTimeIntervalType => (daytimeChecks, sparkDaytimeSig)
case _: YearMonthIntervalType =>(yearmonthChecks, sparkYearmonthSig)
case _ => (default, sparkDefault)
}
}

override def getCastChecksAndSigs(from: TypeEnum.Value): (TypeSig, TypeSig) = {
from match {
case TypeEnum.DAYTIME => (daytimeChecks, sparkDaytimeSig)
case TypeEnum.YEARMONTH => (yearmonthChecks, sparkYearmonthSig)
}
}

def daytimeChecks: TypeSig = TypeSig.none
def sparkDaytimeSig: TypeSig = TypeSig.DAYTIME + TypeSig.STRING

def yearmonthChecks: TypeSig = TypeSig.none
def sparkYearmonthSig: TypeSig = TypeSig.YEARMONTH + TypeSig.STRING
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ trait SparkShims {
def skipAssertIsOnTheGpu(plan: SparkPlan): Boolean

def leafNodeDefaultParallelism(ss: SparkSession): Int

def getTypeSigUtil(): TypeSigUtil

}

abstract class SparkCommonShims extends SparkShims {
Expand Down
Loading

0 comments on commit 0a49903

Please sign in to comment.