diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 9a70c83f8ac..2ac20383e89 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -12329,8 +12329,8 @@ Accelerator support is described below.
|
|
|
-S* |
-S* |
+PS* (missing nested BINARY, CALENDAR, UDT) |
+PS* (missing nested BINARY, CALENDAR, UDT) |
|
|
diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py
index 367a297ed3f..6c194e69379 100644
--- a/integration_tests/src/main/python/collection_ops_test.py
+++ b/integration_tests/src/main/python/collection_ops_test.py
@@ -18,7 +18,11 @@
from data_gen import *
from pyspark.sql.types import *
-@pytest.mark.parametrize('data_gen', all_gen, ids=idfn)
+nested_gens = [ArrayGen(LongGen()),
+ StructGen([("a", LongGen())]),
+ MapGen(StringGen(pattern='key_[0-9]', nullable=False), StringGen())]
+
+@pytest.mark.parametrize('data_gen', all_gen + nested_gens, ids=idfn)
@pytest.mark.parametrize('size_of_null', ['true', 'false'], ids=idfn)
def test_size_of_array(data_gen, size_of_null):
gen = ArrayGen(data_gen)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index e4818c31605..3391a89a1bc 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -2339,7 +2339,8 @@ object GpuOverrides {
expr[Size](
"The size of an array or a map",
ExprChecks.unaryProjectNotLambda(TypeSig.INT, TypeSig.INT,
- (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.all),
+ (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL
+ + TypeSig.DECIMAL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
(TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.all)),
(a, conf, p, r) => new UnaryExprMeta[Size](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression =
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala
index 9129dc9a32d..19c641d1cd9 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala
@@ -31,40 +31,18 @@ case class GpuSize(child: Expression, legacySizeOfNull: Boolean)
override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable
override protected def doColumnar(input: GpuColumnVector): ColumnVector = {
- val inputBase = input.getBase
- if (inputBase.getRowCount == 0) {
- return GpuColumnVector.from(GpuScalar.from(0), 0, IntegerType).getBase
- }
// Compute sizes of cuDF.ListType to get sizes of each ArrayData or MapData, considering
// MapData is represented as List of Struct in terms of cuDF.
- // We compute list size via subtracting the offset of next element(row) to the current offset.
- val collectionSize = {
- // Here is a hack: using index -1 to fetch the offset column of list.
- // In terms of cuDF native, the offset is the first (index 0) child of list_column_view.
- // In JNI layer, we add 1 to the child index when fetching child column of ListType to keep
- // alignment.
- // So, in JVM layer, we have to use -1 as index to fetch the real first child of list_column.
- withResource(inputBase.getChildColumnView(-1)) { offset =>
- withResource(offset.subVector(1)) { upBound =>
- withResource(offset.subVector(0, offset.getRowCount.toInt - 1)) { lowBound =>
- upBound.sub(lowBound)
+ withResource(input.getBase.countElements()) { collectionSize =>
+ if (legacySizeOfNull) {
+ withResource(GpuScalar.from(-1)) { nullScalar =>
+ withResource(input.getBase.isNull) { inputIsNull =>
+ inputIsNull.ifElse(nullScalar, collectionSize)
}
}
- }
- }
-
- val nullScalar = if (legacySizeOfNull) {
- GpuScalar.from(-1)
- } else {
- GpuScalar.from(null, IntegerType)
- }
-
- withResource(collectionSize) { collectionSize =>
- withResource(nullScalar) { nullScalar =>
- withResource(inputBase.isNull) { inputIsNull =>
- inputIsNull.ifElse(nullScalar, collectionSize)
- }
+ } else {
+ collectionSize.incRefCount()
}
}
}