diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index f14aaab72a98f..e6681de8d6412 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -54,6 +54,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { + case (p, PredicateSubquery(_, Seq(e: Expression), _, _)) if !e.isInstanceOf[Predicate] => + // This predicate subquery is inserted by PartitionPruning rule, should not be rewritten. + p case (p, PredicateSubquery(sub, conditions, _, _)) => val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) Join(outerPlan, sub, LeftSemi, joinCond) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index e485b52b43f76..8016ef0a2a46d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -156,7 +156,30 @@ case class FileSourceScanExec( false } - @transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters) + private def isDynamicPartitionFilter(e: Expression): Boolean = + e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + + @transient private lazy val selectedPartitions = + relation.location.listFiles(partitionFilters.filterNot(isDynamicPartitionFilter)) + + // We can only determine the actual partitions at runtime when a dynamic partition filter is + // present. This is because such a filter relies on information that is only available at run + // time (for instance the keys used in the other side of a join). + @transient private lazy val dynamicallySelectedPartitions = { + val dynamicPartitionFilters = partitionFilters.filter(isDynamicPartitionFilter) + if (dynamicPartitionFilters.nonEmpty) { + val predicate = dynamicPartitionFilters.reduce(And) + val partitionColumns = relation.partitionSchema + val boundPredicate = newPredicate(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }, Nil) + selectedPartitions.filter(p => boundPredicate.eval(p.values)) + } else { + selectedPartitions + } + } override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) { @@ -261,9 +284,9 @@ case class FileSourceScanExec( relation.bucketSpec match { case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled => - createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation) + createBucketedReadRDD(bucketing, readFile, dynamicallySelectedPartitions, relation) case _ => - createNonBucketedReadRDD(readFile, selectedPartitions, relation) + createNonBucketedReadRDD(readFile, dynamicallySelectedPartitions, relation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 2c7879286e63b..01c33818fe8ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog -import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{CombineFilters, Optimizer, PushDownPredicate, PushPredicateThroughJoin} +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.closure.TranslateClosureOptimizerRule -import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate import org.apache.spark.sql.internal.SQLConf @@ -43,6 +45,13 @@ class SparkOptimizer( // Java closure to Catalyst expressions Batch("Translate Closure", Once, new TranslateClosureOptimizerRule(conf))) ++ defaultOptimizers :+ + Batch("PartitionPruning", Once, + PartitionPruning(conf), + OptimizeSubqueries) :+ + Batch("Pushdown pruning subquery", fixedPoint, + PushPredicateThroughJoin, + PushDownPredicate, + CombineFilters) :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ @@ -50,3 +59,95 @@ class SparkOptimizer( experimentalMethods.extraOptimizations ++ extraOptimizationRules: _*) } } + +/** + * Inserts a predicate for partitioned table when partition column is used as join key. + */ +case class PartitionPruning(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Returns whether an attribute is a partition column or not. + */ + private def isPartitioned(a: Expression, plan: LogicalPlan): Boolean = { + plan.foreach { + case l: LogicalRelation if a.references.subsetOf(l.outputSet) => + l.relation match { + case fs: HadoopFsRelation => + val partitionColumns = AttributeSet( + l.resolve(fs.partitionSchema, fs.sparkSession.sessionState.analyzer.resolver)) + if (a.references.subsetOf(partitionColumns)) { + return true + } + case _ => + } + case _ => + } + false + } + + private def insertPredicate( + partitionedPlan: LogicalPlan, + partitioned: Expression, + otherPlan: LogicalPlan, + value: Expression): LogicalPlan = { + val alias = value match { + case a: Attribute => a + case o => Alias(o, o.toString)() + } + Filter( + PredicateSubquery(Aggregate(Seq(alias), Seq(alias), otherPlan), Seq(partitioned)), + partitionedPlan) + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.partitionPruning) { + return plan + } + plan transformUp { + case join @ Join(left, right, joinType, Some(condition)) => + var newLeft = left + var newRight = right + splitConjunctivePredicates(condition).foreach { + case e @ EqualTo(a: Expression, b: Expression) => + // they should come from different sides, otherwise should be pushed down + val (l, r) = if (a.references.subsetOf(left.outputSet) && + b.references.subsetOf(right.outputSet)) { + a -> b + } else { + b -> a + } + if (isPartitioned(l, left) && hasHighlySelectivePredicate(right) && + (joinType == Inner || joinType == LeftSemi || joinType == RightOuter) && + r.references.subsetOf(right.outputSet)) { + newLeft = insertPredicate(newLeft, l, right, r) + } else if (isPartitioned(r, right) && hasHighlySelectivePredicate(left) && + (joinType == Inner || joinType == LeftOuter) && + l.references.subsetOf(left.outputSet)) { + newRight = insertPredicate(newRight, r, left, l) + } + case _ => + } + Join(newLeft, newRight, joinType, Some(condition)) + } + } + + /** + * Returns whether an expression is highly selective or not. + */ + def isHighlySelective(e: Expression): Boolean = e match { + case Not(expr) => isHighlySelective(expr) + case And(l, r) => isHighlySelective(l) || isHighlySelective(r) + case Or(l, r) => isHighlySelective(l) && isHighlySelective(r) + case _: BinaryComparison => true + case _: In | _: InSet => true + case _: StringPredicate => true + case _ => false + } + + def hasHighlySelectivePredicate(plan: LogicalPlan): Boolean = { + plan.find { + case f: Filter => isHighlySelective(f.condition) + case _ => false + }.isDefined + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 55ca4f11068f9..cb45cc2914971 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -67,7 +67,7 @@ object FileSourceStrategy extends Strategy with Logging { val normalizedFilters = filters.map { e => e transform { case a: AttributeReference => - a.withName(l.output.find(_.semanticEquals(a)).get.name) + a.withName(l.output.find(_.semanticEquals(a)).getOrElse(a).name) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 8566a8061034b..af4673a855cb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -52,8 +52,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { logicalRelation.resolve( partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) - val partitionKeyFilters = - ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) + val partitionKeyFilters = ExpressionSet(normalizedFilters.filter { f => + f.references.subsetOf(partitionSet) && f.find(_.isInstanceOf[SubqueryExpression]).isEmpty + }) if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f0375c1ed618f..307b1a53941f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -417,6 +417,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARTITION_PRUNING = SQLConfigBuilder("spark.sql.dynamicPartitionPruning") + .internal() + .doc("When true, we will generate predicate for partition column when it's used as join key") + .booleanConf + .createWithDefault(true) + val WHOLESTAGE_CODEGEN_ENABLED = SQLConfigBuilder("spark.sql.codegen.wholeStage") .internal() .doc("When true, the whole stage (of multiple operators) will be compiled into single java" + @@ -790,6 +796,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) + def partitionPruning: Boolean = getConf(PARTITION_PRUNING) + def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED) def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 806381008aba6..fbda29230cf71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -22,8 +22,9 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.{AccumulatorSuite, SparkException} +import org.apache.spark.sql.catalyst.expressions.PlanExpression import org.apache.spark.sql.catalyst.util.StringUtils -import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.{aggregate, FileSourceScanExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2086,6 +2087,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("dynamic partition pruning") { + withTempDir { dir => + val df = spark.range(100).selectExpr("id", "id as k") + df.write.mode("overwrite").partitionBy("k").parquet(dir.toString) + val df2 = spark.read.parquet(dir.toString).join(df.filter("id < 2"), "k") + assert(df2.queryExecution.executedPlan.find { + case s: FileSourceScanExec => + s.partitionFilters.exists(_.find(_.isInstanceOf[PlanExpression[_]]).isDefined) + case o => false + }.isDefined, "Parquet scan should have partition predicate") + checkAnswer(df2, Row(0, 0, 0) :: Row(1, 1, 1) :: Nil) + } + } + test("SPARK-14986: Outer lateral view with empty generate expression") { checkAnswer( sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index e607af67f93e5..b702b87a7cba6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -2011,6 +2011,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("dynamic partition pruning") { + withTable("df1", "df2") { + spark.range(100) + .select($"id", $"id".as("k")) + .write + .partitionBy("k") + .format("parquet") + .mode("overwrite") + .saveAsTable("df1") + + spark.range(100) + .select($"id", $"id".as("k")) + .write + .partitionBy("k") + .format("parquet") + .mode("overwrite") + .saveAsTable("df2") + + checkAnswer( + sql("select df1.id, df2.k from df1 join df2 on df1.k = df2.k and df2.id < 2"), + Row(0, 0) :: Row(1, 1) :: Nil) + } + } + def testCommandAvailable(command: String): Boolean = { val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue()) attempt.isSuccess && attempt.get == 0