Skip to content

Commit

Permalink
[SC-3623][BRANCH-2.1] Dynamic partition pruning
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This PR ports databricks/runtime#31 over to DB Spark branch-2.1. This adds dynamic partition pruning (see the original PR for more details on the feature).

This was non-trivial to port because the read path has changed significantly in Spark 2.1. We only support partition pruning for `HadoopFsRelation`. This relation is exclusively read using the `FileSourceScanExec` and so I have only implemented dynamic partition pruning for this scan operator (in 2.0 we support dynamic partition pruning for both `RowScanExec` and `FileSourceScanExec`).

## How was this patch tested?
Added a test to `SQLQuerySuite`.

Author: Herman van Hovell <[email protected]>
Author: Davies Liu <[email protected]>

Closes apache#131 from hvanhovell/dynamic_partition_pruning.
  • Loading branch information
Davies Liu authored and hvanhovell committed Dec 8, 2016
1 parent 27e7e3d commit 53e0bb3
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {

// Filter the plan by applying left semi and left anti joins.
withSubquery.foldLeft(newFilter) {
case (p, PredicateSubquery(_, Seq(e: Expression), _, _)) if !e.isInstanceOf[Predicate] =>
// This predicate subquery is inserted by PartitionPruning rule, should not be rewritten.
p
case (p, PredicateSubquery(sub, conditions, _, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
Join(outerPlan, sub, LeftSemi, joinCond)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,30 @@ case class FileSourceScanExec(
false
}

@transient private lazy val selectedPartitions = relation.location.listFiles(partitionFilters)
private def isDynamicPartitionFilter(e: Expression): Boolean =
e.find(_.isInstanceOf[PlanExpression[_]]).isDefined

@transient private lazy val selectedPartitions =
relation.location.listFiles(partitionFilters.filterNot(isDynamicPartitionFilter))

// We can only determine the actual partitions at runtime when a dynamic partition filter is
// present. This is because such a filter relies on information that is only available at run
// time (for instance the keys used in the other side of a join).
@transient private lazy val dynamicallySelectedPartitions = {
val dynamicPartitionFilters = partitionFilters.filter(isDynamicPartitionFilter)
if (dynamicPartitionFilters.nonEmpty) {
val predicate = dynamicPartitionFilters.reduce(And)
val partitionColumns = relation.partitionSchema
val boundPredicate = newPredicate(predicate.transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
}, Nil)
selectedPartitions.filter(p => boundPredicate.eval(p.values))
} else {
selectedPartitions
}
}

override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = {
val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) {
Expand Down Expand Up @@ -261,9 +284,9 @@ case class FileSourceScanExec(

relation.bucketSpec match {
case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled =>
createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation)
createBucketedReadRDD(bucketing, readFile, dynamicallySelectedPartitions, relation)
case _ =>
createNonBucketedReadRDD(readFile, selectedPartitions, relation)
createNonBucketedReadRDD(readFile, dynamicallySelectedPartitions, relation)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{CombineFilters, Optimizer, PushDownPredicate, PushPredicateThroughJoin}
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.closure.TranslateClosureOptimizerRule
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf

Expand All @@ -43,10 +45,109 @@ class SparkOptimizer(
// Java closure to Catalyst expressions
Batch("Translate Closure", Once, new TranslateClosureOptimizerRule(conf))) ++
defaultOptimizers :+
Batch("PartitionPruning", Once,
PartitionPruning(conf),
OptimizeSubqueries) :+
Batch("Pushdown pruning subquery", fixedPoint,
PushPredicateThroughJoin,
PushDownPredicate,
CombineFilters) :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog, conf)) :+
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
Batch("User Provided Optimizers", fixedPoint,
experimentalMethods.extraOptimizations ++ extraOptimizationRules: _*)
}
}

/**
* Inserts a predicate for partitioned table when partition column is used as join key.
*/
case class PartitionPruning(conf: SQLConf) extends Rule[LogicalPlan] with PredicateHelper {

/**
* Returns whether an attribute is a partition column or not.
*/
private def isPartitioned(a: Expression, plan: LogicalPlan): Boolean = {
plan.foreach {
case l: LogicalRelation if a.references.subsetOf(l.outputSet) =>
l.relation match {
case fs: HadoopFsRelation =>
val partitionColumns = AttributeSet(
l.resolve(fs.partitionSchema, fs.sparkSession.sessionState.analyzer.resolver))
if (a.references.subsetOf(partitionColumns)) {
return true
}
case _ =>
}
case _ =>
}
false
}

private def insertPredicate(
partitionedPlan: LogicalPlan,
partitioned: Expression,
otherPlan: LogicalPlan,
value: Expression): LogicalPlan = {
val alias = value match {
case a: Attribute => a
case o => Alias(o, o.toString)()
}
Filter(
PredicateSubquery(Aggregate(Seq(alias), Seq(alias), otherPlan), Seq(partitioned)),
partitionedPlan)
}

def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.partitionPruning) {
return plan
}
plan transformUp {
case join @ Join(left, right, joinType, Some(condition)) =>
var newLeft = left
var newRight = right
splitConjunctivePredicates(condition).foreach {
case e @ EqualTo(a: Expression, b: Expression) =>
// they should come from different sides, otherwise should be pushed down
val (l, r) = if (a.references.subsetOf(left.outputSet) &&
b.references.subsetOf(right.outputSet)) {
a -> b
} else {
b -> a
}
if (isPartitioned(l, left) && hasHighlySelectivePredicate(right) &&
(joinType == Inner || joinType == LeftSemi || joinType == RightOuter) &&
r.references.subsetOf(right.outputSet)) {
newLeft = insertPredicate(newLeft, l, right, r)
} else if (isPartitioned(r, right) && hasHighlySelectivePredicate(left) &&
(joinType == Inner || joinType == LeftOuter) &&
l.references.subsetOf(left.outputSet)) {
newRight = insertPredicate(newRight, r, left, l)
}
case _ =>
}
Join(newLeft, newRight, joinType, Some(condition))
}
}

/**
* Returns whether an expression is highly selective or not.
*/
def isHighlySelective(e: Expression): Boolean = e match {
case Not(expr) => isHighlySelective(expr)
case And(l, r) => isHighlySelective(l) || isHighlySelective(r)
case Or(l, r) => isHighlySelective(l) && isHighlySelective(r)
case _: BinaryComparison => true
case _: In | _: InSet => true
case _: StringPredicate => true
case _ => false
}

def hasHighlySelectivePredicate(plan: LogicalPlan): Boolean = {
plan.find {
case f: Filter => isHighlySelective(f.condition)
case _ => false
}.isDefined
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object FileSourceStrategy extends Strategy with Logging {
val normalizedFilters = filters.map { e =>
e transform {
case a: AttributeReference =>
a.withName(l.output.find(_.semanticEquals(a)).get.name)
a.withName(l.output.find(_.semanticEquals(a)).getOrElse(a).name)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
logicalRelation.resolve(
partitionSchema, sparkSession.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
val partitionKeyFilters =
ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet)))
val partitionKeyFilters = ExpressionSet(normalizedFilters.filter { f =>
f.references.subsetOf(partitionSet) && f.find(_.isInstanceOf[SubqueryExpression]).isEmpty
})

if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,12 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val PARTITION_PRUNING = SQLConfigBuilder("spark.sql.dynamicPartitionPruning")
.internal()
.doc("When true, we will generate predicate for partition column when it's used as join key")
.booleanConf
.createWithDefault(true)

val WHOLESTAGE_CODEGEN_ENABLED = SQLConfigBuilder("spark.sql.codegen.wholeStage")
.internal()
.doc("When true, the whole stage (of multiple operators) will be compiled into single java" +
Expand Down Expand Up @@ -790,6 +796,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP)

def partitionPruning: Boolean = getConf(PARTITION_PRUNING)

def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED)

def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH)
Expand Down
17 changes: 16 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ import java.math.MathContext
import java.sql.Timestamp

import org.apache.spark.{AccumulatorSuite, SparkException}
import org.apache.spark.sql.catalyst.expressions.PlanExpression
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.{aggregate, FileSourceScanExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -2086,6 +2087,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}

test("dynamic partition pruning") {
withTempDir { dir =>
val df = spark.range(100).selectExpr("id", "id as k")
df.write.mode("overwrite").partitionBy("k").parquet(dir.toString)
val df2 = spark.read.parquet(dir.toString).join(df.filter("id < 2"), "k")
assert(df2.queryExecution.executedPlan.find {
case s: FileSourceScanExec =>
s.partitionFilters.exists(_.find(_.isInstanceOf[PlanExpression[_]]).isDefined)
case o => false
}.isDefined, "Parquet scan should have partition predicate")
checkAnswer(df2, Row(0, 0, 0) :: Row(1, 1, 1) :: Nil)
}
}

test("SPARK-14986: Outer lateral view with empty generate expression") {
checkAnswer(
sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
}

test("dynamic partition pruning") {
withTable("df1", "df2") {
spark.range(100)
.select($"id", $"id".as("k"))
.write
.partitionBy("k")
.format("parquet")
.mode("overwrite")
.saveAsTable("df1")

spark.range(100)
.select($"id", $"id".as("k"))
.write
.partitionBy("k")
.format("parquet")
.mode("overwrite")
.saveAsTable("df2")

checkAnswer(
sql("select df1.id, df2.k from df1 join df2 on df1.k = df2.k and df2.id < 2"),
Row(0, 0) :: Row(1, 1) :: Nil)
}
}

def testCommandAvailable(command: String): Boolean = {
val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue())
attempt.isSuccess && attempt.get == 0
Expand Down

0 comments on commit 53e0bb3

Please sign in to comment.