Skip to content

Commit

Permalink
some of the changes recommended by Gene
Browse files Browse the repository at this point in the history
  • Loading branch information
harshmotw-db committed Nov 6, 2024
1 parent 876d5ca commit 2ecf567
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 4 deletions.
1 change: 0 additions & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,6 @@ def test_udf_with_variant_output(self):
expected = [Row(udf=i) for i in range(10)]

for f in [scalar_f, iter_f]:
# with self.assertRaises(AnalysisException) as ae:
result = self.spark.range(10).select(f(col("id")).cast("int").alias("udf")).collect()
self.assertEqual(result, expected)

Expand Down
46 changes: 46 additions & 0 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def test_udf_with_variant_input(self):
def test_udf_with_complex_variant_input(self):
for arrow_enabled in ["false", "true"]:
with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": arrow_enabled}):
# struct<variant>
df = self.spark.range(0, 10).selectExpr(
"named_struct('v', parse_json(cast(id as string))) struct_of_v"
)
Expand All @@ -353,6 +354,24 @@ def test_udf_with_complex_variant_input(self):
expected = [Row(udf="{0}".format(i)) for i in range(10)]
self.assertEqual(result, expected)

# array<variant>
df = self.spark.range(0, 10).selectExpr(
"array(parse_json(cast(id as string))) array_of_v"
)
u = udf(lambda u: str(u[0]), StringType())
result = df.select(u(col("array_of_v"))).collect()
expected = [Row(udf="{0}".format(i)) for i in range(10)]
self.assertEqual(result, expected)

# map<string, variant>
df = self.spark.range(0, 10).selectExpr(
"map('v', parse_json(cast(id as string))) map_of_v"
)
u = udf(lambda u: str(u["v"]), StringType())
result = df.select(u(col("map_of_v"))).collect()
expected = [Row(udf="{0}".format(i)) for i in range(10)]
self.assertEqual(result, expected)

def test_udf_with_variant_output(self):
for arrow_enabled in ["false", "true"]:
with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": arrow_enabled}):
Expand All @@ -373,6 +392,33 @@ def test_udf_with_complex_variant_output(self):
for arrow_enabled in ["false", "true"]:
with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": arrow_enabled}):
# The variant value returned corresponds to a JSON string of {"a": "<a-j>"}.
# struct<variant>
u = udf(
lambda i: {
"v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))
},
StructType([StructField("v", VariantType(), True)]),
)
result = self.spark.range(0, 10).select(
u(col("id")).cast("string").alias("udf")
).collect()
expected = [Row(udf=f"{{{{\"a\":\"{chr(97 + i)}\"}}}}") for i in range(10)]
self.assertEqual(result, expected)

# array<variant>
u = udf(
lambda i: [
VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))
],
ArrayType(VariantType()),
)
result = self.spark.range(0, 10).select(
u(col("id")).cast("string").alias("udf")
).collect()
expected = [Row(udf=f"[{{\"a\":\"{chr(97 + i)}\"}}]") for i in range(10)]
self.assertEqual(result, expected)

# map<string, variant>
u = udf(
lambda i: {
"v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 97]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ private[sql] object ArrowUtils {
val fieldType = new FieldType(
nullable,
ArrowType.Struct.INSTANCE,
null,
Map("variant" -> "true").asJava)
null
)
val metadataFieldType = new FieldType(
false,
toArrowType(BinaryType, timeZoneId, largeVarTypes),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ private[arrow] class StructWriter(
val valueVector: StructVector,
children: Array[ArrowFieldWriter]) extends ArrowFieldWriter {

lazy val isVariant = valueVector.getField.getMetadata.get("variant") == "true"
lazy val isVariant = ArrowUtils.isVariantField(valueVector.getField)

override def setNull(): Unit = {
var i = 0
Expand Down

0 comments on commit 2ecf567

Please sign in to comment.