Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unit tests when AQE is enabled #558

Merged
merged 7 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.TestUtils.{findOperator, operatorCount}

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.joins.HashJoin
import org.apache.spark.sql.functions.broadcast
Expand All @@ -35,13 +37,14 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite {
val df5 = df4.join(df3, Seq("longs"), "inner")

val plan = df5.queryExecution.executedPlan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, just one more nit. It would be better to get the reference to the executedPlan after calling collect rather than before. The goal is to get the final plan, so it would be clearer in this order. It isn't obvious to people looking this code that the executedPlan gets updated by the call to collect. Other than this, LGTM.

// execute the plan so that the final adaptive plan is available when AQE is on
df5.collect()

val bhjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin)
assert(bhjCount.size === 1)

assert(plan.collect {
case p if ShimLoader.getSparkShims.isGpuBroadcastHashJoin(p) => p
}.size === 1)
assert(plan.collect {
case p if ShimLoader.getSparkShims.isGpuShuffledHashJoin(p) => p
}.size === 1)
val shjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuShuffledHashJoin)
assert(shjCount.size === 1)
}, conf)
}

Expand All @@ -52,17 +55,21 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite {

for (name <- Seq("BROADCAST", "BROADCASTJOIN", "MAPJOIN")) {
val plan1 = spark.sql(s"SELECT /*+ $name(t) */ * FROM t JOIN u ON t.longs = u.longs")
.queryExecution.executedPlan
val plan2 = spark.sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.longs = u.longs")
.queryExecution.executedPlan

val res1 = plan1.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_))
val res2 = plan2.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_))
val initialPlan1 = plan1.queryExecution.executedPlan
// execute the plan so that the final adaptive plan is available when AQE is on
plan1.collect()
val finalPlan1 = findOperator(initialPlan1, ShimLoader.getSparkShims.isGpuBroadcastHashJoin)
andygrove marked this conversation as resolved.
Show resolved Hide resolved
assert(ShimLoader.getSparkShims.getBuildSide
(finalPlan1.get.asInstanceOf[HashJoin]).toString == "GpuBuildLeft")

assert(ShimLoader.getSparkShims.getBuildSide(res1.get.asInstanceOf[HashJoin]).toString ==
"GpuBuildLeft")
assert(ShimLoader.getSparkShims.getBuildSide(res2.get.asInstanceOf[HashJoin]).toString ==
"GpuBuildRight")
val initialPlan2 = plan2.queryExecution.executedPlan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here re removing intermediate variable

// execute the plan so that the final adaptive plan is available when AQE is on
plan2.collect()
andygrove marked this conversation as resolved.
Show resolved Hide resolved
val finalPlan2 = findOperator(initialPlan2, ShimLoader.getSparkShims.isGpuBroadcastHashJoin)
assert(ShimLoader.getSparkShims.
getBuildSide(finalPlan2.get.asInstanceOf[HashJoin]).toString == "GpuBuildRight")
}
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.TestUtils.{findOperator, getFinalPlan}
import org.scalatest.FunSuite

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -69,7 +70,10 @@ class HashSortOptimizeSuite extends FunSuite {
val df2 = buildDataFrame2(spark)
val rdf = df1.join(df2, df1("a") === df2("x"))
val plan = rdf.queryExecution.executedPlan
val joinNode = plan.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_))
// execute the plan so that the final adaptive plan is available when AQE is on
rdf.collect()

val joinNode = findOperator(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_))
assert(joinNode.isDefined, "No broadcast join node found")
validateOptimizeSort(plan, joinNode.get)
})
Expand All @@ -82,7 +86,9 @@ class HashSortOptimizeSuite extends FunSuite {
val df2 = buildDataFrame2(spark)
val rdf = df1.join(df2, df1("a") === df2("x"))
val plan = rdf.queryExecution.executedPlan
val joinNode = plan.find(ShimLoader.getSparkShims.isGpuShuffledHashJoin(_))
// execute the plan so that the final adaptive plan is available when AQE is on
rdf.collect()
val joinNode = findOperator(plan, ShimLoader.getSparkShims.isGpuShuffledHashJoin(_))
assert(joinNode.isDefined, "No broadcast join node found")
validateOptimizeSort(plan, joinNode.get)
})
Expand All @@ -106,7 +112,10 @@ class HashSortOptimizeSuite extends FunSuite {
val df2 = buildDataFrame2(spark)
val rdf = df1.join(df2, df1("a") === df2("x")).orderBy(df1("a"))
val plan = rdf.queryExecution.executedPlan
val numSorts = plan.map {
// Get the final executed plan when AQE is either enabled or disabled.
val finalPlan = getFinalPlan(plan)

val numSorts = finalPlan.map {
case _: SortExec | _: GpuSortExec => 1
case _ => 0
}.sum
Expand Down
44 changes: 44 additions & 0 deletions tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ import java.io.File

import ai.rapids.cudf.{ColumnVector, DType, Table}
import org.scalatest.Assertions
import scala.collection.mutable.ListBuffer

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -51,6 +54,47 @@ object TestUtils extends Assertions with Arm {
}
}

/** Recursively check if the predicate matches in the given plan */
def findOperator(plan: SparkPlan, predicate: SparkPlan => Boolean): Option[SparkPlan] = {
plan match {
case _ if predicate(plan) => Some(plan)
case a: AdaptiveSparkPlanExec => findOperator(a.executedPlan, predicate)
case qs: BroadcastQueryStageExec => findOperator(qs.broadcast, predicate)
case qs: ShuffleQueryStageExec => findOperator(qs.shuffle, predicate)
case other => other.children.flatMap(p => findOperator(p, predicate)).headOption
}
}

/** Return list of matching predicates present in the plan */
def operatorCount(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = {
def recurse(
plan: SparkPlan,
predicate: SparkPlan => Boolean,
accum: ListBuffer[SparkPlan]): Seq[SparkPlan] = {
plan match {
case _ if predicate(plan) =>
accum += plan
plan.children.flatMap(p => recurse(p, predicate, accum)).headOption
case a: AdaptiveSparkPlanExec => recurse(a.executedPlan, predicate, accum)
case qs: BroadcastQueryStageExec => recurse(qs.broadcast, predicate, accum)
case qs: ShuffleQueryStageExec => recurse(qs.shuffle, predicate, accum)
case other => other.children.flatMap(p => recurse(p, predicate, accum)).headOption
}
accum
}

recurse(plan, predicate, new ListBuffer[SparkPlan]())
}

/** Return final executed plan */
def getFinalPlan(plan: SparkPlan): SparkPlan = {
plan match {
case a: AdaptiveSparkPlanExec =>
a.executedPlan
case _ => plan
}
}

/** Compre the equality of two `ColumnVector` instances */
def compareColumns(expected: ColumnVector, actual: ColumnVector): Unit = {
assertResult(expected.getType)(actual.getType)
Expand Down