From b68daae878752b4b673aba0834b88943833e49e3 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Wed, 5 Aug 2020 11:19:30 -0700 Subject: [PATCH 1/6] Fix scala tests when AQE is enabled Signed-off-by: Niranjan Artal --- .../spark/rapids/HashSortOptimizeSuite.scala | 18 +++++++++++++++--- .../com/nvidia/spark/rapids/TestUtils.scala | 13 +++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala index 287a2a1f8e2..0b7979f471e 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala @@ -16,11 +16,13 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.TestUtils.findOperator import org.scalatest.FunSuite import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.execution.{SortExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec /** Test plan modifications to add optimizing sorts after hash joins in the plan */ class HashSortOptimizeSuite extends FunSuite { @@ -69,7 +71,10 @@ class HashSortOptimizeSuite extends FunSuite { val df2 = buildDataFrame2(spark) val rdf = df1.join(df2, df1("a") === df2("x")) val plan = rdf.queryExecution.executedPlan - val joinNode = plan.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) + // execute the plan so that the final adaptive plan is available when AQE is on + rdf.collect() + + val joinNode = findOperator(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) assert(joinNode.isDefined, "No broadcast join node found") validateOptimizeSort(plan, joinNode.get) }) @@ -82,7 +87,9 @@ class HashSortOptimizeSuite extends FunSuite { val df2 = buildDataFrame2(spark) val rdf = df1.join(df2, df1("a") === df2("x")) val plan = rdf.queryExecution.executedPlan - val joinNode = plan.find(ShimLoader.getSparkShims.isGpuShuffledHashJoin(_)) + // execute the plan so that the final adaptive plan is available when AQE is on + rdf.collect() + val joinNode = findOperator(plan, ShimLoader.getSparkShims.isGpuShuffledHashJoin(_)) assert(joinNode.isDefined, "No broadcast join node found") validateOptimizeSort(plan, joinNode.get) }) @@ -106,7 +113,12 @@ class HashSortOptimizeSuite extends FunSuite { val df2 = buildDataFrame2(spark) val rdf = df1.join(df2, df1("a") === df2("x")).orderBy(df1("a")) val plan = rdf.queryExecution.executedPlan - val numSorts = plan.map { + val finalPlan = plan match { + case a: AdaptiveSparkPlanExec => + a.executedPlan + case _ => plan + } + val numSorts = finalPlan.map { case _: SortExec | _: GpuSortExec => 1 case _ => 0 }.sum diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala index 4be02c8ccdd..9db82737ba4 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala @@ -23,6 +23,8 @@ import org.scalatest.Assertions import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.vectorized.ColumnarBatch @@ -51,6 +53,17 @@ object TestUtils extends Assertions with Arm { } } + /** Recursively check if the predicate matches in the given plan */ + def findOperator(plan: SparkPlan, predicate: SparkPlan => Boolean): Option[SparkPlan] = { + plan match { + case _ if predicate(plan) => Some(plan) + case a: AdaptiveSparkPlanExec => findOperator(a.executedPlan, predicate) + case qs: BroadcastQueryStageExec => findOperator(qs.broadcast, predicate) + case qs: ShuffleQueryStageExec => findOperator(qs.shuffle, predicate) + case other => other.children.flatMap(p => findOperator(p, predicate)).headOption + } + } + /** Compre the equality of two `ColumnVector` instances */ def compareColumns(expected: ColumnVector, actual: ColumnVector): Unit = { assertResult(expected.getType)(actual.getType) From 2e86fa0378426645fbe9c743d7612323fa1de77a Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Wed, 5 Aug 2020 17:18:11 -0700 Subject: [PATCH 2/6] fix broadcasthashjoin tests Signed-off-by: Niranjan Artal --- .../spark/rapids/BroadcastHashJoinSuite.scala | 35 +++++++++++-------- .../com/nvidia/spark/rapids/TestUtils.scala | 24 ++++++++++++- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index f343927d055..5d7662f3426 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -16,6 +16,8 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.TestUtils.{findOperator, operatorCount} + import org.apache.spark.SparkConf import org.apache.spark.sql.execution.joins.HashJoin import org.apache.spark.sql.functions.broadcast @@ -35,13 +37,14 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { val df5 = df4.join(df3, Seq("longs"), "inner") val plan = df5.queryExecution.executedPlan + // execute the plan so that the final adaptive plan is available when AQE is on + df5.collect() + + val bhjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) + assert(bhjCount.size === 1) - assert(plan.collect { - case p if ShimLoader.getSparkShims.isGpuBroadcastHashJoin(p) => p - }.size === 1) - assert(plan.collect { - case p if ShimLoader.getSparkShims.isGpuShuffledHashJoin(p) => p - }.size === 1) + val shjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuShuffledHashJoin(_)) + assert(shjCount.size === 1) }, conf) } @@ -52,17 +55,21 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { for (name <- Seq("BROADCAST", "BROADCASTJOIN", "MAPJOIN")) { val plan1 = spark.sql(s"SELECT /*+ $name(t) */ * FROM t JOIN u ON t.longs = u.longs") - .queryExecution.executedPlan val plan2 = spark.sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.longs = u.longs") - .queryExecution.executedPlan - val res1 = plan1.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) - val res2 = plan2.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) + val finalplan1 = plan1.queryExecution.executedPlan + plan1.collect() + val finalPlan1 = findOperator(finalplan1, + ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) + assert(ShimLoader.getSparkShims.getBuildSide + (finalPlan1.get.asInstanceOf[HashJoin]).toString == "GpuBuildLeft") - assert(ShimLoader.getSparkShims.getBuildSide(res1.get.asInstanceOf[HashJoin]).toString == - "GpuBuildLeft") - assert(ShimLoader.getSparkShims.getBuildSide(res2.get.asInstanceOf[HashJoin]).toString == - "GpuBuildRight") + val finalplan2 = plan2.queryExecution.executedPlan + plan2.collect() + val finalPlan2 = findOperator(finalplan2, + ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) + assert(ShimLoader.getSparkShims. + getBuildSide(finalPlan2.get.asInstanceOf[HashJoin]).toString == "GpuBuildRight") } }) } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala index 9db82737ba4..2c508ce349c 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala @@ -20,7 +20,6 @@ import java.io.File import ai.rapids.cudf.{ColumnVector, DType, Table} import org.scalatest.Assertions - import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.SparkPlan @@ -28,6 +27,8 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, Broadcast import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.vectorized.ColumnarBatch +import scala.collection.mutable.ListBuffer + /** A collection of utility methods useful in tests. */ object TestUtils extends Assertions with Arm { def getTempDir(basename: String): File = new File( @@ -64,6 +65,27 @@ object TestUtils extends Assertions with Arm { } } + /** Return list of matching predicates present in the plan */ + def operatorCount(plan: SparkPlan, predicate: SparkPlan => Boolean): Seq[SparkPlan] = { + def recurse( + plan: SparkPlan, + predicate: SparkPlan => Boolean, + accum: ListBuffer[SparkPlan]): Seq[SparkPlan] = { + plan match { + case _ if predicate(plan) => + accum += plan + plan.children.flatMap(p => recurse(p, predicate, accum)).headOption + case a: AdaptiveSparkPlanExec => recurse(a.executedPlan, predicate, accum) + case qs: BroadcastQueryStageExec => recurse(qs.broadcast, predicate, accum) + case qs: ShuffleQueryStageExec => recurse(qs.shuffle, predicate, accum) + case other => other.children.flatMap(p => recurse(p, predicate, accum)).headOption + } + accum + } + + recurse(plan, predicate, new ListBuffer[SparkPlan]()) + } + /** Compre the equality of two `ColumnVector` instances */ def compareColumns(expected: ColumnVector, actual: ColumnVector): Unit = { assertResult(expected.getType)(actual.getType) From 853ce846dce907314a55ece03cae73b31a0fc2c3 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Thu, 13 Aug 2020 16:54:59 -0700 Subject: [PATCH 3/6] fix indentation Signed-off-by: Niranjan Artal --- tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala index 2c508ce349c..838b1cb4833 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala @@ -20,6 +20,8 @@ import java.io.File import ai.rapids.cudf.{ColumnVector, DType, Table} import org.scalatest.Assertions +import scala.collection.mutable.ListBuffer + import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.SparkPlan @@ -27,8 +29,6 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, Broadcast import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.vectorized.ColumnarBatch -import scala.collection.mutable.ListBuffer - /** A collection of utility methods useful in tests. */ object TestUtils extends Assertions with Arm { def getTempDir(basename: String): File = new File( From 4dff1f15f5c57711379175048795ef51d22eb7a3 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Mon, 17 Aug 2020 16:22:28 -0700 Subject: [PATCH 4/6] addressed review comments Signed-off-by: Niranjan Artal --- .../spark/rapids/BroadcastHashJoinSuite.scala | 16 ++++++++-------- .../spark/rapids/HashSortOptimizeSuite.scala | 11 ++++------- .../com/nvidia/spark/rapids/TestUtils.scala | 9 +++++++++ 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index 5d7662f3426..4cdc7ce4862 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -40,10 +40,10 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { // execute the plan so that the final adaptive plan is available when AQE is on df5.collect() - val bhjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) + val bhjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin) assert(bhjCount.size === 1) - val shjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuShuffledHashJoin(_)) + val shjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuShuffledHashJoin) assert(shjCount.size === 1) }, conf) } @@ -57,17 +57,17 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { val plan1 = spark.sql(s"SELECT /*+ $name(t) */ * FROM t JOIN u ON t.longs = u.longs") val plan2 = spark.sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.longs = u.longs") - val finalplan1 = plan1.queryExecution.executedPlan + val initialPlan1 = plan1.queryExecution.executedPlan + // execute the plan so that the final adaptive plan is available when AQE is on plan1.collect() - val finalPlan1 = findOperator(finalplan1, - ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) + val finalPlan1 = findOperator(initialPlan1, ShimLoader.getSparkShims.isGpuBroadcastHashJoin) assert(ShimLoader.getSparkShims.getBuildSide (finalPlan1.get.asInstanceOf[HashJoin]).toString == "GpuBuildLeft") - val finalplan2 = plan2.queryExecution.executedPlan + val initialPlan2 = plan2.queryExecution.executedPlan + // execute the plan so that the final adaptive plan is available when AQE is on plan2.collect() - val finalPlan2 = findOperator(finalplan2, - ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_)) + val finalPlan2 = findOperator(initialPlan2, ShimLoader.getSparkShims.isGpuBroadcastHashJoin) assert(ShimLoader.getSparkShims. getBuildSide(finalPlan2.get.asInstanceOf[HashJoin]).toString == "GpuBuildRight") } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala index 0b7979f471e..dc92c07e825 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala @@ -16,13 +16,12 @@ package com.nvidia.spark.rapids -import com.nvidia.spark.rapids.TestUtils.findOperator +import com.nvidia.spark.rapids.TestUtils.{findOperator, getFinalPlan} import org.scalatest.FunSuite import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.execution.{SortExec, SparkPlan} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec /** Test plan modifications to add optimizing sorts after hash joins in the plan */ class HashSortOptimizeSuite extends FunSuite { @@ -113,11 +112,9 @@ class HashSortOptimizeSuite extends FunSuite { val df2 = buildDataFrame2(spark) val rdf = df1.join(df2, df1("a") === df2("x")).orderBy(df1("a")) val plan = rdf.queryExecution.executedPlan - val finalPlan = plan match { - case a: AdaptiveSparkPlanExec => - a.executedPlan - case _ => plan - } + // Get the final executed plan when AQE is either enabled or disabled. + val finalPlan = getFinalPlan(plan) + val numSorts = finalPlan.map { case _: SortExec | _: GpuSortExec => 1 case _ => 0 diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala index 838b1cb4833..3fed04b3503 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala @@ -86,6 +86,15 @@ object TestUtils extends Assertions with Arm { recurse(plan, predicate, new ListBuffer[SparkPlan]()) } + /** Return final executed plan */ + def getFinalPlan(plan: SparkPlan): SparkPlan = { + plan match { + case a: AdaptiveSparkPlanExec => + a.executedPlan + case _ => plan + } + } + /** Compre the equality of two `ColumnVector` instances */ def compareColumns(expected: ColumnVector, actual: ColumnVector): Unit = { assertResult(expected.getType)(actual.getType) From 8a2c3001f5445340d6f0b0572a2a90ed20b93f89 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Tue, 18 Aug 2020 08:55:36 -0700 Subject: [PATCH 5/6] addressed review comments Signed-off-by: Niranjan Artal --- .../com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index 4cdc7ce4862..cbbec564f75 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -57,17 +57,17 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { val plan1 = spark.sql(s"SELECT /*+ $name(t) */ * FROM t JOIN u ON t.longs = u.longs") val plan2 = spark.sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.longs = u.longs") - val initialPlan1 = plan1.queryExecution.executedPlan // execute the plan so that the final adaptive plan is available when AQE is on plan1.collect() - val finalPlan1 = findOperator(initialPlan1, ShimLoader.getSparkShims.isGpuBroadcastHashJoin) + val finalPlan1 = findOperator(plan1.queryExecution.executedPlan, + ShimLoader.getSparkShims.isGpuBroadcastHashJoin) assert(ShimLoader.getSparkShims.getBuildSide (finalPlan1.get.asInstanceOf[HashJoin]).toString == "GpuBuildLeft") - val initialPlan2 = plan2.queryExecution.executedPlan // execute the plan so that the final adaptive plan is available when AQE is on plan2.collect() - val finalPlan2 = findOperator(initialPlan2, ShimLoader.getSparkShims.isGpuBroadcastHashJoin) + val finalPlan2 = findOperator(plan2.queryExecution.executedPlan, + ShimLoader.getSparkShims.isGpuBroadcastHashJoin) assert(ShimLoader.getSparkShims. getBuildSide(finalPlan2.get.asInstanceOf[HashJoin]).toString == "GpuBuildRight") } From e169f4ded5129f9559550b2af7dc97b42027fb21 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Tue, 18 Aug 2020 10:32:35 -0700 Subject: [PATCH 6/6] addressed review comments Signed-off-by: Niranjan Artal --- .../scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala index cbbec564f75..b68763be678 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala @@ -36,9 +36,9 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite { val df4 = longsDf(spark) val df5 = df4.join(df3, Seq("longs"), "inner") - val plan = df5.queryExecution.executedPlan // execute the plan so that the final adaptive plan is available when AQE is on df5.collect() + val plan = df5.queryExecution.executedPlan val bhjCount = operatorCount(plan, ShimLoader.getSparkShims.isGpuBroadcastHashJoin) assert(bhjCount.size === 1)