From 25197b89655d81c46d011892d9bd2952233e4bee Mon Sep 17 00:00:00 2001 From: Alfred Xu Date: Fri, 16 Jul 2021 11:38:49 +0800 Subject: [PATCH] Refine GpuHashAggregateExec.setupReference (#2917) Signed-off-by: sperlingxx Co-authored-by: Alessandro Bellina Reorganized the code of boundInputReferences in a more general way --- .../com/nvidia/spark/rapids/aggregate.scala | 139 ++++++++---------- 1 file changed, 64 insertions(+), 75 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index 9bcfec7a108..883c82911a4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -163,17 +163,16 @@ case class GpuHashAggregateMetrics( case class AggregateModeInfo( uniqueModes: Seq[AggregateMode], hasPartialMode: Boolean, - hasPartialMerge: Boolean, + hasPartialMergeMode: Boolean, hasFinalMode: Boolean, hasCompleteMode: Boolean) object AggregateModeInfo { def apply(uniqueModes: Seq[AggregateMode]): AggregateModeInfo = { - val hasPartialMerge = uniqueModes.contains(PartialMerge) AggregateModeInfo( uniqueModes = uniqueModes, - hasPartialMode = hasPartialMerge || uniqueModes.contains(Partial), - hasPartialMerge = hasPartialMerge, + hasPartialMode = uniqueModes.contains(Partial), + hasPartialMergeMode = uniqueModes.contains(PartialMerge), hasFinalMode = uniqueModes.contains(Final), hasCompleteMode = uniqueModes.contains(Complete) ) @@ -624,83 +623,73 @@ class GpuHashAggregateIterator( val aggModeCudfAggregates = AggregateUtils.computeAggModeCudfAggregates(aggregateExpressions, aggBufferAttributes) - // - // expressions to pick input to the aggregate, and finalize the output to the result projection. - // - // Pick update distinct attributes or input projections for Partial - val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition( - _.isDistinct) - val updateExpressionsDistinct = - distinctAggExpressions.flatMap( - _.aggregateFunction.updateExpressions) - val updateAttributesDistinct = - distinctAggExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - val inputProjectionsDistinct = - distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection) - - // Pick merge non-distinct for PartialMerge - val mergeExpressionsNonDistinct = - nonDistinctAggExpressions - .flatMap(_.aggregateFunction.mergeExpressions) - .map(_.asInstanceOf[CudfAggregate].ref) - val mergeAttributesNonDistinct = - nonDistinctAggExpressions.flatMap( - _.aggregateFunction.aggBufferAttributes) - - // Partial with no distinct or when modes are empty - val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions - .flatMap(_.aggregateFunction.inputProjection) - - var distinctAttributes = Seq[Attribute]() - var distinctExpressions = Seq[Expression]() - var nonDistinctAttributes = Seq[Attribute]() - var nonDistinctExpressions = Seq[Expression]() - modeInfo.uniqueModes.foreach { - case PartialMerge => - nonDistinctAttributes = mergeAttributesNonDistinct - nonDistinctExpressions = mergeExpressionsNonDistinct - case Partial => - // Partial with distinct case - val updateExpressionsCudfAggsDistinct = - updateExpressionsDistinct.filter(_.isInstanceOf[CudfAggregate]) - .map(_.asInstanceOf[CudfAggregate].ref) - if (inputProjectionsDistinct.exists(p => !p.isInstanceOf[NamedExpression])) { - // Case of distinct average we need to evaluate the "GpuCast and GpuIsNotNull" columns. - // Refer to how input projections are setup for GpuAverage. - // In the case where we have expressions to evaluate, pick the unique attributes - // references from them as you only have one column for it before you start evaluating. - distinctExpressions = inputProjectionsDistinct - distinctAttributes = inputProjectionsDistinct.flatMap(ref => - ref.references.toSeq).distinct - } else { - distinctAttributes = updateAttributesDistinct - distinctExpressions = updateExpressionsCudfAggsDistinct - } - case _ => - } - val inputBindExpressions = groupingExpressions ++ nonDistinctExpressions ++ distinctExpressions - val resultingBindAttributes = groupingAttributes ++ distinctAttributes ++ nonDistinctAttributes - - val finalProjections = groupingExpressions ++ - aggregateExpressions.map(_.aggregateFunction.evaluateExpression) - // boundInputReferences is used to pick out of the input batch the appropriate columns - // for aggregation - // - Partial Merge mode: we use the inputBindExpressions which can be only - // non distinct merge expressions. - // - Partial or Complete mode: we use the inputProjections or distinct update expressions. - // - Partial, PartialMerge mode: we use the inputProjections or distinct update expressions + // for aggregation. + // + // - PartialMerge with Partial mode: we use the inputProjections // for Partial and non distinct merge expressions for PartialMerge. - // - Final mode: we pick the columns in the order as handed to us. - val boundInputReferences = if (modeInfo.hasPartialMerge) { - GpuBindReferences.bindGpuReferences(inputBindExpressions, resultingBindAttributes) - } else if (modeInfo.hasFinalMode) { + // - Final or PartialMerge-only mode: we pick the columns in the order as handed to us. + // - Partial or Complete mode: we use the inputProjections + val boundInputReferences = + if (modeInfo.hasPartialMergeMode && modeInfo.hasPartialMode) { + // The 3rd stage of AggWithOneDistinct, which combines (partial) reduce-side + // nonDistinctAggExpressions and map-side distinctAggExpressions. For this stage, we need to + // switch the position of distinctAttributes and nonDistinctAttributes. + // + // The schema of the 2nd stage's outputs: + // groupingAttributes ++ distinctAttributes ++ nonDistinctAggBufferAttributes + // + // The schema of the 3rd stage's expressions: + // nonDistinctMergeAggExpressions ++ distinctPartialAggExpressions + + val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition( + _.isDistinct) + + // The schema of childAttr: [groupAttr, distinctAttr, nonDistinctAttr]. + // With the size of nonDistinctAttr, we can easily extract distinctAttr and nonDistinctAttr + // from childAttr. + val sizeOfNonDistAttr = nonDistinctAggExpressions + .map(_.aggregateFunction.aggBufferAttributes.length).sum + val nonDistinctAttributes = childAttr.attrs.takeRight(sizeOfNonDistAttr) + val distinctAttributes = childAttr.attrs.slice( + groupingAttributes.length, childAttr.attrs.length - sizeOfNonDistAttr) + + // With PartialMerge modes, we just pass through corresponding attributes of child plan into + // nonDistinctExpressions. + val nonDistinctExpressions = nonDistinctAttributes.asInstanceOf[Seq[Expression]] + // With Partial modes, the input projections are necessary for distinctExpressions. + val distinctExpressions = distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection) + + // Align the expressions of input projections and input attributes + val inputProjections = groupingExpressions ++ nonDistinctExpressions ++ distinctExpressions + val inputAttributes = groupingAttributes ++ distinctAttributes ++ nonDistinctAttributes + GpuBindReferences.bindGpuReferences(inputProjections, inputAttributes) + } else if (modeInfo.hasFinalMode || + (modeInfo.hasPartialMergeMode && modeInfo.uniqueModes.length == 1)) { + // two possible conditions: + // 1. The Final stage, including the 2nd stage of NoDistinctAgg and 4th stage of + // AggWithOneDistinct, which needs no input projections. Because the child outputs are + // internal aggregation buffers, which are aligned for the final stage. + // + // 2. The 2nd stage (PartialMerge) of AggWithOneDistinct, which works like the final stage + // taking the child outputs as inputs without any projections. GpuBindReferences.bindGpuReferences(childAttr.attrs.asInstanceOf[Seq[Expression]], childAttr) - } else { + } else if (modeInfo.hasPartialMode || modeInfo.hasCompleteMode || + modeInfo.uniqueModes.isEmpty) { + // The first aggregation stage (including Partial or Complete or no aggExpression), + // whose child node is not an AggregateExec. Therefore, input projections are essential. + val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions + .flatMap(_.aggregateFunction.inputProjection) GpuBindReferences.bindGpuReferences(inputProjections, childAttr) + } else { + // This branch should NOT be reached. + throw new IllegalStateException( + s"Unable to handle aggregate with modes: ${modeInfo.uniqueModes}") } val boundFinalProjections = if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) { + val finalProjections = groupingExpressions ++ + aggregateExpressions.map(_.aggregateFunction.evaluateExpression) Some(GpuBindReferences.bindGpuReferences(finalProjections, aggBufferAttributes)) } else { None @@ -716,7 +705,7 @@ class GpuHashAggregateIterator( // out of the node as is. // - Final or Complete mode: we use resultExpressions to pick out the correct columns that // finalReferences has pre-processed for us - val boundResultReferences = if (modeInfo.hasPartialMode) { + val boundResultReferences = if (modeInfo.hasPartialMode || modeInfo.hasPartialMergeMode) { GpuBindReferences.bindGpuReferences( resultExpressions, resultExpressions.map(_.toAttribute))