Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Yohahaha committed May 12, 2023
1 parent 9562244 commit 5684321
Showing 1 changed file with 20 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,33 +231,27 @@ case class GlutenHashAggregateExecTransformer(
// This is a special handling for PartialMerge in the execution of distinct.
// Use Partial phase instead for this aggregation.
val modeKeyWord = modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode)
var sigName = ExpressionMappings.aggregate_functions_map.getOrElse(
aggregateFunction.getClass, ExpressionMappings.getAggSigOther(aggregateFunction.prettyName))
// Check whether Gluten supports this aggregate function.
if (sigName.isEmpty) {
throw new UnsupportedOperationException(s"not currently supported: $aggregateFunction.")
}

def generateMergeCompanionNode(): Unit = {
aggregateMode match {
case Partial =>
val partialNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, sigName),
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction))
aggregateNodeList.add(partialNode)
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, sigName, mixedPartialAndMerge),
.create(args, aggregateFunction, mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction))
aggregateNodeList.add(aggFunctionNode)
case Final =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, sigName),
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable))
Expand All @@ -271,18 +265,12 @@ case class GlutenHashAggregateExecTransformer(
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
generateMergeCompanionNode()
case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop |
_: Corr | _: CovPopulation | _: CovSample =>
generateMergeCompanionNode()
case First(_, ignoreNulls) =>
if (ignoreNulls) sigName = ExpressionMappings.FIRST_IGNORE_NULL
generateMergeCompanionNode()
case Last(_, ignoreNulls) =>
if (ignoreNulls) sigName = ExpressionMappings.LAST_IGNORE_NULL
_: Corr | _: CovPopulation | _: CovSample | _: First | _: Last =>
generateMergeCompanionNode()
case _ =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(
args, aggregateFunction, sigName,
args, aggregateFunction,
aggregateMode == PartialMerge && mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
Expand Down Expand Up @@ -756,15 +744,29 @@ object VeloxAggregateFunctionsBuilder {
* @param forMergeCompanion: whether this is a special case to solve mixed aggregation phases.
* @return
*/
def create(args: java.lang.Object, aggregateFunc: AggregateFunction, sigName: String,
def create(args: java.lang.Object, aggregateFunc: AggregateFunction,
forMergeCompanion: Boolean = false): Long = {
val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]

var sigName = ExpressionMappings.aggregate_functions_map.getOrElse(
aggregateFunc.getClass, ExpressionMappings.getAggSigOther(aggregateFunc.prettyName))
// Check whether Gluten supports this aggregate function.
if (sigName.isEmpty) {
throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.")
}
// Check whether each backend supports this aggregate function.
if (!BackendsApiManager.getValidatorApiInstance
.doAggregateFunctionValidate(sigName, aggregateFunc)) {
throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.")
}

aggregateFunc match {
case First(_, ignoreNulls) =>
if (ignoreNulls) sigName = ExpressionMappings.FIRST_IGNORE_NULL
case Last(_, ignoreNulls) =>
if (ignoreNulls) sigName = ExpressionMappings.LAST_IGNORE_NULL
}

// Use companion function for partial-merge aggregation functions on count distinct.
val substraitAggFuncName = if (!forMergeCompanion) sigName else sigName + "_merge"

Expand Down

0 comments on commit 5684321

Please sign in to comment.