Skip to content

Commit

Permalink
rework GpuHashAgg.setupRef
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx committed Jul 14, 2021
1 parent 1a09322 commit 3dcce50
Showing 1 changed file with 54 additions and 84 deletions.
138 changes: 54 additions & 84 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -627,111 +627,81 @@ class GpuHashAggregateIterator(
val aggModeCudfAggregates =
AggregateUtils.computeAggModeCudfAggregates(aggregateExpressions, aggBufferAttributes)

//
// expressions to pick input to the aggregate, and finalize the output to the result projection.
//
// Pick update distinct attributes or input projections for Partial
val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition(
_.isDistinct)
val updateExpressionsDistinct =
distinctAggExpressions.flatMap(
_.aggregateFunction.updateExpressions)
val updateAttributesDistinct =
distinctAggExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
val inputProjectionsDistinct =
distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection)

// Pick merge non-distinct for PartialMerge
val mergeExpressionsNonDistinct =
nonDistinctAggExpressions
.flatMap(_.aggregateFunction.mergeExpressions)
.map(_.asInstanceOf[CudfAggregate].ref)
val mergeAttributesNonDistinct =
nonDistinctAggExpressions.flatMap(
_.aggregateFunction.aggBufferAttributes)

// Partial with no distinct or when modes are empty
val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions
.flatMap(_.aggregateFunction.inputProjection)

var distinctAttributes = Seq[Attribute]()
var distinctExpressions = Seq[Expression]()
var nonDistinctAttributes = Seq[Attribute]()
var nonDistinctExpressions = Seq[Expression]()
modeInfo.uniqueModes.foreach {
case PartialMerge =>
nonDistinctAttributes = mergeAttributesNonDistinct
nonDistinctExpressions = mergeExpressionsNonDistinct
case Partial =>
// Partial with distinct case
val updateExpressionsCudfAggsDistinct =
updateExpressionsDistinct.filter(_.isInstanceOf[CudfAggregate])
.map(_.asInstanceOf[CudfAggregate].ref)
if (inputProjectionsDistinct.exists(p => !p.isInstanceOf[NamedExpression])) {
// Case of distinct average we need to evaluate the "GpuCast and GpuIsNotNull" columns.
// Refer to how input projections are setup for GpuAverage.
// In the case where we have expressions to evaluate, pick the unique attributes
// references from them as you only have one column for it before you start evaluating.
distinctExpressions = inputProjectionsDistinct
distinctAttributes = inputProjectionsDistinct.flatMap(ref =>
ref.references.toSeq).distinct
} else {
distinctAttributes = updateAttributesDistinct
distinctExpressions = updateExpressionsCudfAggsDistinct
}
case _ =>
}
val inputBindExpressions = groupingExpressions ++ nonDistinctExpressions ++ distinctExpressions
val resultingBindAttributes = groupingAttributes ++ distinctAttributes ++ nonDistinctAttributes

val finalProjections = groupingExpressions ++
aggregateExpressions.map(_.aggregateFunction.evaluateExpression)

// boundInputReferences is used to pick out of the input batch the appropriate columns
// for aggregation
// - Partial Merge mode: we use the inputBindExpressions which can be only
// non distinct merge expressions.
// - Partial or Complete mode: we use the inputProjections or distinct update expressions.
// - Partial, PartialMerge mode: we use the inputProjections or distinct update expressions
// - PartialMerge with Partial mode: we use the inputProjections or distinct update expressions
// for Partial and non distinct merge expressions for PartialMerge.
// - Final mode: we pick the columns in the order as handed to us.
val boundInputReferences = if (modeInfo.hasPartialMerge) {
GpuBindReferences.bindGpuReferences(inputBindExpressions, resultingBindAttributes)
} else if (modeInfo.hasFinalMode) {
// - Final or PartialMerge-only mode: we pick the columns in the order as handed to us.
// - Partial or Complete mode: we use the inputProjections or distinct update expressions.
val boundInputReferences = if (modeInfo.hasPartialMerge && modeInfo.hasPartialMode) {
// The 3rd stage of AggWithOneDistinct, which combines (partial) reduce-side
// nonDistinctAggExpressions and map-side distinctAggExpressions. For this stage, we need to
// switch the position of distinctAttributes and nonDistinctAttributes.
//
// The layout of the 2nd stage's outputs:
// groupingAttributes ++ distinctAttributes ++ nonDistinctAggBufferAttributes
//
// The layout of the 3rd stage's expressions:
// groupingAttributes ++ nonDistinctMergeAggAttributes ++ distinctPartialAggAttributes

val (distinctAggExpressions, nonDistinctAggExpressions) = aggregateExpressions.partition(
_.isDistinct)

// Pick merge non-distinct for PartialMerge
val nonDistinctExpressions = nonDistinctAggExpressions
.flatMap(_.aggregateFunction.mergeExpressions)
.map(_.asInstanceOf[CudfAggregate].ref)
val nonDistinctAttributes = nonDistinctAggExpressions
.flatMap(_.aggregateFunction.aggBufferAttributes)

// Pick update distinct attributes or input projections for Partial
val distinctExpressions = distinctAggExpressions.flatMap(_.aggregateFunction.inputProjection)
val distinctAttributes = distinctExpressions.flatMap(_.references.toSeq).distinct

val inputProjections = groupingExpressions ++ nonDistinctExpressions ++ distinctExpressions
val inputAttributes = groupingAttributes ++ distinctAttributes ++ nonDistinctAttributes
GpuBindReferences.bindGpuReferences(inputProjections, inputAttributes)
} else if (modeInfo.hasFinalMode || modeInfo.hasPartialMerge) {
// two possible conditions:
// 1. The Final stage, including the 2nd stage of NoDistinctAgg and 4th stage of
// AggWithOneDistinct, which needs no input projections. Because the child outputs are
// internal aggregation buffers, which are aligned for the final stage.
//
// 2. The 2nd stage (PartialMerge) of AggWithOneDistinct, which works like the final stage
// taking the child outputs as inputs without any projections.
GpuBindReferences.bindGpuReferences(childAttr.attrs.asInstanceOf[Seq[Expression]], childAttr)
} else {
GpuBindReferences.bindGpuReferences(inputProjections, childAttr)
}
// The first aggregation stage (including Partial or Complete modes), whose child node
// is not an AggregateExec. Therefore, input projections are essential.
val inputProjections: Seq[Expression] = groupingExpressions ++ aggregateExpressions
.flatMap(_.aggregateFunction.inputProjection)
GpuBindReferences.bindGpuReferences(inputProjections, childAttr) }

val boundFinalProjections = if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) {
val finalProjections = groupingExpressions ++
aggregateExpressions.map(_.aggregateFunction.evaluateExpression)
Some(GpuBindReferences.bindGpuReferences(finalProjections, aggBufferAttributes))
} else {
None
}

// allAttributes can be different things, depending on aggregation mode:
// - Partial mode: grouping key + cudf aggregates (e.g. no avg, intead sum::count
// - Final mode: grouping key + spark aggregates (e.g. avg)
val finalAttributes = groupingAttributes ++ aggregateAttributes

// boundResultReferences is used to project the aggregated input batch(es) for the result.
// - Partial mode: it's a pass through. We take whatever was aggregated and let it come
// out of the node as is.
// - Final or Complete mode: we use resultExpressions to pick out the correct columns that
// finalReferences has pre-processed for us
val boundResultReferences = if (modeInfo.hasPartialMode) {
// - For final stages (Final or Complete mode): we use resultExpressions to pick out the
// correct columns that finalReferences has pre-processed for us
// - For non-Final stages (Partial or PartialMerge mode): it's just a pass through. We take
// whatever was aggregated and let it come out of the node as is.
val boundResultReferences = if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) {
GpuBindReferences.bindGpuReferences(
resultExpressions,
resultExpressions.map(_.toAttribute))
} else if (modeInfo.hasFinalMode || modeInfo.hasCompleteMode) {
GpuBindReferences.bindGpuReferences(
resultExpressions,
finalAttributes)
groupingAttributes ++ aggregateAttributes)
} else {
GpuBindReferences.bindGpuReferences(
resultExpressions,
groupingAttributes)
resultExpressions.map(_.toAttribute))
}

BoundExpressionsModeAggregates(boundInputReferences, boundFinalProjections,
boundResultReferences, aggModeCudfAggregates)
}
Expand Down

0 comments on commit 3dcce50

Please sign in to comment.