Skip to content

Commit

Permalink
Preliminary support for keeping broadcast exchanges on GPU when AQE i…
Browse files Browse the repository at this point in the history
…s enabled (NVIDIA#448)
  • Loading branch information
andygrove authored Jul 31, 2020
1 parent 2891ac4 commit 1d1720c
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.spark300.RapidsShuffleManager

import org.apache.spark.SparkEnv
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
Expand Down Expand Up @@ -193,4 +195,11 @@ class Spark300Shims extends SparkShims {
override def getRapidsShuffleManagerClass: String = {
classOf[RapidsShuffleManager].getCanonicalName
}

override def injectQueryStagePrepRule(
extensions: SparkSessionExtensions,
ruleBuilder: SparkSession => Rule[SparkPlan]): Unit = {
// not supported in 3.0.0 but it doesn't matter because AdaptiveSparkPlanExec in 3.0.0 will
// never allow us to replace an Exchange node, so they just stay on CPU
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.spark300.Spark300Shims
import com.nvidia.spark.rapids.spark301.RapidsShuffleManager

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

class Spark301Shims extends Spark300Shims {

Expand All @@ -47,4 +50,10 @@ class Spark301Shims extends Spark300Shims {
override def getRapidsShuffleManagerClass: String = {
classOf[RapidsShuffleManager].getCanonicalName
}

override def injectQueryStagePrepRule(
extensions: SparkSessionExtensions,
ruleBuilder: SparkSession => Rule[SparkPlan]): Unit = {
extensions.injectQueryStagePrepRule(ruleBuilder)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1744,6 +1744,16 @@ object GpuOverrides {
val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
commonExecs ++ ShimLoader.getSparkShims.getExecs
}
/** Tag the initial plan when AQE is enabled */
case class GpuQueryStagePrepOverrides() extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan) :SparkPlan = {
// Note that we disregard the GPU plan returned here and instead rely on side effects of
// tagging the underlying SparkPlan.
GpuOverrides().apply(plan)
// return the original plan which is now modified as a side-effect of invoking GpuOverrides
plan
}
}

case class GpuOverrides() extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan) :SparkPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging {
logWarning("Installing extensions to enable rapids GPU SQL support." +
s" To disable GPU support set `${RapidsConf.SQL_ENABLED}` to false")
extensions.injectColumnar(_ => ColumnarOverrideRules())
ShimLoader.getSparkShims.injectQueryStagePrepRule(extensions, _ => GpuQueryStagePrepOverrides())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.nvidia.spark.rapids.GpuOverrides.isStringLit
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ComplexTypeMergingExpression, Expression, String2TrimExpression, TernaryExpression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand
Expand Down Expand Up @@ -116,12 +117,21 @@ abstract class RapidsMeta[INPUT <: BASE, BASE, OUTPUT <: BASE](

private var shouldBeRemovedReasons: Option[mutable.Set[String]] = None

val gpuSupportedTag = TreeNodeTag[String]("rapids.gpu.supported")

/**
* Call this to indicate that this should not be replaced with a GPU enabled version
* @param because why it should not be replaced.
*/
final def willNotWorkOnGpu(because: String): Unit =
final def willNotWorkOnGpu(because: String): Unit = {
cannotBeReplacedReasons.get.add(because)
// annotate the real spark plan with the reason as well so that the information is available
// during query stage planning when AQE is on
wrapped match {
case p: SparkPlan => p.setTagValue(gpuSupportedTag, because)
case _ =>
}
}

final def shouldBeRemoved(because: String): Unit =
shouldBeRemovedReasons.get.add(because)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.nvidia.spark.rapids

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
Expand Down Expand Up @@ -68,4 +70,8 @@ trait SparkShims {
endMapIndex: Int,
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

def injectQueryStagePrepRule(
extensions: SparkSessionExtensions,
rule: SparkSession => Rule[SparkPlan])
}
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,18 @@ class GpuBroadcastMeta(
case _: BroadcastNestedLoopJoinExec => true
case _ => false
}
if (!parent.exists(isSupported)) {
willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " +
"with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec")
if (parent.isDefined) {
if (!parent.exists(isSupported)) {
willNotWorkOnGpu("BroadcastExchange only works on the GPU if being used " +
"with a GPU version of BroadcastHashJoinExec or BroadcastNestedLoopJoinExec")
}
} else {
// when AQE is enabled and we are planning a new query stage, parent will be None so
// we need to look at meta-data previously stored on the spark plan
wrapped.getTagValue(gpuSupportedTag) match {
case Some(reason) => willNotWorkOnGpu(reason)
case None => // this broadcast is supported on GPU
}
}
}

Expand Down

0 comments on commit 1d1720c

Please sign in to comment.