Skip to content

Commit

Permalink
Support collect_list on GPU for windowing.
Browse files Browse the repository at this point in the history
Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman committed Jan 19, 2021
1 parent 4df63a5 commit 2d07f99
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2096,6 +2096,20 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[MakeDecimal](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow)
}),
expr[CollectList](
"Collect a list of elements",
/* It should be 'fullAgg' eventually but now only support windowing, so 'windowOnly' */
ExprChecks.windowOnly(TypeSig.ARRAY.nested(TypeSig.integral +
TypeSig.STRUCT.nested(TypeSig.commonCudfTypes)),
TypeSig.ARRAY.nested(TypeSig.all),
Seq(ParamCheck("input",
TypeSig.integral + TypeSig.STRUCT.nested(
TypeSig.integral + TypeSig.STRING + TypeSig.TIMESTAMP),
TypeSig.all))),
(c, conf, p, r) => new ExprMeta[CollectList](c, conf, p, r) {
override def convertToGpu(): GpuExpression = GpuCollectList(
childExprs.head.convertToGpu(), c.mutableAggBufferOffset, c.inputAggBufferOffset)
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ object ExprChecks {
}

/**
* Window only operations. Spark does not support these operations as anythign but a window
* Window only operations. Spark does not support these operations as anything but a window
* operation.
*/
def windowOnly(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExprId, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Complete, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BooleanType, DataType, DoubleType, LongType, NumericType, StructType}
import org.apache.spark.sql.types._

trait GpuAggregateFunction extends GpuExpression {
// using the child reference, define the shape of the vectors sent to
Expand Down Expand Up @@ -529,3 +529,42 @@ abstract class GpuLastBase(child: Expression)
override lazy val deterministic: Boolean = false
override def toString: String = s"gpulast($child)${if (ignoreNulls) " ignore nulls"}"
}

/**
* Collects and returns a list of non-unique elements.
*
* FIXME Not sure whether GPU version requires the two offset parameters. Keep it here first.
*/
case class GpuCollectList(child: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends GpuDeclarativeAggregate with GpuAggregateWindowFunction {

def this(child: Expression) = this(child, 0, 0)

override lazy val deterministic: Boolean = false

override def nullable: Boolean = false

override def prettyName: String = "collect_list"

override def dataType: DataType = ArrayType(child.dataType, false)

override def children: Seq[Expression] = child :: Nil

// WINDOW FUNCTION
override val windowInputProjection: Seq[Expression] = Seq(child)
override def windowAggregation(inputs: Seq[(ColumnVector, Int)]): AggregationOnColumn =
Aggregation.collect().onColumn(inputs.head._2)

// Declarative aggregate. But for now 'CollectList' does not support it.
// The functions as below should NOT be called yet, ensured by the "TypeCheck.windowOnly"
// when overriding the expression.
private lazy val cudfList = AttributeReference("collect_list", dataType)()
override val initialValues: Seq[GpuExpression] = throw new UnsupportedOperationException
override val updateExpressions: Seq[Expression] = throw new UnsupportedOperationException
override val mergeExpressions: Seq[GpuExpression] = throw new UnsupportedOperationException
override val evaluateExpression: Expression = throw new UnsupportedOperationException
override val inputProjection: Seq[Expression] = Seq(child)
override def aggBufferAttributes: Seq[AttributeReference] = cudfList :: Nil
}

0 comments on commit 2d07f99

Please sign in to comment.