From 2ed051d4a5db51dd922b6e66d762b71b86cbffab Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 20 Aug 2020 06:18:22 +0900 Subject: [PATCH 1/4] Fix --- .../sql/catalyst/analysis/TypeCoercion.scala | 76 +++++++++++++------ .../catalyst/analysis/TypeCoercionSuite.scala | 17 ++++- .../resources/sql-tests/inputs/except.sql | 19 +++++ .../sql-tests/inputs/intersect-all.sql | 15 ++++ .../test/resources/sql-tests/inputs/union.sql | 14 ++++ .../sql-tests/results/except.sql.out | 58 +++++++++++++- .../sql-tests/results/intersect-all.sql.out | 42 +++++++++- .../resources/sql-tests/results/union.sql.out | 43 ++++++++++- 8 files changed, 254 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 604a082be4e55..a86c906fee1a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -328,27 +328,48 @@ object TypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case s @ Except(left, right, isAll) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) - assert(newChildren.length == 2) - Except(newChildren.head, newChildren.last, isAll) - - case s @ Intersect(left, right, isAll) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) - assert(newChildren.length == 2) - Intersect(newChildren.head, newChildren.last, isAll) - - case s: Union if s.childrenResolved && !s.byName && + def apply(plan: LogicalPlan): LogicalPlan = { + val exprIdMapArray = mutable.ArrayBuffer[(ExprId, (ExprId, DataType))]() + val newPlan = plan resolveOperatorsUp { + case s @ Except(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(left :: right :: Nil) + exprIdMapArray ++= newExprIds + assert(newChildren.length == 2) + Except(newChildren.head, newChildren.last, isAll) + + case s @ Intersect(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(left :: right :: Nil) + exprIdMapArray ++= newExprIds + assert(newChildren.length == 2) + Intersect(newChildren.head, newChildren.last, isAll) + + case s: Union if s.childrenResolved && !s.byName && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) - s.copy(children = newChildren) + val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(s.children) + exprIdMapArray ++= newExprIds + s.copy(children = newChildren) + } + + // Re-maps existing references to the new ones (exprId and dataType) + // for aliases added when widening columns' data types. + val exprIdMap = exprIdMapArray.toMap + newPlan resolveOperatorsUp { + case p if p.childrenResolved && p.missingInput.nonEmpty => + p.mapExpressions { _.transform { + case a: AttributeReference if p.missingInput.contains(a) && + exprIdMap.contains(a.exprId) => + val (exprId, dt) = exprIdMap(a.exprId) + AttributeReference(a.name, dt, a.nullable, a.metadata)(exprId, a.qualifier) + } + } + } } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]) + : (Seq[LogicalPlan], Seq[(ExprId, (ExprId, DataType))]) = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute @@ -358,10 +379,11 @@ object TypeCoercion { if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. - children.map(widenTypes(_, targetTypes)) + val (newChildren, newExprIds) = children.map(widenTypes(_, targetTypes)).unzip + (newChildren, newExprIds.flatten) } else { // Unable to find a target type to widen, then just return the original set. - children + (children, Nil) } } @@ -385,12 +407,16 @@ object TypeCoercion { } /** Given a plan, add an extra project on top to widen some columns' data types. */ - private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { - val casted = plan.output.zip(targetTypes).map { - case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)() - case (e, _) => e - } - Project(casted, plan) + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]) + : (LogicalPlan, Seq[(ExprId, (ExprId, DataType))]) = { + val (casted, newExprIds) = plan.output.zip(targetTypes).map { + case (e, dt) if e.dataType != dt => + val alias = Alias(Cast(e, dt), e.name)() + (alias, Some(e.exprId -> (alias.exprId, dt))) + case (e, _) => + (e, None) + }.unzip + (Project(casted, plan), newExprIds.flatten) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1ea1ddb8bbd08..1af562fd1a061 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -21,13 +21,12 @@ import java.sql.Timestamp import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval class TypeCoercionSuite extends AnalysisTest { import TypeCoercionSuite._ @@ -1417,6 +1416,20 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") { + val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))()) + val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))()) + val p1 = t1.select(t1.output.head) + val p2 = t2.select(t2.output.head) + val union = p1.union(p2) + val wp1 = widenSetOperationTypes(union.select(p1.output.head)) + assert(wp1.isInstanceOf[Project]) + assert(wp1.missingInput.isEmpty) + val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union)) + assert(wp2.isInstanceOf[Aggregate]) + assert(wp2.missingInput.isEmpty) + } + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. diff --git a/sql/core/src/test/resources/sql-tests/inputs/except.sql b/sql/core/src/test/resources/sql-tests/inputs/except.sql index 1d579e65f3473..ffdf1f4f3d24d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/except.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/except.sql @@ -55,3 +55,22 @@ FROM t1 WHERE t1.v >= (SELECT min(t2.v) FROM t2 WHERE t2.k = t1.k); + +-- SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes +CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v); +SELECT t.v FROM ( + SELECT v FROM t3 + EXCEPT + SELECT v + v AS v FROM t3 +) t; + +SELECT SUM(t.v) FROM ( + SELECT v FROM t3 + EXCEPT + SELECT v + v AS v FROM t3 +) t; + +-- Clean-up +DROP VIEW IF EXISTS t1; +DROP VIEW IF EXISTS t2; +DROP VIEW IF EXISTS t3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql index b0b2244048caa..077caa5dd44a0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/intersect-all.sql @@ -155,6 +155,21 @@ SELECT * FROM tab2; -- Restore the property SET spark.sql.legacy.setopsPrecedence.enabled = false; +-- SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes +CREATE OR REPLACE TEMPORARY VIEW tab3 AS VALUES (decimal(1)), (decimal(2)) tbl3(v); +SELECT t.v FROM ( + SELECT v FROM tab3 + INTERSECT + SELECT v + v AS v FROM tab3 +) t; + +SELECT SUM(t.v) FROM ( + SELECT v FROM tab3 + INTERSECT + SELECT v + v AS v FROM tab3 +) t; + -- Clean-up DROP VIEW IF EXISTS tab1; DROP VIEW IF EXISTS tab2; +DROP VIEW IF EXISTS tab3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql index 6da1b9b49b226..8a5b6c50fc1e3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/union.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -45,10 +45,24 @@ SELECT array(1, 2), 'str' UNION ALL SELECT array(1, 2, 3, NULL), 1; +-- SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes +CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v); +SELECT t.v FROM ( + SELECT v FROM t3 + UNION ALL + SELECT v + v AS v FROM t3 +) t; + +SELECT SUM(t.v) FROM ( + SELECT v FROM t3 + UNION + SELECT v + v AS v FROM t3 +) t; -- Clean-up DROP VIEW IF EXISTS t1; DROP VIEW IF EXISTS t2; +DROP VIEW IF EXISTS t3; DROP VIEW IF EXISTS p1; DROP VIEW IF EXISTS p2; DROP VIEW IF EXISTS p3; diff --git a/sql/core/src/test/resources/sql-tests/results/except.sql.out b/sql/core/src/test/resources/sql-tests/results/except.sql.out index 62d695219d01d..061b122eac7cf 100644 --- a/sql/core/src/test/resources/sql-tests/results/except.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/except.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 15 -- !query @@ -103,3 +103,59 @@ WHERE t1.v >= (SELECT min(t2.v) struct -- !query output two + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT t.v FROM ( + SELECT v FROM t3 + EXCEPT + SELECT v + v AS v FROM t3 +) t +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT SUM(t.v) FROM ( + SELECT v FROM t3 + EXCEPT + SELECT v + v AS v FROM t3 +) t +-- !query schema +struct +-- !query output +1 + + +-- !query +DROP VIEW IF EXISTS t1 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS t2 +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP VIEW IF EXISTS t3 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out index 4762082dc3be2..b99f63393cc4d 100644 --- a/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/intersect-all.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 22 +-- Number of queries: 26 -- !query @@ -291,6 +291,38 @@ struct spark.sql.legacy.setopsPrecedence.enabled false +-- !query +CREATE OR REPLACE TEMPORARY VIEW tab3 AS VALUES (decimal(1)), (decimal(2)) tbl3(v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT t.v FROM ( + SELECT v FROM tab3 + INTERSECT + SELECT v + v AS v FROM tab3 +) t +-- !query schema +struct +-- !query output +2 + + +-- !query +SELECT SUM(t.v) FROM ( + SELECT v FROM tab3 + INTERSECT + SELECT v + v AS v FROM tab3 +) t +-- !query schema +struct +-- !query output +2 + + -- !query DROP VIEW IF EXISTS tab1 -- !query schema @@ -305,3 +337,11 @@ DROP VIEW IF EXISTS tab2 struct<> -- !query output + + +-- !query +DROP VIEW IF EXISTS tab3 +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out index 44002406836a4..ce3c761bc5d2d 100644 --- a/sql/core/src/test/resources/sql-tests/results/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 20 -- !query @@ -126,6 +126,39 @@ struct,str:string> [1,2] str +-- !query +CREATE OR REPLACE TEMPORARY VIEW t3 AS VALUES (decimal(1)) tbl(v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT t.v FROM ( + SELECT v FROM t3 + UNION ALL + SELECT v + v AS v FROM t3 +) t +-- !query schema +struct +-- !query output +1 +2 + + +-- !query +SELECT SUM(t.v) FROM ( + SELECT v FROM t3 + UNION + SELECT v + v AS v FROM t3 +) t +-- !query schema +struct +-- !query output +3 + + -- !query DROP VIEW IF EXISTS t1 -- !query schema @@ -142,6 +175,14 @@ struct<> +-- !query +DROP VIEW IF EXISTS t3 +-- !query schema +struct<> +-- !query output + + + -- !query DROP VIEW IF EXISTS p1 -- !query schema From b0b55317f7e2bbf3d94ef2339343ad83d188fa1a Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 21 Aug 2020 08:57:02 +0900 Subject: [PATCH 2/4] review --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a86c906fee1a2..f2fd07ada1d99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -329,7 +329,7 @@ object TypeCoercion { object WidenSetOperationTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val exprIdMapArray = mutable.ArrayBuffer[(ExprId, (ExprId, DataType))]() + val exprIdMapArray = mutable.ArrayBuffer[(ExprId, Attribute)]() val newPlan = plan resolveOperatorsUp { case s @ Except(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => @@ -359,9 +359,7 @@ object TypeCoercion { case p if p.childrenResolved && p.missingInput.nonEmpty => p.mapExpressions { _.transform { case a: AttributeReference if p.missingInput.contains(a) && - exprIdMap.contains(a.exprId) => - val (exprId, dt) = exprIdMap(a.exprId) - AttributeReference(a.name, dt, a.nullable, a.metadata)(exprId, a.qualifier) + exprIdMap.contains(a.exprId) => exprIdMap(a.exprId) } } } @@ -369,7 +367,7 @@ object TypeCoercion { /** Build new children with the widest types for each attribute among all the children */ private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]) - : (Seq[LogicalPlan], Seq[(ExprId, (ExprId, DataType))]) = { + : (Seq[LogicalPlan], Seq[(ExprId, Attribute)]) = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute @@ -408,11 +406,11 @@ object TypeCoercion { /** Given a plan, add an extra project on top to widen some columns' data types. */ private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]) - : (LogicalPlan, Seq[(ExprId, (ExprId, DataType))]) = { + : (LogicalPlan, Seq[(ExprId, Attribute)]) = { val (casted, newExprIds) = plan.output.zip(targetTypes).map { case (e, dt) if e.dataType != dt => val alias = Alias(Cast(e, dt), e.name)() - (alias, Some(e.exprId -> (alias.exprId, dt))) + (alias, Some(e.exprId -> alias.toAttribute)) case (e, _) => (e, None) }.unzip From 2340afee80657dd50e931c56245a9aff01e1ffb8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 25 Aug 2020 22:14:59 +0900 Subject: [PATCH 3/4] review --- .../sql/catalyst/analysis/TypeCoercion.scala | 78 +++++++------------ .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../execution/RemoveRedundantProjects.scala | 16 ++-- 3 files changed, 39 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f2fd07ada1d99..7e4d89790dee1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -326,48 +326,29 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - object WidenSetOperationTypes extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = { - val exprIdMapArray = mutable.ArrayBuffer[(ExprId, Attribute)]() - val newPlan = plan resolveOperatorsUp { - case s @ Except(left, right, isAll) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(left :: right :: Nil) - exprIdMapArray ++= newExprIds - assert(newChildren.length == 2) - Except(newChildren.head, newChildren.last, isAll) - - case s @ Intersect(left, right, isAll) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(left :: right :: Nil) - exprIdMapArray ++= newExprIds - assert(newChildren.length == 2) - Intersect(newChildren.head, newChildren.last, isAll) - - case s: Union if s.childrenResolved && !s.byName && + object WidenSetOperationTypes extends TypeCoercionRule { + + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + case s @ Except(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + Except(newChildren.head, newChildren.last, isAll) + + case s @ Intersect(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + assert(newChildren.length == 2) + Intersect(newChildren.head, newChildren.last, isAll) + + case s: Union if s.childrenResolved && !s.byName && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(s.children) - exprIdMapArray ++= newExprIds - s.copy(children = newChildren) - } - - // Re-maps existing references to the new ones (exprId and dataType) - // for aliases added when widening columns' data types. - val exprIdMap = exprIdMapArray.toMap - newPlan resolveOperatorsUp { - case p if p.childrenResolved && p.missingInput.nonEmpty => - p.mapExpressions { _.transform { - case a: AttributeReference if p.missingInput.contains(a) && - exprIdMap.contains(a.exprId) => exprIdMap(a.exprId) - } - } - } + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + s.copy(children = newChildren) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]) - : (Seq[LogicalPlan], Seq[(ExprId, Attribute)]) = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute @@ -377,11 +358,10 @@ object TypeCoercion { if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. - val (newChildren, newExprIds) = children.map(widenTypes(_, targetTypes)).unzip - (newChildren, newExprIds.flatten) + children.map(widenTypes(_, targetTypes)) } else { // Unable to find a target type to widen, then just return the original set. - (children, Nil) + children } } @@ -405,16 +385,12 @@ object TypeCoercion { } /** Given a plan, add an extra project on top to widen some columns' data types. */ - private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]) - : (LogicalPlan, Seq[(ExprId, Attribute)]) = { - val (casted, newExprIds) = plan.output.zip(targetTypes).map { - case (e, dt) if e.dataType != dt => - val alias = Alias(Cast(e, dt), e.name)() - (alias, Some(e.exprId -> alias.toAttribute)) - case (e, _) => - (e, None) - }.unzip - (Project(casted, plan), newExprIds.flatten) + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { + val casted = plan.output.zip(targetTypes).map { + case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)(exprId = e.exprId) + case (e, _) => e + } + Project(casted, plan) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bcdc5cd942e35..a4e07515aec0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -477,7 +477,8 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { object RemoveNoopOperators extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Eliminate no-op Projects - case p @ Project(_, child) if child.sameOutput(p) => child + case Project(projList, child) if projList.length == child.output.length && + projList.zip(child.output).forall { case (e1, e2) => e1.semanticEquals(e2) } => child // Eliminate no-op Window case w: Window if w.windowExpressions.isEmpty => w.child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala index ecb4ad0f6e8dd..7a9b60d57e25b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, PartialMerge} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase -import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf /** @@ -85,14 +84,19 @@ case class RemoveRedundantProjects(conf: SQLConf) extends Rule[SparkPlan] { // to convert the rows to UnsafeRow. See DataSourceV2Strategy for more details. case d: DataSourceV2ScanExecBase if !d.supportsColumnar => false case _ => + def semanticEquals(exprs1: Seq[Expression], exprs2: Seq[Expression]): Boolean = { + exprs1.length == exprs2.length && exprs1.zip(exprs2).forall { + case (e1, e2) => e1.semanticEquals(e2) + } + } if (requireOrdering) { - project.output.map(_.exprId.id) == child.output.map(_.exprId.id) && + semanticEquals(project.projectList, child.output) && checkNullability(project.output, child.output) } else { - val orderedProjectOutput = project.output.sortBy(_.exprId.id) + val orderedProjectList = project.projectList.sortBy(_.exprId.id) val orderedChildOutput = child.output.sortBy(_.exprId.id) - orderedProjectOutput.map(_.exprId.id) == orderedChildOutput.map(_.exprId.id) && - checkNullability(orderedProjectOutput, orderedChildOutput) + semanticEquals(orderedProjectList, orderedChildOutput) && + checkNullability(orderedProjectList.map(_.toAttribute), orderedChildOutput) } } } From 6372515769c1d3c299b865d4ec4b8378bdc5e84c Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 27 Aug 2020 16:54:13 +0900 Subject: [PATCH 4/4] review --- .../sql/catalyst/analysis/Analyzer.scala | 225 ++++++++++-------- .../sql/catalyst/analysis/TypeCoercion.scala | 75 ++++-- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../execution/RemoveRedundantProjects.scala | 16 +- 4 files changed, 180 insertions(+), 139 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0e81f48fc7ebb..7c6d0fcd9c8c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -123,6 +123,127 @@ object AnalysisContext { } } +object Analyzer { + + /** + * Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones. + * This method also updates all the related references in the `plan` accordingly. + * + * @param plan to rewrite + * @param rewritePlanMap has mappings from old plans to new ones for the given `plan`. + * @return a rewritten plan and updated references related to a root node of + * the given `plan` for rewriting it. + */ + def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan]) + : (LogicalPlan, Seq[(Attribute, Attribute)]) = { + if (plan.resolved) { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + val newChildren = plan.children.map { child => + // If not, we'd rewrite child plan recursively until we find the + // conflict node or reach the leaf node. + val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap) + attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => + // `attrMapping` is not only used to replace the attributes of the current `plan`, + // but also to be propagated to the parent plans of the current `plan`. Therefore, + // the `oldAttr` must be part of either `plan.references` (so that it can be used to + // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be + // used by those parent plans). + (plan.outputSet ++ plan.references).contains(oldAttr) + } + newChild + } + + val newPlan = if (rewritePlanMap.contains(plan)) { + rewritePlanMap(plan).withNewChildren(newChildren) + } else { + plan.withNewChildren(newChildren) + } + + assert(!attrMapping.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMapping) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + val p = newPlan.transformExpressions { + case a: Attribute => + updateAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites)) + } + attrMapping ++= plan.output.zip(p.output) + .filter { case (a1, a2) => a1.exprId != a2.exprId } + p -> attrMapping + } else { + // Just passes through unresolved nodes + plan.mapChildren { + rewritePlan(_, rewritePlanMap)._1 + } -> Nil + } + } + + private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + val exprId = attrMap.getOrElse(attr, attr).exprId + attr.withExprId(exprId) + } + + /** + * The outer plan may have old references and the function below updates the + * outer references to refer to the new attributes. + * + * For example (SQL): + * {{{ + * SELECT * FROM t1 + * INTERSECT + * SELECT * FROM t1 + * WHERE EXISTS (SELECT 1 + * FROM t2 + * WHERE t1.c1 = t2.c1) + * }}} + * Plan before resolveReference rule. + * 'Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- 'Project [*] + * +- Filter exists#257 [c1#245] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#245) = c1#251) + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#245,c2#246] parquet + * Plan after the resolveReference rule. + * Intersect + * :- Project [c1#245, c2#246] + * : +- SubqueryAlias t1 + * : +- Relation[c1#245,c2#246] parquet + * +- Project [c1#259, c2#260] + * +- Filter exists#257 [c1#259] + * : +- Project [1 AS 1#258] + * : +- Filter (outer(c1#259) = c1#251) => Updated + * : +- SubqueryAlias t2 + * : +- Relation[c1#251,c2#252] parquet + * +- SubqueryAlias t1 + * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten. + */ + private def updateOuterReferencesInSubquery( + plan: LogicalPlan, + attrMap: AttributeMap[Attribute]): LogicalPlan = { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan transformDown { case currentFragment => + currentFragment transformExpressions { + case OuterReference(a: Attribute) => + OuterReference(updateAttr(a, attrMap)) + case s: SubqueryExpression => + s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap)) + } + } + } + } +} + /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -1251,109 +1372,7 @@ class Analyzer( if (conflictPlans.isEmpty) { right } else { - rewritePlan(right, conflictPlans.toMap)._1 - } - } - - private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan]) - : (LogicalPlan, Seq[(Attribute, Attribute)]) = { - if (conflictPlanMap.contains(plan)) { - // If the plan is the one that conflict the with left one, we'd - // just replace it with the new plan and collect the rewrite - // attributes for the parent node. - val newRelation = conflictPlanMap(plan) - newRelation -> plan.output.zip(newRelation.output) - } else { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - val newPlan = plan.mapChildren { child => - // If not, we'd rewrite child plan recursively until we find the - // conflict node or reach the leaf node. - val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap) - attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => - // `attrMapping` is not only used to replace the attributes of the current `plan`, - // but also to be propagated to the parent plans of the current `plan`. Therefore, - // the `oldAttr` must be part of either `plan.references` (so that it can be used to - // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be - // used by those parent plans). - (plan.outputSet ++ plan.references).contains(oldAttr) - } - newChild - } - - if (attrMapping.isEmpty) { - newPlan -> attrMapping.toSeq - } else { - assert(!attrMapping.groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - val attributeRewrites = AttributeMap(attrMapping.toSeq) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - newPlan.transformExpressions { - case a: Attribute => - dedupAttr(a, attributeRewrites) - case s: SubqueryExpression => - s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) - } -> attrMapping.toSeq - } - } - } - - private def dedupAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - val exprId = attrMap.getOrElse(attr, attr).exprId - attr.withExprId(exprId) - } - - /** - * The outer plan may have been de-duplicated and the function below updates the - * outer references to refer to the de-duplicated attributes. - * - * For example (SQL): - * {{{ - * SELECT * FROM t1 - * INTERSECT - * SELECT * FROM t1 - * WHERE EXISTS (SELECT 1 - * FROM t2 - * WHERE t1.c1 = t2.c1) - * }}} - * Plan before resolveReference rule. - * 'Intersect - * :- Project [c1#245, c2#246] - * : +- SubqueryAlias t1 - * : +- Relation[c1#245,c2#246] parquet - * +- 'Project [*] - * +- Filter exists#257 [c1#245] - * : +- Project [1 AS 1#258] - * : +- Filter (outer(c1#245) = c1#251) - * : +- SubqueryAlias t2 - * : +- Relation[c1#251,c2#252] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#245,c2#246] parquet - * Plan after the resolveReference rule. - * Intersect - * :- Project [c1#245, c2#246] - * : +- SubqueryAlias t1 - * : +- Relation[c1#245,c2#246] parquet - * +- Project [c1#259, c2#260] - * +- Filter exists#257 [c1#259] - * : +- Project [1 AS 1#258] - * : +- Filter (outer(c1#259) = c1#251) => Updated - * : +- SubqueryAlias t2 - * : +- Relation[c1#251,c2#252] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are de-duplicated. - */ - private def dedupOuterReferencesInSubquery( - plan: LogicalPlan, - attrMap: AttributeMap[Attribute]): LogicalPlan = { - plan transformDown { case currentFragment => - currentFragment transformExpressions { - case OuterReference(a: Attribute) => - OuterReference(dedupAttr(a, attrMap)) - case s: SubqueryExpression => - s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attrMap)) - } + Analyzer.rewritePlan(right, conflictPlans.toMap)._1 } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 7e4d89790dee1..861eddedc0e1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -328,27 +328,51 @@ object TypeCoercion { */ object WidenSetOperationTypes extends TypeCoercionRule { - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case s @ Except(left, right, isAll) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) - assert(newChildren.length == 2) - Except(newChildren.head, newChildren.last, isAll) - - case s @ Intersect(left, right, isAll) if s.childrenResolved && - left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) - assert(newChildren.length == 2) - Intersect(newChildren.head, newChildren.last, isAll) - - case s: Union if s.childrenResolved && !s.byName && + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { + val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]() + val newPlan = plan resolveOperatorsUp { + case s @ Except(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil) + if (newChildren.nonEmpty) { + rewritePlanMap ++= newChildren + Except(newChildren.head._1, newChildren.last._1, isAll) + } else { + s + } + + case s @ Intersect(left, right, isAll) if s.childrenResolved && + left.output.length == right.output.length && !s.resolved => + val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil) + if (newChildren.nonEmpty) { + rewritePlanMap ++= newChildren + Intersect(newChildren.head._1, newChildren.last._1, isAll) + } else { + s + } + + case s: Union if s.childrenResolved && !s.byName && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) - s.copy(children = newChildren) + val newChildren = buildNewChildrenWithWiderTypes(s.children) + if (newChildren.nonEmpty) { + rewritePlanMap ++= newChildren + s.copy(children = newChildren.map(_._1)) + } else { + s + } + } + + if (rewritePlanMap.nonEmpty) { + assert(!plan.fastEquals(newPlan)) + Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1 + } else { + plan + } } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]) + : Seq[(LogicalPlan, LogicalPlan)] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute @@ -360,8 +384,7 @@ object TypeCoercion { // Add an extra Project if the targetTypes are different from the original types. children.map(widenTypes(_, targetTypes)) } else { - // Unable to find a target type to widen, then just return the original set. - children + Nil } } @@ -385,12 +408,16 @@ object TypeCoercion { } /** Given a plan, add an extra project on top to widen some columns' data types. */ - private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]) + : (LogicalPlan, LogicalPlan) = { val casted = plan.output.zip(targetTypes).map { - case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)(exprId = e.exprId) - case (e, _) => e - } - Project(casted, plan) + case (e, dt) if e.dataType != dt => + val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId) + alias -> alias.newInstance() + case (e, _) => + e -> e + }.unzip + Project(casted._1, plan) -> Project(casted._2, plan) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4e07515aec0f..bcdc5cd942e35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -477,8 +477,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { object RemoveNoopOperators extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Eliminate no-op Projects - case Project(projList, child) if projList.length == child.output.length && - projList.zip(child.output).forall { case (e1, e2) => e1.semanticEquals(e2) } => child + case p @ Project(_, child) if child.sameOutput(p) => child // Eliminate no-op Window case w: Window if w.windowExpressions.isEmpty => w.child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala index 7a9b60d57e25b..ecb4ad0f6e8dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, PartialMerge} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf /** @@ -84,19 +85,14 @@ case class RemoveRedundantProjects(conf: SQLConf) extends Rule[SparkPlan] { // to convert the rows to UnsafeRow. See DataSourceV2Strategy for more details. case d: DataSourceV2ScanExecBase if !d.supportsColumnar => false case _ => - def semanticEquals(exprs1: Seq[Expression], exprs2: Seq[Expression]): Boolean = { - exprs1.length == exprs2.length && exprs1.zip(exprs2).forall { - case (e1, e2) => e1.semanticEquals(e2) - } - } if (requireOrdering) { - semanticEquals(project.projectList, child.output) && + project.output.map(_.exprId.id) == child.output.map(_.exprId.id) && checkNullability(project.output, child.output) } else { - val orderedProjectList = project.projectList.sortBy(_.exprId.id) + val orderedProjectOutput = project.output.sortBy(_.exprId.id) val orderedChildOutput = child.output.sortBy(_.exprId.id) - semanticEquals(orderedProjectList, orderedChildOutput) && - checkNullability(orderedProjectList.map(_.toAttribute), orderedChildOutput) + orderedProjectOutput.map(_.exprId.id) == orderedChildOutput.map(_.exprId.id) && + checkNullability(orderedProjectOutput, orderedChildOutput) } } }