diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 9a70c83f8ac..2ac20383e89 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -12329,8 +12329,8 @@ Accelerator support is described below. -S* -S* +PS* (missing nested BINARY, CALENDAR, UDT) +PS* (missing nested BINARY, CALENDAR, UDT) diff --git a/integration_tests/src/main/python/collection_ops_test.py b/integration_tests/src/main/python/collection_ops_test.py index 367a297ed3f..6c194e69379 100644 --- a/integration_tests/src/main/python/collection_ops_test.py +++ b/integration_tests/src/main/python/collection_ops_test.py @@ -18,7 +18,11 @@ from data_gen import * from pyspark.sql.types import * -@pytest.mark.parametrize('data_gen', all_gen, ids=idfn) +nested_gens = [ArrayGen(LongGen()), + StructGen([("a", LongGen())]), + MapGen(StringGen(pattern='key_[0-9]', nullable=False), StringGen())] + +@pytest.mark.parametrize('data_gen', all_gen + nested_gens, ids=idfn) @pytest.mark.parametrize('size_of_null', ['true', 'false'], ids=idfn) def test_size_of_array(data_gen, size_of_null): gen = ArrayGen(data_gen) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index e4818c31605..3391a89a1bc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2339,7 +2339,8 @@ object GpuOverrides { expr[Size]( "The size of an array or a map", ExprChecks.unaryProjectNotLambda(TypeSig.INT, TypeSig.INT, - (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.all), + (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.commonCudfTypes + TypeSig.NULL + + TypeSig.DECIMAL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP), (TypeSig.ARRAY + TypeSig.MAP).nested(TypeSig.all)), (a, conf, p, r) => new UnaryExprMeta[Size](a, conf, p, r) { override def convertToGpu(child: Expression): GpuExpression = diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala index 9129dc9a32d..19c641d1cd9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/collectionOperations.scala @@ -31,40 +31,18 @@ case class GpuSize(child: Expression, legacySizeOfNull: Boolean) override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable override protected def doColumnar(input: GpuColumnVector): ColumnVector = { - val inputBase = input.getBase - if (inputBase.getRowCount == 0) { - return GpuColumnVector.from(GpuScalar.from(0), 0, IntegerType).getBase - } // Compute sizes of cuDF.ListType to get sizes of each ArrayData or MapData, considering // MapData is represented as List of Struct in terms of cuDF. - // We compute list size via subtracting the offset of next element(row) to the current offset. - val collectionSize = { - // Here is a hack: using index -1 to fetch the offset column of list. - // In terms of cuDF native, the offset is the first (index 0) child of list_column_view. - // In JNI layer, we add 1 to the child index when fetching child column of ListType to keep - // alignment. - // So, in JVM layer, we have to use -1 as index to fetch the real first child of list_column. - withResource(inputBase.getChildColumnView(-1)) { offset => - withResource(offset.subVector(1)) { upBound => - withResource(offset.subVector(0, offset.getRowCount.toInt - 1)) { lowBound => - upBound.sub(lowBound) + withResource(input.getBase.countElements()) { collectionSize => + if (legacySizeOfNull) { + withResource(GpuScalar.from(-1)) { nullScalar => + withResource(input.getBase.isNull) { inputIsNull => + inputIsNull.ifElse(nullScalar, collectionSize) } } - } - } - - val nullScalar = if (legacySizeOfNull) { - GpuScalar.from(-1) - } else { - GpuScalar.from(null, IntegerType) - } - - withResource(collectionSize) { collectionSize => - withResource(nullScalar) { nullScalar => - withResource(inputBase.isNull) { inputIsNull => - inputIsNull.ifElse(nullScalar, collectionSize) - } + } else { + collectionSize.incRefCount() } } }