From 04f7f6dac0b9177e11482cca4e7ebf7b7564e45f Mon Sep 17 00:00:00 2001 From: Zhenhua Wang Date: Mon, 7 Sep 2020 06:26:14 +0000 Subject: [PATCH] [SPARK-32748][SQL] Support local property propagation in SubqueryBroadcastExec ### What changes were proposed in this pull request? Since [SPARK-22590](https://github.com/apache/spark/commit/2854091d12d670b014c41713e72153856f4d3f6a), local property propagation is supported through `SQLExecution.withThreadLocalCaptured` in both `BroadcastExchangeExec` and `SubqueryExec` when computing `relationFuture`. This pr adds the support in `SubqueryBroadcastExec`. ### Why are the changes needed? Local property propagation is missed in `SubqueryBroadcastExec`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add a new test. Closes #29589 from wzhfy/thread_local. Authored-by: Zhenhua Wang Signed-off-by: Wenchen Fan --- .../sql/execution/SubqueryBroadcastExec.scala | 16 +++-- .../internal/ExecutorSideSQLConfSuite.scala | 63 ++++++++++++++++++- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala index ddf0b72dd7a96..61ba8a034f445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SubqueryBroadcastExec.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution -import scala.concurrent.{ExecutionContext, Future} +import java.util.concurrent.{Future => JFuture} + +import scala.concurrent.ExecutionContext import scala.concurrent.duration.Duration import org.apache.spark.rdd.RDD @@ -26,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.execution.joins.{HashedRelation, HashJoin, LongHashedRelation} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.util.ThreadUtils /** @@ -60,10 +63,12 @@ case class SubqueryBroadcastExec( } @transient - private lazy val relationFuture: Future[Array[InternalRow]] = { + private lazy val relationFuture: JFuture[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - Future { + SQLExecution.withThreadLocalCaptured[Array[InternalRow]]( + sqlContext.sparkSession, + SubqueryBroadcastExec.executionContext) { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { @@ -89,7 +94,7 @@ case class SubqueryBroadcastExec( rows } - }(SubqueryBroadcastExec.executionContext) + } } protected override def doPrepare(): Unit = { @@ -110,5 +115,6 @@ case class SubqueryBroadcastExec( object SubqueryBroadcastExec { private[execution] val executionContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("dynamicpruning", 16)) + ThreadUtils.newDaemonCachedThreadPool("dynamic-pruning", + SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 567524ac75c2e..e11fe3f274085 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.{SparkException, SparkFunSuite, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.{FileSourceScanExec, InSubqueryExec, LeafExecNode, QueryExecution, SparkPlan, SubqueryBroadcastExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution import org.apache.spark.sql.execution.debug.codegenStringSeq import org.apache.spark.sql.functions.col @@ -188,6 +188,65 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { assert(checks2.forall(_.toSeq == Seq(true, true))) } } + + test("SPARK-32748: propagate local properties to dynamic pruning thread") { + val factTable = "fact_local_prop_dpp" + val dimTable = "dim_local_prop_dpp" + + def checkPropertyValueByUdfResult(propKey: String, propValue: String): Unit = { + spark.sparkContext.setLocalProperty(propKey, propValue) + val df = sql( + s""" + |SELECT compare_property_value(f.id, '$propKey', '$propValue') as col + |FROM $factTable f + |INNER JOIN $dimTable s + |ON f.id = s.id AND s.value < 3 + """.stripMargin) + + val subqueryBroadcastSeq = df.queryExecution.executedPlan.flatMap { + case s: FileSourceScanExec => s.partitionFilters.collect { + case DynamicPruningExpression(InSubqueryExec(_, b: SubqueryBroadcastExec, _, _)) => b + } + case _ => Nil + } + assert(subqueryBroadcastSeq.nonEmpty, + s"Should trigger DPP with a reused broadcast exchange:\n${df.queryExecution}") + + assert(df.collect().forall(_.toSeq == Seq(true))) + } + + withTable(factTable, dimTable) { + spark.range(10).select($"id", $"id".as("value")) + .write.partitionBy("id").mode("overwrite").saveAsTable(factTable) + spark.range(5).select($"id", $"id".as("value")) + .write.mode("overwrite").saveAsTable(dimTable) + + withSQLConf( + StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD.key -> "1", + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + + try { + spark.udf.register( + "compare_property_value", + (input: Int, propKey: String, propValue: String) => + TaskContext.get().getLocalProperty(propKey) == propValue + ) + val propKey = "spark.sql.subquery.broadcast.prop.key" + + // set local property and assert + val propValue1 = UUID.randomUUID().toString() + checkPropertyValueByUdfResult(propKey, propValue1) + + // change local property and re-assert + val propValue2 = UUID.randomUUID().toString() + checkPropertyValueByUdfResult(propKey, propValue2) + } finally { + spark.sessionState.catalog.dropTempFunction("compare_property_value", true) + } + } + } + } } case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {