diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 86aa1f2cd61d9..d1aea17538859 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1580,36 +1580,25 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { * Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator. */ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]() - val newPlan = plan transform { - case Deduplicate(keys, child) if !child.isStreaming => - val keyExprIds = keys.map(_.exprId) - val aggCols = child.output.map { attr => - if (keyExprIds.contains(attr.exprId)) { - attr -> attr - } else { - val alias = Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) - alias -> alias.newInstance() - } - }.unzip - // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping - // aggregations by checking the number of grouping keys. The key difference here is that a - // global aggregation always returns at least one row even if there are no input rows. Here - // we append a literal when the grouping key list is empty so that the result aggregate - // operator is properly treated as a grouping aggregation. - val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys - val newAgg = Aggregate(nonemptyKeys, aggCols._1, child) - rewritePlanMap += newAgg -> Aggregate(nonemptyKeys, aggCols._2, child) - newAgg - } - - if (rewritePlanMap.nonEmpty) { - assert(!plan.fastEquals(newPlan)) - Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1 - } else { - plan - } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { + case d @ Deduplicate(keys, child) if !child.isStreaming => + val keyExprIds = keys.map(_.exprId) + val aggCols = child.output.map { attr => + if (keyExprIds.contains(attr.exprId)) { + attr + } else { + Alias(new First(attr).toAggregateExpression(), attr.name)() + } + } + // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping + // aggregations by checking the number of grouping keys. The key difference here is that a + // global aggregation always returns at least one row even if there are no input rows. Here + // we append a literal when the grouping key list is empty so that the result aggregate + // operator is properly treated as a grouping aggregation. + val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys + val newAgg = Aggregate(nonemptyKeys, aggCols, child) + val attrMapping = d.output.zip(newAgg.output) + newAgg -> attrMapping } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index aaef7a49a5472..016352a47b302 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{Analyzer, CleanupAliases} +import org.apache.spark.sql.catalyst.analysis.CleanupAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -513,9 +513,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { */ private def constructLeftJoins( child: LogicalPlan, - subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, Seq[(LogicalPlan, LogicalPlan)]) = { - val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]() - val newPlan = subqueries.foldLeft(child) { + subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { + subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(query, conditions, _)) => val origOutput = query.output.head @@ -543,23 +542,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { if (havingNode.isEmpty) { // CASE 2: Subquery with no HAVING clause - val joinPlan = Join(currentChild, - Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And), JoinHint.NONE) - - def buildPlan(exprId: ExprId): LogicalPlan = { - Project( - currentChild.output :+ - Alias( - If(IsNull(alwaysTrueRef), - resultWithZeroTups.get, - aggValRef), origOutput.name)(exprId), - joinPlan) - } + Project( + currentChild.output :+ + Alias( + If(IsNull(alwaysTrueRef), + resultWithZeroTups.get, + aggValRef), origOutput.name)(exprId = origOutput.exprId), + Join(currentChild, + Project(query.output :+ alwaysTrueExpr, query), + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) - val newPlan = buildPlan(origOutput.exprId) - rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId) - newPlan } else { // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. // Need to modify any operators below the join to pass through all columns @@ -575,85 +567,66 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case op => sys.error(s"Unexpected operator $op in corelated subquery") } - val joinPlan = Join(currentChild, - Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And), JoinHint.NONE) - - def buildPlan(exprId: ExprId): LogicalPlan = { - // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups - // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) - // ELSE (aggregate value) END AS (original column name) - val caseExpr = Alias(CaseWhen(Seq( - (IsNull(alwaysTrueRef), resultWithZeroTups.get), - (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), - aggValRef), - origOutput.name)(exprId) - - Project( - currentChild.output :+ caseExpr, - joinPlan) - } + // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups + // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) + // ELSE (aggregate value) END AS (original column name) + val caseExpr = Alias(CaseWhen(Seq( + (IsNull(alwaysTrueRef), resultWithZeroTups.get), + (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), + aggValRef), + origOutput.name)(exprId = origOutput.exprId) + + Project( + currentChild.output :+ caseExpr, + Join(currentChild, + Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) - val newPlan = buildPlan(origOutput.exprId) - rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId) - newPlan } } } - - (newPlan, rewritePlanMap) } /** * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar * subqueries. */ - def apply(plan: LogicalPlan): LogicalPlan = { - val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]() - val newPlan = plan transform { - case a @ Aggregate(grouping, expressions, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] - val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) - if (subqueries.nonEmpty) { - // We currently only allow correlated subqueries in an aggregate if they are part of the - // grouping expressions. As a result we need to replace all the scalar subqueries in the - // grouping expressions by their result. - val newGrouping = grouping.map { e => - subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e) - } - val (newChild, rewriteMap) = constructLeftJoins(child, subqueries) - rewritePlanMap ++= rewriteMap - Aggregate(newGrouping, newExpressions, newChild) - } else { - a - } - case p @ Project(expressions, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] - val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) - if (subqueries.nonEmpty) { - val (newChild, rewriteMap) = constructLeftJoins(child, subqueries) - rewritePlanMap ++= rewriteMap - Project(newExpressions, newChild) - } else { - p + def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { + case a @ Aggregate(grouping, expressions, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + // We currently only allow correlated subqueries in an aggregate if they are part of the + // grouping expressions. As a result we need to replace all the scalar subqueries in the + // grouping expressions by their result. + val newGrouping = grouping.map { e => + subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e) } - case f @ Filter(condition, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] - val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) - if (subqueries.nonEmpty) { - val (newChild, rewriteMap) = constructLeftJoins(child, subqueries) - rewritePlanMap ++= rewriteMap - Project(f.output, Filter(newCondition, newChild)) - } else { - f - } - } - - if (rewritePlanMap.nonEmpty) { - assert(!plan.fastEquals(newPlan)) - Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1 - } else { - newPlan - } + val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + val attrMapping = a.output.zip(newAgg.output) + newAgg -> attrMapping + } else { + a -> Nil + } + case p @ Project(expressions, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + val newProj = Project(newExpressions, constructLeftJoins(child, subqueries)) + val attrMapping = p.output.zip(newProj.output) + newProj -> attrMapping + } else { + p -> Nil + } + case f @ Filter(condition, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) + if (subqueries.nonEmpty) { + val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) + val attrMapping = f.output.zip(newProj.output) + newProj -> attrMapping + } else { + f -> Nil + } } }