Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50238][PYTHON] Add Variant Support in PySpark UDFs/UDTFs/UDAFs #48770

Closed
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from pyspark.sql.pandas.types import (
from_arrow_type,
is_variant,
to_arrow_type,
_create_converter_from_pandas,
_create_converter_to_pandas,
Expand Down Expand Up @@ -420,7 +421,11 @@ def __init__(
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types

if self._df_for_struct and types.is_struct(arrow_column.type):
if (
self._df_for_struct
and types.is_struct(arrow_column.type)
and not is_variant(arrow_column.type)
):
import pandas as pd
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved

series = [
Expand Down Expand Up @@ -505,7 +510,12 @@ def _create_batch(self, series):

arrs = []
for s, t in series:
if self._struct_in_pandas == "dict" and t is not None and pa.types.is_struct(t):
if (
self._struct_in_pandas == "dict"
and t is not None
and pa.types.is_struct(t)
and not is_variant(t)
):
# A pandas UDF should return pd.DataFrame when the return type is a struct type.
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved
# If it returns a pd.Series, it should throw an error.
if not isinstance(s, pd.DataFrame):
Expand Down
23 changes: 22 additions & 1 deletion python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def to_arrow_type(
elif type(dt) == VariantType:
fields = [
pa.field("value", pa.binary(), nullable=False),
pa.field("metadata", pa.binary(), nullable=False),
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved
]
arrow_type = pa.struct(fields)
else:
Expand Down Expand Up @@ -221,6 +221,15 @@ def to_arrow_schema(
return pa.schema(fields)


def is_variant(at: "pa.DataType") -> bool:
"""Check if a PyArrow struct data type represents a variant"""
import pyarrow.types as types
assert types.is_struct(at)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an assert? Should this just return false if it is not a struct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is intended to be called with at.is_struct so I want to prevent developers from using this function with non-struct types. I should add a comment.

Checking is_struct in this function adds cost in production (where I'm assuming Python runs in optimized mode so asserts are disabled).


return any((field.name == "metadata" and b"variant" in field.metadata and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we check that the fields are metadata and value?

field.metadata[b"variant"] == b"true") for field in at)


def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType:
"""Convert pyarrow type to Spark data type."""
import pyarrow.types as types
Expand Down Expand Up @@ -280,6 +289,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
from_arrow_type(at.item_type, prefer_timestamp_ntz),
)
elif types.is_struct(at):
if is_variant(at):
return VariantType()
return StructType(
[
StructField(
Expand Down Expand Up @@ -1295,6 +1306,16 @@ def convert_udt(value: Any) -> Any:

return convert_udt

elif isinstance(dt, VariantType):
def convert_variant(variant: Any) -> Any:
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(variant, VariantVal)
return {
"value": variant.value,
"metadata": variant.metadata
}

return convert_variant

return None

conv = _converter(data_type)
Expand Down
39 changes: 38 additions & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from typing import cast

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, lit
from pyspark.sql.types import DoubleType, StructType, StructField, LongType, DayTimeIntervalType
from pyspark.sql.types import (
DoubleType,
StructType,
StructField,
LongType,
DayTimeIntervalType,
VariantType,
)
from pyspark.errors import ParseException, PythonException, PySparkTypeError
from pyspark.util import PythonEvalType
from pyspark.testing.sqlutils import (
Expand All @@ -42,33 +49,63 @@ def test_pandas_udf_basic(self):
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, VariantType())
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, VariantType(), PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, StructType([StructField("v", DoubleType())]), PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, StructType([StructField("v", VariantType())]), PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_basic_with_return_type_string(self):
udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "variant", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v double", PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v variant", PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v double", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v variant", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, returnType="v double", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, returnType="v variant", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_decorator(self):
@pandas_udf(DoubleType())
def foo(x):
Expand Down
98 changes: 69 additions & 29 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,46 +752,86 @@ def check_vectorized_udf_return_scalar(self):

def test_udf_with_variant_input(self):
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v")
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved
from pyspark.sql.functions import col

scalar_f = pandas_udf(lambda u: str(u), StringType())
scalar_f = pandas_udf(lambda u: u.apply(str), StringType(), PandasUDFType.SCALAR)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does pandas_udf go through the same path as an arrow udf path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for the most part. I recall that for pandas UDFs to work, I also had to add changes in arrow_to_pandas and _create_batch too because they treat struct types in a special way. Example: https://github.com/apache/spark/pull/48770/files#r1831583273

iter_f = pandas_udf(
lambda it: map(lambda u: str(u), it), StringType(), PandasUDFType.SCALAR_ITER
lambda it: map(lambda u: u.apply(str), it), StringType(), PandasUDFType.SCALAR_ITER
)

expected = [Row(udf="{0}".format(i)) for i in range(10)]

for f in [scalar_f, iter_f]:
with self.assertRaises(AnalysisException) as ae:
df.select(f(col("v"))).collect()

self.check_error(
exception=ae.exception,
errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
messageParameters={
"sqlExpr": '"<lambda>(v)"',
"dataType": "VARIANT",
},
)
result = df.select(f(col("v")).alias("udf")).collect()
self.assertEqual(result, expected)

def test_udf_with_variant_output(self):
# Corresponds to a JSON string of {"a": "b"}.
returned_variant = VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 1, 0, 1, 97]))
scalar_f = pandas_udf(lambda x: returned_variant, VariantType())
scalar_f = pandas_udf(
lambda u: u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))), VariantType()
)
iter_f = pandas_udf(
lambda it: map(lambda x: returned_variant, it), VariantType(), PandasUDFType.SCALAR_ITER
lambda it: map(lambda u: u.apply(
lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved
), it),
VariantType(),
PandasUDFType.SCALAR_ITER
)

expected = [Row(udf=i) for i in range(10)]
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved

for f in [scalar_f, iter_f]:
with self.assertRaises(AnalysisException) as ae:
self.spark.range(0, 10).select(f()).collect()

self.check_error(
exception=ae.exception,
errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE",
messageParameters={
"sqlExpr": '"<lambda>()"',
"dataType": "VARIANT",
},
)
result = self.spark.range(10).select(f(col("id")).cast("int").alias("udf")).collect()
self.assertEqual(result, expected)

def test_chained_udfs_with_variant(self):
scalar_first = pandas_udf(
lambda u: u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))), VariantType()
)
iter_first = pandas_udf(
lambda it: map(lambda u: u.apply(
lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))
), it),
VariantType(),
PandasUDFType.SCALAR_ITER
)
scalar_second = pandas_udf(lambda u: u.apply(str), StringType(), PandasUDFType.SCALAR)
iter_second = pandas_udf(
lambda it: map(lambda u: u.apply(str), it), StringType(), PandasUDFType.SCALAR_ITER
)

expected = [Row(udf="{0}".format(i)) for i in range(10)]

for f in [scalar_first, iter_first]:
for s in [scalar_second, iter_second]:
result = self.spark.range(10).select(s(f(col("id"))).alias("udf")).collect()
self.assertEqual(result, expected)

def test_chained_udfs_with_complex_variant(self):
scalar_first = pandas_udf(
lambda u: u.apply(lambda i: [VariantVal(bytes([12, i]), bytes([1, 0, 0]))]),
ArrayType(VariantType())
)
iter_first = pandas_udf(
lambda it: map(lambda u: u.apply(
lambda i: [VariantVal(bytes([12, i]), bytes([1, 0, 0]))]
), it),
ArrayType(VariantType()),
PandasUDFType.SCALAR_ITER
)
scalar_second = pandas_udf(lambda u: u.apply(lambda v: str(v[0])),
StringType(),
PandasUDFType.SCALAR)
iter_second = pandas_udf(
lambda it: map(lambda u: u.apply(lambda v: str(v[0])), it),
StringType(),
PandasUDFType.SCALAR_ITER
)

expected = [Row(udf="{0}".format(i)) for i in range(10)]

for f in [scalar_first, iter_first]:
for s in [scalar_second, iter_second]:
result = self.spark.range(10).select(s(f(col("id"))).alias("udf")).collect()
self.assertEqual(result, expected)

def test_vectorized_udf_decorator(self):
df = self.spark.range(10)
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,7 @@ def schema_from_udf(ddl):
("struct<>", True),
("struct<a: string, b: array<long>>", True),
("", True),
("a: int, b: variant", True),
("<a: int, b: variant>", False),
("randomstring", False),
("struct", False),
Expand Down
Loading