From 2ecf5672beecbc2c84dc70a3aa9f8de68355c8c6 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 6 Nov 2024 14:38:03 -0800 Subject: [PATCH] some of the changes recommended by Gene --- .../tests/pandas/test_pandas_udf_scalar.py | 1 - python/pyspark/sql/tests/test_udf.py | 46 +++++++++++++++++++ .../apache/spark/sql/util/ArrowUtils.scala | 4 +- .../sql/execution/arrow/ArrowWriter.scala | 2 +- 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 670c8278f74b0..e82afe3c5a2ec 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -779,7 +779,6 @@ def test_udf_with_variant_output(self): expected = [Row(udf=i) for i in range(10)] for f in [scalar_f, iter_f]: - # with self.assertRaises(AnalysisException) as ae: result = self.spark.range(10).select(f(col("id")).cast("int").alias("udf")).collect() self.assertEqual(result, expected) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 37d4fa9dec725..a6b4e84f1eb6c 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -345,6 +345,7 @@ def test_udf_with_variant_input(self): def test_udf_with_complex_variant_input(self): for arrow_enabled in ["false", "true"]: with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": arrow_enabled}): + # struct df = self.spark.range(0, 10).selectExpr( "named_struct('v', parse_json(cast(id as string))) struct_of_v" ) @@ -353,6 +354,24 @@ def test_udf_with_complex_variant_input(self): expected = [Row(udf="{0}".format(i)) for i in range(10)] self.assertEqual(result, expected) + # array + df = self.spark.range(0, 10).selectExpr( + "array(parse_json(cast(id as string))) array_of_v" + ) + u = udf(lambda u: str(u[0]), StringType()) + result = df.select(u(col("array_of_v"))).collect() + expected = [Row(udf="{0}".format(i)) for i in range(10)] + self.assertEqual(result, expected) + + # map + df = self.spark.range(0, 10).selectExpr( + "map('v', parse_json(cast(id as string))) map_of_v" + ) + u = udf(lambda u: str(u["v"]), StringType()) + result = df.select(u(col("map_of_v"))).collect() + expected = [Row(udf="{0}".format(i)) for i in range(10)] + self.assertEqual(result, expected) + def test_udf_with_variant_output(self): for arrow_enabled in ["false", "true"]: with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": arrow_enabled}): @@ -373,6 +392,33 @@ def test_udf_with_complex_variant_output(self): for arrow_enabled in ["false", "true"]: with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": arrow_enabled}): # The variant value returned corresponds to a JSON string of {"a": ""}. + # struct + u = udf( + lambda i: { + "v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97])) + }, + StructType([StructField("v", VariantType(), True)]), + ) + result = self.spark.range(0, 10).select( + u(col("id")).cast("string").alias("udf") + ).collect() + expected = [Row(udf=f"{{{{\"a\":\"{chr(97 + i)}\"}}}}") for i in range(10)] + self.assertEqual(result, expected) + + # array + u = udf( + lambda i: [ + VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97])) + ], + ArrayType(VariantType()), + ) + result = self.spark.range(0, 10).select( + u(col("id")).cast("string").alias("udf") + ).collect() + expected = [Row(udf=f"[{{\"a\":\"{chr(97 + i)}\"}}]") for i in range(10)] + self.assertEqual(result, expected) + + # map u = udf( lambda i: { "v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97])) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 057715f499f6c..804c05bdf1947 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -144,8 +144,8 @@ private[sql] object ArrowUtils { val fieldType = new FieldType( nullable, ArrowType.Struct.INSTANCE, - null, - Map("variant" -> "true").asJava) + null + ) val metadataFieldType = new FieldType( false, toArrowType(BinaryType, timeZoneId, largeVarTypes), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index ca7703bef48bb..065b4b8c821a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -373,7 +373,7 @@ private[arrow] class StructWriter( val valueVector: StructVector, children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { - lazy val isVariant = valueVector.getField.getMetadata.get("variant") == "true" + lazy val isVariant = ArrowUtils.isVariantField(valueVector.getField) override def setNull(): Unit = { var i = 0