diff --git a/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala b/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala index 940f496234a75..af21cea7f435b 100644 --- a/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala +++ b/gluten-data/src/main/scala/io/glutenproject/execution/GlutenHashAggregateExecTransformer.scala @@ -231,18 +231,12 @@ case class GlutenHashAggregateExecTransformer( // This is a special handling for PartialMerge in the execution of distinct. // Use Partial phase instead for this aggregation. val modeKeyWord = modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode) - var sigName = ExpressionMappings.aggregate_functions_map.getOrElse( - aggregateFunction.getClass, ExpressionMappings.getAggSigOther(aggregateFunction.prettyName)) - // Check whether Gluten supports this aggregate function. - if (sigName.isEmpty) { - throw new UnsupportedOperationException(s"not currently supported: $aggregateFunction.") - } def generateMergeCompanionNode(): Unit = { aggregateMode match { case Partial => val partialNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, sigName), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), childrenNodeList, modeKeyWord, getIntermediateTypeNode(aggregateFunction)) @@ -250,14 +244,14 @@ case class GlutenHashAggregateExecTransformer( case PartialMerge => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( VeloxAggregateFunctionsBuilder - .create(args, aggregateFunction, sigName, mixedPartialAndMerge), + .create(args, aggregateFunction, mixedPartialAndMerge), childrenNodeList, modeKeyWord, getIntermediateTypeNode(aggregateFunction)) aggregateNodeList.add(aggFunctionNode) case Final => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, sigName), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)) @@ -271,18 +265,12 @@ case class GlutenHashAggregateExecTransformer( case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => generateMergeCompanionNode() case _: Average | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | - _: Corr | _: CovPopulation | _: CovSample => - generateMergeCompanionNode() - case First(_, ignoreNulls) => - if (ignoreNulls) sigName = ExpressionMappings.FIRST_IGNORE_NULL - generateMergeCompanionNode() - case Last(_, ignoreNulls) => - if (ignoreNulls) sigName = ExpressionMappings.LAST_IGNORE_NULL + _: Corr | _: CovPopulation | _: CovSample | _: First | _: Last => generateMergeCompanionNode() case _ => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( VeloxAggregateFunctionsBuilder.create( - args, aggregateFunction, sigName, + args, aggregateFunction, aggregateMode == PartialMerge && mixedPartialAndMerge), childrenNodeList, modeKeyWord, @@ -756,15 +744,29 @@ object VeloxAggregateFunctionsBuilder { * @param forMergeCompanion: whether this is a special case to solve mixed aggregation phases. * @return */ - def create(args: java.lang.Object, aggregateFunc: AggregateFunction, sigName: String, + def create(args: java.lang.Object, aggregateFunc: AggregateFunction, forMergeCompanion: Boolean = false): Long = { val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + var sigName = ExpressionMappings.aggregate_functions_map.getOrElse( + aggregateFunc.getClass, ExpressionMappings.getAggSigOther(aggregateFunc.prettyName)) + // Check whether Gluten supports this aggregate function. + if (sigName.isEmpty) { + throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.") + } // Check whether each backend supports this aggregate function. if (!BackendsApiManager.getValidatorApiInstance .doAggregateFunctionValidate(sigName, aggregateFunc)) { throw new UnsupportedOperationException(s"not currently supported: $aggregateFunc.") } + + aggregateFunc match { + case First(_, ignoreNulls) => + if (ignoreNulls) sigName = ExpressionMappings.FIRST_IGNORE_NULL + case Last(_, ignoreNulls) => + if (ignoreNulls) sigName = ExpressionMappings.LAST_IGNORE_NULL + } + // Use companion function for partial-merge aggregation functions on count distinct. val substraitAggFuncName = if (!forMergeCompanion) sigName else sigName + "_merge"