diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/aggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/aggregateFunctions.scala index 7ca2921fcb8..a46d66d4f5c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/aggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/aggregateFunctions.scala @@ -653,15 +653,6 @@ case class GpuExtractChunk32( override def children: Seq[Expression] = Seq(data) } -object GpuExtractChunk32 { - /** Build the aggregate expressions for summing the four 32-bit chunks of a 128-bit decimal. */ - def chunkSumExprs(): Seq[CudfAggregate] = (0 until 4).map { i => - // first three chunk columns are UINT32, so they are promoted to UINT64 during aggregation. - val colType = if (i < 3) GpuUnsignedLongType else LongType - new CudfSum(colType) - } -} - /** * Reassembles a 128-bit value from four separate 64-bit sum results * @param chunkAttrs attributes for the four 64-bit sum chunks ordered from least significant to @@ -1100,7 +1091,7 @@ case class GpuDecimal128Sum( chunks :+ GpuIsNull(child) } - private lazy val updateSumChunks = GpuExtractChunk32.chunkSumExprs + private lazy val updateSumChunks = (0 until 4).map(_ => new CudfSum(LongType)) override lazy val updateAggregates: Seq[CudfAggregate] = updateSumChunks :+ updateIsEmpty @@ -1119,7 +1110,7 @@ case class GpuDecimal128Sum( chunks ++ Seq(isEmpty, GpuIsNull(sum)) } - private lazy val mergeSumChunks = GpuExtractChunk32.chunkSumExprs() + private lazy val mergeSumChunks = (0 until 4).map(_ => new CudfSum(LongType)) // To be able to do decimal overflow detection, we need a CudfSum that does **not** ignore nulls. // Cudf does not have such an aggregation, so for merge we have to work around that similar to @@ -1484,7 +1475,7 @@ case class GpuDecimal128Average(child: Expression, dt: DecimalType) chunks :+ forCount } - private lazy val updateSumChunks = GpuExtractChunk32.chunkSumExprs() + private lazy val updateSumChunks = (0 until 4).map(_ => new CudfSum(LongType)) override lazy val updateAggregates: Seq[CudfAggregate] = updateSumChunks :+ updateCount @@ -1502,7 +1493,7 @@ case class GpuDecimal128Average(child: Expression, dt: DecimalType) chunks ++ Seq(count, GpuIsNull(sum)) } - private lazy val mergeSumChunks = GpuExtractChunk32.chunkSumExprs() + private lazy val mergeSumChunks = (0 until 4).map(_ => new CudfSum(LongType)) override lazy val mergeAggregates: Seq[CudfAggregate] = mergeSumChunks ++ Seq(mergeCount, mergeIsOverflow)