Skip to content

Commit

Permalink
Add in neverReplaceExec and several rules for it (#660)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Sep 4, 2020
1 parent 9c4070c commit a0101d1
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 44 deletions.
125 changes: 83 additions & 42 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,23 @@ import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.{AlterNamespaceSetPropertiesExec, AlterTableExec, AtomicReplaceTableExec, BatchScanExec, CreateNamespaceExec, CreateTableExec, DeleteFromTableExec, DescribeNamespaceExec, DescribeTableExec, DropNamespaceExec, DropTableExec, RefreshTableExec, RenameTableExec, ReplaceTableExec, SetCatalogAndNamespaceExec, ShowCurrentNamespaceExec, ShowNamespacesExec, ShowTablePropertiesExec, ShowTablesExec}
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
import org.apache.spark.sql.rapids.execution.{GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta, GpuCustomShuffleReaderExec, GpuShuffleMeta}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String

/**
* Base class for all ReplacementRules
Expand All @@ -71,6 +69,7 @@ abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPU

private var _incompatDoc: Option[String] = None
private var _disabledDoc: Option[String] = None
private var _visible: Boolean = true

override def incompatDoc: Option[String] = _incompatDoc
override def disabledMsg: Option[String] = _disabledDoc
Expand All @@ -95,6 +94,11 @@ abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPU
this
}

final def invisible(): this.type = {
_visible = false
this
}

/**
* Provide a function that will wrap a spark type in a [[RapidsMeta]] instance that is used for
* conversion to a GPU version.
Expand Down Expand Up @@ -139,32 +143,34 @@ abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPU
}

def confHelp(asTable: Boolean = false, sparkSQLFunctions: Option[String] = None): Unit = {
val notesMsg = notes()
if (asTable) {
import ConfHelper.makeConfAnchor
print(s"${makeConfAnchor(confKey)}")
if (sparkSQLFunctions.isDefined) {
print(s"|${sparkSQLFunctions.get}")
}
print(s"|$desc|${notesMsg.isEmpty}|")
if (notesMsg.isDefined) {
print(s"${notesMsg.get}")
if (_visible) {
val notesMsg = notes()
if (asTable) {
import ConfHelper.makeConfAnchor
print(s"${makeConfAnchor(confKey)}")
if (sparkSQLFunctions.isDefined) {
print(s"|${sparkSQLFunctions.get}")
}
print(s"|$desc|${notesMsg.isEmpty}|")
if (notesMsg.isDefined) {
print(s"${notesMsg.get}")
} else {
print("None")
}
println("|")
} else {
print("None")
}
println("|")
} else {
println(s"$confKey:")
println(s"\tEnable (true) or disable (false) the $tag $operationName.")
if (sparkSQLFunctions.isDefined) {
println(s"\tsql function: ${sparkSQLFunctions.get}")
}
println(s"\t$desc")
if (notesMsg.isDefined) {
println(s"\t${notesMsg.get}")
println(s"$confKey:")
println(s"\tEnable (true) or disable (false) the $tag $operationName.")
if (sparkSQLFunctions.isDefined) {
println(s"\tsql function: ${sparkSQLFunctions.get}")
}
println(s"\t$desc")
if (notesMsg.isDefined) {
println(s"\t${notesMsg.get}")
}
println(s"\tdefault: ${notesMsg.isEmpty}")
println()
}
println(s"\tdefault: ${notesMsg.isEmpty}")
println()
}
}

Expand Down Expand Up @@ -436,6 +442,22 @@ object GpuOverrides {
new PartRule[INPUT](doWrap, desc, tag)
}

/**
* Create an exec rule that should never be replaced, because it is something that should always
* run on the CPU, or should just be ignored totally for what ever reason.
*/
def neverReplaceExec[INPUT <: SparkPlan](desc: String)
(implicit tag: ClassTag[INPUT]): ExecRule[INPUT] = {
assert(desc != null)
def doWrap(
exec: INPUT,
conf: RapidsConf,
p: Option[RapidsMeta[_, _, _]],
cc: ConfKeysAndIncompat) =
new DoNotReplaceOrWarnSparkPlanMeta[INPUT](exec, conf, p)
new ExecRule[INPUT](doWrap, desc, tag).invisible()
}

def exec[INPUT <: SparkPlan](
desc: String,
doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], ConfKeysAndIncompat)
Expand Down Expand Up @@ -879,7 +901,7 @@ object GpuOverrides {
GpuKnownFloatingPointNormalized(child)
}),
expr[DateDiff](
"Returns the number of days from startDate to endDate",
"Returns the number of days from startDate to endDate",
(a, conf, p, r) => new BinaryExprMeta[DateDiff](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuDateDiff(lhs, rhs)
Expand Down Expand Up @@ -1544,7 +1566,7 @@ object GpuOverrides {
override def convertToGpu(): GpuPartitioning =
GpuHashPartitioning(childExprs.map(_.convertToGpu()), hp.numPartitions)
}),
part[RangePartitioning](
part[RangePartitioning](
"Range partitioning",
(rp, conf, p, r) => new PartMeta[RangePartitioning](rp, conf, p, r) {
override val childExprs: Seq[BaseExprMeta[_]] =
Expand All @@ -1565,14 +1587,14 @@ object GpuOverrides {
}
}
}),
part[RoundRobinPartitioning](
part[RoundRobinPartitioning](
"Round robin partitioning",
(rrp, conf, p, r) => new PartMeta[RoundRobinPartitioning](rrp, conf, p, r) {
override def convertToGpu(): GpuPartitioning = {
GpuRoundRobinPartitioning(rrp.numPartitions)
}
}),
part[SinglePartition.type](
part[SinglePartition.type](
"Single partitioning",
(sp, conf, p, r) => new PartMeta[SinglePartition.type](sp, conf, p, r) {
override val childExprs: Seq[ExprMeta[_]] = Seq.empty[ExprMeta[_]]
Expand Down Expand Up @@ -1736,12 +1758,28 @@ object GpuOverrides {
exec.partitionSpecs)
}
}),
exec[AdaptiveSparkPlanExec]("Wrapper for adaptive query plan", (exec, conf, p, _) =>
new DoNotReplaceSparkPlanMeta[AdaptiveSparkPlanExec](exec, conf, p)),
exec[BroadcastQueryStageExec]("Broadcast query stage", (exec, conf, p, _) =>
new DoNotReplaceSparkPlanMeta[BroadcastQueryStageExec](exec, conf, p)),
exec[ShuffleQueryStageExec]("Shuffle query stage", (exec, conf, p, _) =>
new DoNotReplaceSparkPlanMeta[ShuffleQueryStageExec](exec, conf, p))
neverReplaceExec[AlterNamespaceSetPropertiesExec]("Namespace metadata operation"),
neverReplaceExec[CreateNamespaceExec]("Namespace metadata operation"),
neverReplaceExec[DescribeNamespaceExec]("Namespace metadata operation"),
neverReplaceExec[DropNamespaceExec]("Namespace metadata operation"),
neverReplaceExec[SetCatalogAndNamespaceExec]("Namesapce metadata operation"),
neverReplaceExec[ShowCurrentNamespaceExec]("Namesapce metadata operation"),
neverReplaceExec[ShowNamespacesExec]("Namesapce metadata operation"),
neverReplaceExec[ExecutedCommandExec]("Table metadata operation"),
neverReplaceExec[AlterTableExec]("Table metadata operation"),
neverReplaceExec[CreateTableExec]("Table metadata operation"),
neverReplaceExec[DeleteFromTableExec]("Table metadata operation"),
neverReplaceExec[DescribeTableExec]("Table metadata operation"),
neverReplaceExec[DropTableExec]("Table metadata operation"),
neverReplaceExec[AtomicReplaceTableExec]("Table metadata operation"),
neverReplaceExec[RefreshTableExec]("Table metadata operation"),
neverReplaceExec[RenameTableExec]("Table metadata operation"),
neverReplaceExec[ReplaceTableExec]("Table metadata operation"),
neverReplaceExec[ShowTablePropertiesExec]("Table metadata operation"),
neverReplaceExec[ShowTablesExec]("Table metadata operation"),
neverReplaceExec[AdaptiveSparkPlanExec]("Wrapper for adaptive query plan"),
neverReplaceExec[BroadcastQueryStageExec]("Broadcast query stage"),
neverReplaceExec[ShuffleQueryStageExec]("Shuffle query stage")
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
commonExecs ++ ShimLoader.getSparkShims.getExecs
Expand All @@ -1766,7 +1804,10 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging {
wrap.runAfterTagRules()
val exp = conf.explain
if (!exp.equalsIgnoreCase("NONE")) {
logWarning(s"\n${wrap.explain(exp.equalsIgnoreCase("ALL"))}")
val explain = wrap.explain(exp.equalsIgnoreCase("ALL"))
if (!explain.isEmpty) {
logWarning(s"\n$explain")
}
}
val convertedPlan = wrap.convertIfNeeded()
addSortsIfNeeded(convertedPlan, conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,9 @@ final class RuleNotFoundSparkPlanMeta[INPUT <: SparkPlan](
}

/**
* Metadata for `SparkPlan` that should not be replaced.
* Metadata for `SparkPlan` that should not be replaced or have any kind of warning for
*/
final class DoNotReplaceSparkPlanMeta[INPUT <: SparkPlan](
final class DoNotReplaceOrWarnSparkPlanMeta[INPUT <: SparkPlan](
plan: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
Expand Down

0 comments on commit a0101d1

Please sign in to comment.