Skip to content

Commit

Permalink
Updated the type check
Browse files Browse the repository at this point in the history
Signed-off-by: Firestarman <[email protected]>
  • Loading branch information
firestarman committed Jan 28, 2021
1 parent 724a87f commit c5c982d
Showing 1 changed file with 12 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -734,12 +734,11 @@ object GpuOverrides {
"Calculates a return value for every input row of a table based on a group (or " +
"\"window\") of rows",
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested + TypeSig.ARRAY.nested(TypeSig.STRUCT),
TypeSig.all,
Seq(ParamCheck("windowFunction",
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested +
TypeSig.ARRAY.nested(TypeSig.STRUCT),
TypeSig.all),
ParamCheck("windowSpec",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL,
Expand Down Expand Up @@ -1644,13 +1643,13 @@ object GpuOverrides {
expr[AggregateExpression](
"Aggregate expression",
ExprChecks.fullAgg(
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.STRUCT),
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested() + TypeSig.NULL +
TypeSig.ARRAY.nested(TypeSig.STRUCT),
TypeSig.all,
Seq(ParamCheck(
"aggFunc",
TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.STRUCT),
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested() + TypeSig.NULL +
TypeSig.ARRAY.nested(TypeSig.STRUCT),
TypeSig.all)),
Some(RepeatingParamCheck("filter", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(a, conf, p, r) => new ExprMeta[AggregateExpression](a, conf, p, r) {
Expand All @@ -1675,17 +1674,6 @@ object GpuOverrides {
GpuAggregateExpression(childExprs(0).convertToGpu().asInstanceOf[GpuAggregateFunction],
a.mode, a.isDistinct, filter.map(_.convertToGpu()), resultId)
}

// NOTE: Will remove this once all aggregates support array type.
override def tagExprForGpu(): Unit = {
// Only allow Array type for function "CollectList", since other aggregate functions
// have not been verified.
wrapped.dataType match {
case _: ArrayType if !wrapped.aggregateFunction.isInstanceOf[CollectList] =>
willNotWorkOnGpu("Now only 'collect_list' supports type of array.")
case _ =>
}
}
}),
expr[SortOrder](
"Sort order",
Expand Down Expand Up @@ -2185,13 +2173,13 @@ object GpuOverrides {
GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow)
}),
expr[CollectList](
"Collect a list of elements",
"Collect a list of elements, now only supported by windowing.",
/* It should be 'fullAgg' eventually but now only support windowing, so 'windowOnly' */
ExprChecks.windowOnly(
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.ARRAY.nested(TypeSig.all),
Seq(ParamCheck("input",
TypeSig.commonCudfTypes + TypeSig.STRUCT.nested(TypeSig.commonCudfTypes),
(TypeSig.commonCudfTypes + TypeSig.DECIMAL).nested() + TypeSig.STRUCT,
TypeSig.all))),
(c, conf, p, r) => new ExprMeta[CollectList](c, conf, p, r) {
override def convertToGpu(): GpuExpression = GpuCollectList(
Expand Down Expand Up @@ -2500,8 +2488,8 @@ object GpuOverrides {
(expand, conf, p, r) => new GpuExpandExecMeta(expand, conf, p, r)),
exec[WindowExec](
"Window-operator backend",
ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT +
TypeSig.ARRAY.nested(TypeSig.STRUCT + TypeSig.commonCudfTypes),
ExecChecks(
(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT).nested() + TypeSig.ARRAY,
TypeSig.all),
(windowOp, conf, p, r) =>
new GpuWindowExecMeta(windowOp, conf, p, r)
Expand Down

0 comments on commit c5c982d

Please sign in to comment.