diff --git a/docs/configs.md b/docs/configs.md
index d3b47b3ad3e..f18f55a2bff 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -249,6 +249,7 @@ Name | SQL Function(s) | Description | Default Value | Notes
spark.rapids.sql.expression.Year|`year`|Returns the year from a date or timestamp|true|None|
spark.rapids.sql.expression.AggregateExpression| |Aggregate expression|true|None|
spark.rapids.sql.expression.Average|`avg`, `mean`|Average aggregate operator|true|None|
+spark.rapids.sql.expression.CollectList|`collect_list`|Collect a list of elements, now only supported by windowing.|false|This is disabled by default because for now the GPU collects null values to a list, but Spark does not. This will be fixed in future releases.|
spark.rapids.sql.expression.Count|`count`|Count aggregate operator|true|None|
spark.rapids.sql.expression.First|`first_value`, `first`|first aggregate operator|true|None|
spark.rapids.sql.expression.Last|`last`, `last_value`|last aggregate operator|true|None|
diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 7241384a559..3931dd15e18 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -745,9 +745,9 @@ Accelerator supports are described below.
NS |
NS |
NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
-NS |
-NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
@@ -15449,7 +15449,7 @@ Accelerator support is described below.
NS |
NS |
NS |
-PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
NS |
NS |
@@ -15491,7 +15491,7 @@ Accelerator support is described below.
NS |
NS |
NS |
-PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
NS |
NS |
@@ -15675,7 +15675,7 @@ Accelerator support is described below.
S |
NS |
NS |
-NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
NS |
NS |
@@ -15717,7 +15717,7 @@ Accelerator support is described below.
S |
NS |
NS |
-NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
NS |
NS |
@@ -15739,7 +15739,7 @@ Accelerator support is described below.
S |
NS |
NS |
-NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
NS |
NS |
@@ -15781,7 +15781,7 @@ Accelerator support is described below.
S |
NS |
NS |
-NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
NS |
NS |
@@ -15803,7 +15803,7 @@ Accelerator support is described below.
S |
NS |
NS |
-NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
NS |
NS |
@@ -15845,7 +15845,7 @@ Accelerator support is described below.
S |
NS |
NS |
-NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
NS |
NS |
NS |
@@ -15984,6 +15984,139 @@ Accelerator support is described below.
|
+CollectList |
+`collect_list` |
+Collect a list of elements, now only supported by windowing. |
+This is disabled by default because for now the GPU collects null values to a list, but Spark does not. This will be fixed in future releases. |
+aggregation |
+input |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+NS |
+ |
+ |
+ |
+
+
+reduction |
+input |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+NS |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+NS |
+ |
+ |
+ |
+
+
+window |
+input |
+S |
+S |
+S |
+S |
+S |
+S |
+S |
+S |
+S* |
+S |
+S* |
+NS |
+NS |
+NS |
+NS |
+NS |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, STRUCT, UDT) |
+NS |
+
+
+result |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+PS* (missing nested NULL, BINARY, CALENDAR, ARRAY, MAP, UDT) |
+ |
+ |
+ |
+
+
Count |
`count` |
Count aggregate operator |
diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py
index 64127f7d497..129548b6c99 100644
--- a/integration_tests/src/main/python/window_function_test.py
+++ b/integration_tests/src/main/python/window_function_test.py
@@ -223,3 +223,66 @@ def test_window_aggs_for_ranges_of_dates(data_gen):
' range between 1 preceding and 1 following) as sum_c_asc '
'from window_agg_table'
)
+
+
+def _gen_data_for_collect(nullable=True):
+ return [
+ ('a', RepeatSeqGen(LongGen(), length=20)),
+ ('b', IntegerGen()),
+ ('c_int', IntegerGen(nullable=nullable)),
+ ('c_long', LongGen(nullable=nullable)),
+ ('c_time', DateGen(nullable=nullable)),
+ ('c_string', StringGen(nullable=nullable)),
+ ('c_float', FloatGen(nullable=nullable)),
+ ('c_decimal', DecimalGen(nullable=nullable, precision=8, scale=3)),
+ ('c_struct', StructGen(nullable=nullable, children=[
+ ['child_int', IntegerGen()],
+ ['child_time', DateGen()],
+ ['child_string', StringGen()],
+ ['child_decimal', DecimalGen(precision=8, scale=3)]]))]
+
+
+_collect_sql_string =\
+ '''
+ select
+ collect_list(c_int) over
+ (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_int,
+ collect_list(c_long) over
+ (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_long,
+ collect_list(c_time) over
+ (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_time,
+ collect_list(c_string) over
+ (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_string,
+ collect_list(c_float) over
+ (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_float,
+ collect_list(c_decimal) over
+ (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_decimal,
+ collect_list(c_struct) over
+ (partition by a order by b,c_int rows between UNBOUNDED preceding and CURRENT ROW) as collect_struct
+ from window_collect_table
+ '''
+
+# SortExec does not support array type, so sort the result locally.
+@ignore_order(local=True)
+@pytest.mark.xfail(reason="https://github.com/NVIDIA/spark-rapids/issues/1638")
+def test_window_aggs_for_rows_collect_list():
+ assert_gpu_and_cpu_are_equal_sql(
+ lambda spark : gen_df(spark, _gen_data_for_collect(), length=2048),
+ "window_collect_table",
+ _collect_sql_string,
+ {'spark.rapids.sql.expression.CollectList': 'true'})
+
+
+'''
+ Spark will drop nulls when collecting, but seems GPU does not yet, so exceptions come up.
+ Now set nullable to false to verify the current functionality without null values.
+ Once native supports dropping nulls, will enable the tests above and remove this one.
+'''
+# SortExec does not support array type, so sort the result locally.
+@ignore_order(local=True)
+def test_window_aggs_for_rows_collect_list_no_nulls():
+ assert_gpu_and_cpu_are_equal_sql(
+ lambda spark : gen_df(spark, _gen_data_for_collect(False), length=2048),
+ "window_collect_table",
+ _collect_sql_string,
+ {'spark.rapids.sql.expression.CollectList': 'true'})
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index f66f683ba50..cbacfb52a73 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -775,11 +775,11 @@ object GpuOverrides {
"\"window\") of rows",
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
- TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
+ TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all,
Seq(ParamCheck("windowFunction",
TypeSig.commonCudfTypes + TypeSig.DECIMAL +
- TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
+ TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all),
ParamCheck("windowSpec",
TypeSig.CALENDAR + TypeSig.NULL + TypeSig.integral + TypeSig.DECIMAL,
@@ -1716,11 +1716,13 @@ object GpuOverrides {
expr[AggregateExpression](
"Aggregate expression",
ExprChecks.fullAgg(
- TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
+ TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
+ TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all,
Seq(ParamCheck(
"aggFunc",
- TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL,
+ TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL +
+ TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
TypeSig.all)),
Some(RepeatingParamCheck("filter", TypeSig.BOOLEAN, TypeSig.BOOLEAN))),
(a, conf, p, r) => new ExprMeta[AggregateExpression](a, conf, p, r) {
@@ -2296,6 +2298,22 @@ object GpuOverrides {
override def convertToGpu(child: Expression): GpuExpression =
GpuMakeDecimal(child, a.precision, a.scale, a.nullOnOverflow)
}),
+ expr[CollectList](
+ "Collect a list of elements, now only supported by windowing.",
+ // It should be 'fullAgg' eventually but now only support windowing,
+ // so 'aggNotGroupByOrReduction'
+ ExprChecks.aggNotGroupByOrReduction(
+ TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
+ TypeSig.ARRAY.nested(TypeSig.all),
+ Seq(ParamCheck("input",
+ TypeSig.commonCudfTypes + TypeSig.DECIMAL +
+ TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL),
+ TypeSig.all))),
+ (c, conf, p, r) => new ExprMeta[CollectList](c, conf, p, r) {
+ override def convertToGpu(): GpuExpression = GpuCollectList(
+ childExprs.head.convertToGpu(), c.mutableAggBufferOffset, c.inputAggBufferOffset)
+ }).disabledByDefault("for now the GPU collects null values to a list, but Spark does not." +
+ " This will be fixed in future releases."),
expr[ScalarSubquery](
"Subquery that will return only one row and one column",
ExprChecks.projectOnly(
@@ -2636,7 +2654,11 @@ object GpuOverrides {
(expand, conf, p, r) => new GpuExpandExecMeta(expand, conf, p, r)),
exec[WindowExec](
"Window-operator backend",
- ExecChecks(TypeSig.commonCudfTypes + TypeSig.DECIMAL, TypeSig.all),
+ ExecChecks(
+ TypeSig.commonCudfTypes + TypeSig.DECIMAL +
+ TypeSig.STRUCT.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL) +
+ TypeSig.ARRAY.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL + TypeSig.STRUCT),
+ TypeSig.all),
(windowOp, conf, p, r) =>
new GpuWindowExecMeta(windowOp, conf, p, r)
),
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala
index 9700d462b0e..564a1c469bb 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala
@@ -199,13 +199,18 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow
}
}
}
- val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
- if (expectedType != aggColumn.getType) {
- withResource(aggColumn) { aggColumn =>
- GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
- }
- } else {
- GpuColumnVector.from(aggColumn, windowFunc.dataType)
+ // For nested type, do not cast
+ aggColumn.getType match {
+ case dType if dType.isNestedType =>
+ GpuColumnVector.from(aggColumn, windowFunc.dataType)
+ case _ =>
+ val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
+ // The API 'castTo' will take care of the 'from' type and 'to' type, and
+ // just increase the reference count by one when they are the same.
+ // so it is OK to always call it here.
+ withResource(aggColumn) { aggColumn =>
+ GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
+ }
}
}
@@ -230,13 +235,18 @@ case class GpuWindowExpression(windowFunction: Expression, windowSpec: GpuWindow
}
}
}
- val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
- if (expectedType != aggColumn.getType) {
- withResource(aggColumn) { aggColumn =>
- GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
- }
- } else {
- GpuColumnVector.from(aggColumn, windowFunc.dataType)
+ // For nested type, do not cast
+ aggColumn.getType match {
+ case dType if dType.isNestedType =>
+ GpuColumnVector.from(aggColumn, windowFunc.dataType)
+ case _ =>
+ val expectedType = GpuColumnVector.getNonNestedRapidsType(windowFunc.dataType)
+ // The API 'castTo' will take care of the 'from' type and 'to' type, and
+ // just increase the reference count by one when they are the same.
+ // so it is OK to always call it here.
+ withResource(aggColumn) { aggColumn =>
+ GpuColumnVector.from(aggColumn.castTo(expectedType), windowFunc.dataType)
+ }
}
}
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
index 76915acaa70..4a98c089e33 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala
@@ -985,7 +985,7 @@ object ExprChecks {
}
/**
- * Window only operations. Spark does not support these operations as anythign but a window
+ * Window only operations. Spark does not support these operations as anything but a window
* operation.
*/
def windowOnly(
@@ -996,6 +996,31 @@ object ExprChecks {
ExprChecksImpl(Map(
(WindowAggExprContext,
ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck))))
+
+ /**
+ * An aggregation check where window operations are supported by the plugin, but Spark
+ * also supports group by and reduction on these.
+ * This is now really for 'collect_list' which is only supported by windowing.
+ */
+ def aggNotGroupByOrReduction(
+ outputCheck: TypeSig,
+ sparkOutputSig: TypeSig,
+ paramCheck: Seq[ParamCheck] = Seq.empty,
+ repeatingParamCheck: Option[RepeatingParamCheck] = None): ExprChecks = {
+ val notWindowParamCheck = paramCheck.map { pc =>
+ ParamCheck(pc.name, TypeSig.none, pc.spark)
+ }
+ val notWindowRepeat = repeatingParamCheck.map { pc =>
+ RepeatingParamCheck(pc.name, TypeSig.none, pc.spark)
+ }
+ ExprChecksImpl(Map(
+ (GroupByAggExprContext,
+ ContextChecks(TypeSig.none, sparkOutputSig, notWindowParamCheck, notWindowRepeat)),
+ (ReductionAggExprContext,
+ ContextChecks(TypeSig.none, sparkOutputSig, notWindowParamCheck, notWindowRepeat)),
+ (WindowAggExprContext,
+ ContextChecks(outputCheck, sparkOutputSig, paramCheck, repeatingParamCheck))))
+ }
}
/**
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
index e76051ef36d..32da9614133 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
@@ -22,9 +22,9 @@ import com.nvidia.spark.rapids._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExprId, ImplicitCastInputTypes}
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Complete, Final, Partial, PartialMerge}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BooleanType, DataType, DoubleType, LongType, NumericType, StructType}
+import org.apache.spark.sql.types._
trait GpuAggregateFunction extends GpuExpression {
// using the child reference, define the shape of the vectors sent to
@@ -534,3 +534,46 @@ abstract class GpuLastBase(child: Expression)
override lazy val deterministic: Boolean = false
override def toString: String = s"gpulast($child)${if (ignoreNulls) " ignore nulls"}"
}
+
+/**
+ * Collects and returns a list of non-unique elements.
+ *
+ * The two 'offset' parameters are not used by GPU version, but are here for the compatibility
+ * with the CPU version and automated checks.
+ */
+case class GpuCollectList(child: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0)
+ extends GpuDeclarativeAggregate with GpuAggregateWindowFunction {
+
+ def this(child: Expression) = this(child, 0, 0)
+
+ // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the
+ // actual order of input rows.
+ override lazy val deterministic: Boolean = false
+
+ override def nullable: Boolean = false
+
+ override def prettyName: String = "collect_list"
+
+ override def dataType: DataType = ArrayType(child.dataType, false)
+
+ override def children: Seq[Expression] = child :: Nil
+
+ // WINDOW FUNCTION
+ override val windowInputProjection: Seq[Expression] = Seq(child)
+ override def windowAggregation(inputs: Seq[(ColumnVector, Int)]): AggregationOnColumn =
+ Aggregation.collect().onColumn(inputs.head._2)
+
+ // Declarative aggregate. But for now 'CollectList' does not support it.
+ // The members as below should NOT be used yet, ensured by the
+ // "TypeCheck.aggNotGroupByOrReduction" when trying to override the expression.
+ private lazy val cudfList = AttributeReference("collect_list", dataType)()
+ // Make them lazy to avoid being initialized when creating a GpuCollectList.
+ override lazy val initialValues: Seq[GpuExpression] = throw new UnsupportedOperationException
+ override lazy val updateExpressions: Seq[Expression] = throw new UnsupportedOperationException
+ override lazy val mergeExpressions: Seq[GpuExpression] = throw new UnsupportedOperationException
+ override lazy val evaluateExpression: Expression = throw new UnsupportedOperationException
+ override val inputProjection: Seq[Expression] = Seq(child)
+ override def aggBufferAttributes: Seq[AttributeReference] = cudfList :: Nil
+}