Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in neverReplaceExec and several rules for it #660

Merged
merged 1 commit into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 83 additions & 42 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,23 @@ import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, CustomShuffleReaderExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.{AlterNamespaceSetPropertiesExec, AlterTableExec, AtomicReplaceTableExec, BatchScanExec, CreateNamespaceExec, CreateTableExec, DeleteFromTableExec, DescribeNamespaceExec, DescribeTableExec, DropNamespaceExec, DropTableExec, RefreshTableExec, RenameTableExec, ReplaceTableExec, SetCatalogAndNamespaceExec, ShowCurrentNamespaceExec, ShowNamespacesExec, ShowTablePropertiesExec, ShowTablesExec}
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
import org.apache.spark.sql.rapids.execution.{GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta, GpuCustomShuffleReaderExec, GpuShuffleMeta}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String

/**
* Base class for all ReplacementRules
Expand All @@ -71,6 +69,7 @@ abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPU

private var _incompatDoc: Option[String] = None
private var _disabledDoc: Option[String] = None
private var _visible: Boolean = true

override def incompatDoc: Option[String] = _incompatDoc
override def disabledMsg: Option[String] = _disabledDoc
Expand All @@ -95,6 +94,11 @@ abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPU
this
}

final def invisible(): this.type = {
_visible = false
this
}

/**
* Provide a function that will wrap a spark type in a [[RapidsMeta]] instance that is used for
* conversion to a GPU version.
Expand Down Expand Up @@ -139,32 +143,34 @@ abstract class ReplacementRule[INPUT <: BASE, BASE, WRAP_TYPE <: RapidsMeta[INPU
}

def confHelp(asTable: Boolean = false, sparkSQLFunctions: Option[String] = None): Unit = {
val notesMsg = notes()
if (asTable) {
import ConfHelper.makeConfAnchor
print(s"${makeConfAnchor(confKey)}")
if (sparkSQLFunctions.isDefined) {
print(s"|${sparkSQLFunctions.get}")
}
print(s"|$desc|${notesMsg.isEmpty}|")
if (notesMsg.isDefined) {
print(s"${notesMsg.get}")
if (_visible) {
val notesMsg = notes()
if (asTable) {
import ConfHelper.makeConfAnchor
print(s"${makeConfAnchor(confKey)}")
if (sparkSQLFunctions.isDefined) {
print(s"|${sparkSQLFunctions.get}")
}
print(s"|$desc|${notesMsg.isEmpty}|")
if (notesMsg.isDefined) {
print(s"${notesMsg.get}")
} else {
print("None")
}
println("|")
} else {
print("None")
}
println("|")
} else {
println(s"$confKey:")
println(s"\tEnable (true) or disable (false) the $tag $operationName.")
if (sparkSQLFunctions.isDefined) {
println(s"\tsql function: ${sparkSQLFunctions.get}")
}
println(s"\t$desc")
if (notesMsg.isDefined) {
println(s"\t${notesMsg.get}")
println(s"$confKey:")
println(s"\tEnable (true) or disable (false) the $tag $operationName.")
if (sparkSQLFunctions.isDefined) {
println(s"\tsql function: ${sparkSQLFunctions.get}")
}
println(s"\t$desc")
if (notesMsg.isDefined) {
println(s"\t${notesMsg.get}")
}
println(s"\tdefault: ${notesMsg.isEmpty}")
println()
}
println(s"\tdefault: ${notesMsg.isEmpty}")
println()
}
}

Expand Down Expand Up @@ -436,6 +442,22 @@ object GpuOverrides {
new PartRule[INPUT](doWrap, desc, tag)
}

/**
* Create an exec rule that should never be replaced, because it is something that should always
* run on the CPU, or should just be ignored totally for what ever reason.
*/
def neverReplaceExec[INPUT <: SparkPlan](desc: String)
(implicit tag: ClassTag[INPUT]): ExecRule[INPUT] = {
assert(desc != null)
def doWrap(
exec: INPUT,
conf: RapidsConf,
p: Option[RapidsMeta[_, _, _]],
cc: ConfKeysAndIncompat) =
new DoNotReplaceOrWarnSparkPlanMeta[INPUT](exec, conf, p)
new ExecRule[INPUT](doWrap, desc, tag).invisible()
}

def exec[INPUT <: SparkPlan](
desc: String,
doWrap: (INPUT, RapidsConf, Option[RapidsMeta[_, _, _]], ConfKeysAndIncompat)
Expand Down Expand Up @@ -879,7 +901,7 @@ object GpuOverrides {
GpuKnownFloatingPointNormalized(child)
}),
expr[DateDiff](
"Returns the number of days from startDate to endDate",
"Returns the number of days from startDate to endDate",
(a, conf, p, r) => new BinaryExprMeta[DateDiff](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
GpuDateDiff(lhs, rhs)
Expand Down Expand Up @@ -1544,7 +1566,7 @@ object GpuOverrides {
override def convertToGpu(): GpuPartitioning =
GpuHashPartitioning(childExprs.map(_.convertToGpu()), hp.numPartitions)
}),
part[RangePartitioning](
part[RangePartitioning](
"Range partitioning",
(rp, conf, p, r) => new PartMeta[RangePartitioning](rp, conf, p, r) {
override val childExprs: Seq[BaseExprMeta[_]] =
Expand All @@ -1565,14 +1587,14 @@ object GpuOverrides {
}
}
}),
part[RoundRobinPartitioning](
part[RoundRobinPartitioning](
"Round robin partitioning",
(rrp, conf, p, r) => new PartMeta[RoundRobinPartitioning](rrp, conf, p, r) {
override def convertToGpu(): GpuPartitioning = {
GpuRoundRobinPartitioning(rrp.numPartitions)
}
}),
part[SinglePartition.type](
part[SinglePartition.type](
"Single partitioning",
(sp, conf, p, r) => new PartMeta[SinglePartition.type](sp, conf, p, r) {
override val childExprs: Seq[ExprMeta[_]] = Seq.empty[ExprMeta[_]]
Expand Down Expand Up @@ -1736,12 +1758,28 @@ object GpuOverrides {
exec.partitionSpecs)
}
}),
exec[AdaptiveSparkPlanExec]("Wrapper for adaptive query plan", (exec, conf, p, _) =>
new DoNotReplaceSparkPlanMeta[AdaptiveSparkPlanExec](exec, conf, p)),
exec[BroadcastQueryStageExec]("Broadcast query stage", (exec, conf, p, _) =>
new DoNotReplaceSparkPlanMeta[BroadcastQueryStageExec](exec, conf, p)),
exec[ShuffleQueryStageExec]("Shuffle query stage", (exec, conf, p, _) =>
new DoNotReplaceSparkPlanMeta[ShuffleQueryStageExec](exec, conf, p))
neverReplaceExec[AlterNamespaceSetPropertiesExec]("Namespace metadata operation"),
neverReplaceExec[CreateNamespaceExec]("Namespace metadata operation"),
neverReplaceExec[DescribeNamespaceExec]("Namespace metadata operation"),
neverReplaceExec[DropNamespaceExec]("Namespace metadata operation"),
neverReplaceExec[SetCatalogAndNamespaceExec]("Namesapce metadata operation"),
neverReplaceExec[ShowCurrentNamespaceExec]("Namesapce metadata operation"),
neverReplaceExec[ShowNamespacesExec]("Namesapce metadata operation"),
neverReplaceExec[ExecutedCommandExec]("Table metadata operation"),
neverReplaceExec[AlterTableExec]("Table metadata operation"),
neverReplaceExec[CreateTableExec]("Table metadata operation"),
neverReplaceExec[DeleteFromTableExec]("Table metadata operation"),
neverReplaceExec[DescribeTableExec]("Table metadata operation"),
neverReplaceExec[DropTableExec]("Table metadata operation"),
neverReplaceExec[AtomicReplaceTableExec]("Table metadata operation"),
neverReplaceExec[RefreshTableExec]("Table metadata operation"),
neverReplaceExec[RenameTableExec]("Table metadata operation"),
neverReplaceExec[ReplaceTableExec]("Table metadata operation"),
neverReplaceExec[ShowTablePropertiesExec]("Table metadata operation"),
neverReplaceExec[ShowTablesExec]("Table metadata operation"),
neverReplaceExec[AdaptiveSparkPlanExec]("Wrapper for adaptive query plan"),
neverReplaceExec[BroadcastQueryStageExec]("Broadcast query stage"),
neverReplaceExec[ShuffleQueryStageExec]("Shuffle query stage")
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
commonExecs ++ ShimLoader.getSparkShims.getExecs
Expand All @@ -1766,7 +1804,10 @@ case class GpuOverrides() extends Rule[SparkPlan] with Logging {
wrap.runAfterTagRules()
val exp = conf.explain
if (!exp.equalsIgnoreCase("NONE")) {
logWarning(s"\n${wrap.explain(exp.equalsIgnoreCase("ALL"))}")
val explain = wrap.explain(exp.equalsIgnoreCase("ALL"))
if (!explain.isEmpty) {
logWarning(s"\n$explain")
}
}
val convertedPlan = wrap.convertIfNeeded()
addSortsIfNeeded(convertedPlan, conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,9 @@ final class RuleNotFoundSparkPlanMeta[INPUT <: SparkPlan](
}

/**
* Metadata for `SparkPlan` that should not be replaced.
* Metadata for `SparkPlan` that should not be replaced or have any kind of warning for
*/
final class DoNotReplaceSparkPlanMeta[INPUT <: SparkPlan](
final class DoNotReplaceOrWarnSparkPlanMeta[INPUT <: SparkPlan](
plan: INPUT,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]])
Expand Down