diff --git a/docs/supported_ops.md b/docs/supported_ops.md
index 9a70c83f8ac..57642f0d5be 100644
--- a/docs/supported_ops.md
+++ b/docs/supported_ops.md
@@ -699,9 +699,9 @@ Accelerator supports are described below.
NS |
NS |
NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
-NS |
-NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
@@ -10573,9 +10573,9 @@ Accelerator support is described below.
NS |
NS |
NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
-NS |
-NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
@@ -10616,9 +10616,9 @@ Accelerator support is described below.
NS |
NS |
NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
-NS |
-NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
@@ -10659,9 +10659,9 @@ Accelerator support is described below.
NS |
NS |
NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
-NS |
-NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
@@ -10702,9 +10702,9 @@ Accelerator support is described below.
NS |
NS |
NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
-NS |
-NS |
+PS* (missing nested DECIMAL, NULL, BINARY, CALENDAR, MAP, UDT) |
NS |
diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py
index 0abb3b72597..22d9d94ddc7 100644
--- a/integration_tests/src/main/python/data_gen.py
+++ b/integration_tests/src/main/python/data_gen.py
@@ -826,6 +826,26 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
decimal_gen_default, decimal_gen_scale_precision, decimal_gen_same_scale_precision,
decimal_gen_64bit]
+# Pyarrow will complain the error as below if the timestamp is out of range for both CPU and GPU,
+# so narrow down the time range to avoid exceptions causing test failures.
+#
+# "pyarrow.lib.ArrowInvalid: Casting from timestamp[us, tz=UTC] to timestamp[ns]
+# would result in out of bounds timestamp: 51496791452587000"
+#
+# This issue has been fixed in pyarrow by the PR https://github.com/apache/arrow/pull/7169
+# However it still requires PySpark to specify the new argument "timestamp_as_object".
+arrow_common_gen = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen,
+ string_gen, boolean_gen, date_gen,
+ TimestampGen(start=datetime(1970, 1, 1, tzinfo=timezone.utc),
+ end=datetime(2262, 1, 1, tzinfo=timezone.utc))]
+
+arrow_array_gens = [ArrayGen(subGen) for subGen in arrow_common_gen] + nested_array_gens_sample
+
+arrow_one_level_struct_gen = StructGen([
+ ['child'+str(i), sub_gen] for i, sub_gen in enumerate(arrow_common_gen)])
+
+arrow_struct_gens = [arrow_one_level_struct_gen,
+ StructGen([['child0', ArrayGen(short_gen)], ['child1', arrow_one_level_struct_gen]])]
# This function adds a new column named uniq_int where each row
# has a new unique integer value. It just starts at 0 and
diff --git a/integration_tests/src/main/python/udf_test.py b/integration_tests/src/main/python/udf_test.py
index 95cbd8cbeb8..36ba96ac499 100644
--- a/integration_tests/src/main/python/udf_test.py
+++ b/integration_tests/src/main/python/udf_test.py
@@ -45,6 +45,8 @@
'spark.rapids.sql.exec.WindowInPandasExec': 'true'
}
+data_gens_nested_for_udf = arrow_array_gens + arrow_struct_gens
+
####################################################################
# NOTE: pytest does not play well with pyspark udfs, because pyspark
# tries to import the dependencies for top level functions and
@@ -78,6 +80,17 @@ def iterator_add(to_process: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[
conf=arrow_udf_conf)
+@pytest.mark.parametrize('data_gen', data_gens_nested_for_udf, ids=idfn)
+def test_pandas_scalar_udf_nested_type(data_gen):
+ def nested_size(nested):
+ return pd.Series([nested.size]).repeat(len(nested))
+
+ my_udf = f.pandas_udf(nested_size, returnType=LongType())
+ assert_gpu_and_cpu_are_equal_collect(
+ lambda spark: unary_op_df(spark, data_gen).select(my_udf(f.col('a'))),
+ conf=arrow_udf_conf)
+
+
@approximate_float
@allow_non_gpu('AggregateInPandasExec', 'PythonUDF', 'Alias')
@pytest.mark.parametrize('data_gen', integral_gens, ids=idfn)
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index e4818c31605..067eebc0ab7 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -1979,7 +1979,7 @@ object GpuOverrides {
TypeSig.all,
repeatingParamCheck = Some(RepeatingParamCheck(
"param",
- TypeSig.commonCudfTypes,
+ (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
TypeSig.all))),
(a, conf, p, r) => new ExprMeta[PythonUDF](a, conf, p, r) {
override def replaceMessage: String = "not block GPU acceleration"
@@ -2639,7 +2639,9 @@ object GpuOverrides {
"The backend of the Scalar Pandas UDFs. Accelerates the data transfer between the" +
" Java process and the Python process. It also supports scheduling GPU resources" +
" for the Python process when enabled",
- ExecChecks(TypeSig.commonCudfTypes, TypeSig.all),
+ ExecChecks(
+ (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.STRUCT).nested(),
+ TypeSig.all),
(e, conf, p, r) =>
new SparkPlanMeta[ArrowEvalPythonExec](e, conf, p, r) {
val udfs: Seq[BaseExprMeta[PythonUDF]] =
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala
index 318e7dfbd74..cd0770757b0 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala
@@ -26,14 +26,14 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import ai.rapids.cudf.{AggregationOnColumn, ArrowIPCOptions, ArrowIPCWriterOptions, ColumnVector, HostBufferConsumer, HostBufferProvider, HostMemoryBuffer, NvtxColor, NvtxRange, StreamedTableReader, Table}
-import com.nvidia.spark.rapids.{Arm, ConcatAndConsumeAll, GpuAggregateWindowFunction, GpuBindReferences, GpuColumnVector, GpuColumnVectorFromBuffer, GpuExec, GpuMetric, GpuProjectExec, GpuSemaphore, GpuUnevaluable, RapidsBuffer, SpillableColumnarBatch, SpillPriorities}
+import ai.rapids.cudf._
+import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonRDD, SpecialLengths}
+import org.apache.spark.api.python._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.python.PythonUDFRunner
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
@@ -436,12 +436,13 @@ class GpuArrowPythonRunner(
table.close()
GpuSemaphore.releaseIfNecessary(TaskContext.get())
})
- pythonInSchema.foreach { field =>
- if (field.nullable) {
- builder.withColumnNames(field.name)
- } else {
- builder.withNotNullableColumnNames(field.name)
- }
+ // Flatten the names of nested struct columns, required by cudf arrow IPC writer.
+ flattenNames(pythonInSchema).foreach { case (name, nullable) =>
+ if (nullable) {
+ builder.withColumnNames(name)
+ } else {
+ builder.withNotNullableColumnNames(name)
+ }
}
Table.writeArrowIPCChunked(builder.build(), new BufferToStreamWriter(dataOut))
}
@@ -463,6 +464,16 @@ class GpuArrowPythonRunner(
if (onDataWriteFinished != null) onDataWriteFinished()
}
}
+
+ private def flattenNames(d: DataType, nullable: Boolean=true): Seq[(String, Boolean)] =
+ d match {
+ case s: StructType =>
+ s.flatMap(sf => Seq((sf.name, sf.nullable)) ++ flattenNames(sf.dataType, sf.nullable))
+ case m: MapType =>
+ flattenNames(m.keyType, nullable) ++ flattenNames(m.valueType, nullable)
+ case a: ArrayType => flattenNames(a.elementType, nullable)
+ case _ => Nil
+ }
}
}
}