Skip to content

Commit

Permalink
Refine GpuHashAggregateExec.setupReference (#2917)
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
Co-authored-by: Alessandro Bellina <[email protected]>

Reorganized the code of boundInputReferences in a more general way
  • Loading branch information
sperlingxx authored Jul 16, 2021
1 parent 55a2cee commit 25197b8
Showing 1 changed file with 64 additions and 75 deletions.
139 changes: 64 additions & 75 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,16 @@ case class GpuHashAggregateMetrics(
case class AggregateModeInfo(
uniqueModes: Seq[AggregateMode],
hasPartialMode: Boolean,
hasPartialMerge: Boolean,
hasPartialMergeMode: Boolean,
hasFinalMode: Boolean,
hasCompleteMode: Boolean)

object AggregateModeInfo {
def apply(uniqueModes: Seq[AggregateMode]): AggregateModeInfo = {
val hasPartialMerge = uniqueModes.contains(PartialMerge)
AggregateModeInfo(
uniqueModes = uniqueModes,
hasPartialMode = hasPartialMerge || uniqueModes.contains(Partial),
hasPartialMerge = hasPartialMerge,
hasPartialMode = uniqueModes.contains(Partial),
hasPartialMergeMode = uniqueModes.contains(PartialMerge),
hasFinalMode = uniqueModes.contains(Final),
hasCompleteMode = uniqueModes.contains(Complete)
)
Expand Down Expand Up @@ -624,83 +623,73 @@ class GpuHashAggregateIterator(
val aggModeCudfAggregates =
AggregateUtils.computeAggModeCudfAggregates(aggregateExpressions, aggBufferAttributes)

//
// expressions to pick input to the aggregate, and finalize the output to the result projection.
//
// Pick update distinct attributes or input projections for Partial
val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition(
_.isDistinct)
val updateExpressionsDistinct =
distinctAggExpressions.flatMap(
_.aggregateFunction.updateExpressions)
val updateAttributesDistinct =
distinctAggExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val inputProjectionsDistinct =
distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection)

// Pick merge non-distinct for PartialMerge
val mergeExpressionsNonDistinct =
nonDistinctAggExpressions
.flatMap(_.aggregateFunction.mergeExpressions)
.map(_.asInstanceOf[CudfAggregate].ref)
val mergeAttributesNonDistinct =
nonDistinctAggExpressions.flatMap(
_.aggregateFunction.aggBufferAttributes)

// Partial with no distinct or when modes are empty
val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions
.flatMap(_.aggregateFunction.inputProjection)

var distinctAttributes = Seq[Attribute]()
var distinctExpressions = Seq[Expression]()
var nonDistinctAttributes = Seq[Attribute]()
var nonDistinctExpressions = Seq[Expression]()
modeInfo.uniqueModes.foreach {
case PartialMerge =>
nonDistinctAttributes = mergeAttributesNonDistinct
nonDistinctExpressions = mergeExpressionsNonDistinct
case Partial =>
// Partial with distinct case
val updateExpressionsCudfAggsDistinct =
updateExpressionsDistinct.filter(_.isInstanceOf[CudfAggregate])
.map(_.asInstanceOf[CudfAggregate].ref)
if (inputProjectionsDistinct.exists(p => !p.isInstanceOf[NamedExpression])) {
// Case of distinct average we need to evaluate the "GpuCast and GpuIsNotNull" columns.
// Refer to how input projections are setup for GpuAverage.
// In the case where we have expressions to evaluate, pick the unique attributes
// references from them as you only have one column for it before you start evaluating.
distinctExpressions = inputProjectionsDistinct
distinctAttributes = inputProjectionsDistinct.flatMap(ref =>
ref.references.toSeq).distinct
} else {
distinctAttributes = updateAttributesDistinct
distinctExpressions = updateExpressionsCudfAggsDistinct
}
case _ =>
}
val inputBindExpressions = groupingExpressions ++ nonDistinctExpressions ++ distinctExpressions
val resultingBindAttributes = groupingAttributes ++ distinctAttributes ++ nonDistinctAttributes

val finalProjections = groupingExpressions ++
aggregateExpressions.map(_.aggregateFunction.evaluateExpression)

// boundInputReferences is used to pick out of the input batch the appropriate columns
// for aggregation
// - Partial Merge mode: we use the inputBindExpressions which can be only
// non distinct merge expressions.
// - Partial or Complete mode: we use the inputProjections or distinct update expressions.
// - Partial, PartialMerge mode: we use the inputProjections or distinct update expressions
// for aggregation.
//
// - PartialMerge with Partial mode: we use the inputProjections
// for Partial and non distinct merge expressions for PartialMerge.
// - Final mode: we pick the columns in the order as handed to us.
val boundInputReferences = if (modeInfo.hasPartialMerge) {
GpuBindReferences.bindGpuReferences(inputBindExpressions, resultingBindAttributes)
} else if (modeInfo.hasFinalMode) {
// - Final or PartialMerge-only mode: we pick the columns in the order as handed to us.
// - Partial or Complete mode: we use the inputProjections
val boundInputReferences =
if (modeInfo.hasPartialMergeMode && modeInfo.hasPartialMode) {
// The 3rd stage of AggWithOneDistinct, which combines (partial) reduce-side
// nonDistinctAggExpressions and map-side distinctAggExpressions. For this stage, we need to
// switch the position of distinctAttributes and nonDistinctAttributes.
//
// The schema of the 2nd stage's outputs:
// groupingAttributes ++ distinctAttributes ++ nonDistinctAggBufferAttributes
//
// The schema of the 3rd stage's expressions:
// nonDistinctMergeAggExpressions ++ distinctPartialAggExpressions

val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition(
_.isDistinct)

// The schema of childAttr: [groupAttr, distinctAttr, nonDistinctAttr].
// With the size of nonDistinctAttr, we can easily extract distinctAttr and nonDistinctAttr
// from childAttr.
val sizeOfNonDistAttr = nonDistinctAggExpressions
.map(_.aggregateFunction.aggBufferAttributes.length).sum
val nonDistinctAttributes = childAttr.attrs.takeRight(sizeOfNonDistAttr)
val distinctAttributes = childAttr.attrs.slice(
groupingAttributes.length, childAttr.attrs.length - sizeOfNonDistAttr)

// With PartialMerge modes, we just pass through corresponding attributes of child plan into
// nonDistinctExpressions.
val nonDistinctExpressions = nonDistinctAttributes.asInstanceOf[Seq[Expression]]
// With Partial modes, the input projections are necessary for distinctExpressions.
val distinctExpressions = distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection)

// Align the expressions of input projections and input attributes
val inputProjections = groupingExpressions ++ nonDistinctExpressions ++ distinctExpressions
val inputAttributes = groupingAttributes ++ distinctAttributes ++ nonDistinctAttributes
GpuBindReferences.bindGpuReferences(inputProjections, inputAttributes)
} else if (modeInfo.hasFinalMode ||
(modeInfo.hasPartialMergeMode && modeInfo.uniqueModes.length == 1)) {
// two possible conditions:
// 1. The Final stage, including the 2nd stage of NoDistinctAgg and 4th stage of
// AggWithOneDistinct, which needs no input projections. Because the child outputs are
// internal aggregation buffers, which are aligned for the final stage.
//
// 2. The 2nd stage (PartialMerge) of AggWithOneDistinct, which works like the final stage
// taking the child outputs as inputs without any projections.
GpuBindReferences.bindGpuReferences(childAttr.attrs.asInstanceOf[Seq[Expression]], childAttr)
} else {
} else if (modeInfo.hasPartialMode || modeInfo.hasCompleteMode ||
modeInfo.uniqueModes.isEmpty) {
// The first aggregation stage (including Partial or Complete or no aggExpression),
// whose child node is not an AggregateExec. Therefore, input projections are essential.
val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions
.flatMap(_.aggregateFunction.inputProjection)
GpuBindReferences.bindGpuReferences(inputProjections, childAttr)
} else {
// This branch should NOT be reached.
throw new IllegalStateException(
s"Unable to handle aggregate with modes: ${modeInfo.uniqueModes}")
}

val boundFinalProjections = if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) {
val finalProjections = groupingExpressions ++
aggregateExpressions.map(_.aggregateFunction.evaluateExpression)
Some(GpuBindReferences.bindGpuReferences(finalProjections, aggBufferAttributes))
} else {
None
Expand All @@ -716,7 +705,7 @@ class GpuHashAggregateIterator(
// out of the node as is.
// - Final or Complete mode: we use resultExpressions to pick out the correct columns that
// finalReferences has pre-processed for us
val boundResultReferences = if (modeInfo.hasPartialMode) {
val boundResultReferences = if (modeInfo.hasPartialMode || modeInfo.hasPartialMergeMode) {
GpuBindReferences.bindGpuReferences(
resultExpressions,
resultExpressions.map(_.toAttribute))
Expand Down

0 comments on commit 25197b8

Please sign in to comment.