From de2c1a32e461f49d070b4da6c104af0ca5344fb1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 9 Sep 2021 09:06:44 -0600 Subject: [PATCH] Normalize expression ids in Partial aggregate expressions to fix a regression when running with Spark 3.2 (#3403) Signed-off-by: Andy Grove --- .../org/apache/spark/sql/rapids/AggregateFunctions.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index 9304e8e3b23..a99e3f37e08 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -149,12 +149,13 @@ case class GpuAggregateExpression(origAggregateFunction: GpuAggregateFunction, // We compute the same thing regardless of our final result. override lazy val canonicalized: Expression = { val normalizedAggFunc = mode match { - // For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers, - // and the actual children of `aggregateFunction` is not used, here we normalize the expr id. - case PartialMerge | Final => aggregateFunction.transform { + // For Partial, PartialMerge, or Final mode, the input to the `aggregateFunction` is + // aggregate buffers, and the actual children of `aggregateFunction` is not used, + // here we normalize the expr id. + case Partial | PartialMerge | Final => aggregateFunction.transform { case a: AttributeReference => a.withExprId(ExprId(0)) } - case Partial | Complete => aggregateFunction + case Complete => aggregateFunction } GpuAggregateExpression(