Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50238][PYTHON] Add Variant Support in PySpark UDFs/UDTFs/UDAFs #48770

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1045,16 +1045,6 @@
"The input of <functionName> can't be <dataType> type data."
]
},
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
]
},
"UNSUPPORTED_UDF_OUTPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an output data type."
]
},
"VALUE_OUT_OF_RANGE" : {
"message" : [
"The <exprName> must be between <valueRange> (current value = <currentValue>)."
Expand Down
20 changes: 18 additions & 2 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from pyspark.sql.pandas.types import (
from_arrow_type,
is_variant,
to_arrow_type,
_create_converter_from_pandas,
_create_converter_to_pandas,
Expand Down Expand Up @@ -420,7 +421,14 @@ def __init__(
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types

if self._df_for_struct and types.is_struct(arrow_column.type):
# If the arrow type is struct, return a pandas dataframe where the fields of the struct
# correspond to columns in the DataFrame. However, if the arrow struct is actually a
# Variant, which is an atomic type, treat it as a non-struct arrow type.
if (
self._df_for_struct
and types.is_struct(arrow_column.type)
and not is_variant(arrow_column.type)
):
import pandas as pd
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved

series = [
Expand Down Expand Up @@ -505,7 +513,15 @@ def _create_batch(self, series):

arrs = []
for s, t in series:
if self._struct_in_pandas == "dict" and t is not None and pa.types.is_struct(t):
# Variants are represented in arrow as structs with additional metadata (checked by
# is_variant). If the data type is Variant, return a VariantVal atomic type instead of
# a dict of two binary values.
if (
self._struct_in_pandas == "dict"
and t is not None
and pa.types.is_struct(t)
and not is_variant(t)
):
# A pandas UDF should return pd.DataFrame when the return type is a struct type.
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved
# If it returns a pd.Series, it should throw an error.
if not isinstance(s, pd.DataFrame):
Expand Down
30 changes: 29 additions & 1 deletion python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def to_arrow_type(
elif type(dt) == VariantType:
fields = [
pa.field("value", pa.binary(), nullable=False),
pa.field("metadata", pa.binary(), nullable=False),
# The metadata field is tagged so we can identify that the arrow struct actually
# represents a variant.
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved
]
arrow_type = pa.struct(fields)
else:
Expand Down Expand Up @@ -221,6 +223,22 @@ def to_arrow_schema(
return pa.schema(fields)


def is_variant(at: "pa.DataType") -> bool:
"""Check if a PyArrow struct data type represents a variant"""
import pyarrow.types as types

assert types.is_struct(at)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an assert? Should this just return false if it is not a struct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is intended to be called with at.is_struct so I want to prevent developers from using this function with non-struct types. I should add a comment.

Checking is_struct in this function adds cost in production (where I'm assuming Python runs in optimized mode so asserts are disabled).


return any(
(
field.name == "metadata"
and b"variant" in field.metadata
and field.metadata[b"variant"] == b"true"
)
for field in at
) and any(field.name == "value" for field in at)


def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType:
"""Convert pyarrow type to Spark data type."""
import pyarrow.types as types
Expand Down Expand Up @@ -280,6 +298,8 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
from_arrow_type(at.item_type, prefer_timestamp_ntz),
)
elif types.is_struct(at):
if is_variant(at):
return VariantType()
return StructType(
[
StructField(
Expand Down Expand Up @@ -1295,6 +1315,14 @@ def convert_udt(value: Any) -> Any:

return convert_udt

elif isinstance(dt, VariantType):

def convert_variant(variant: Any) -> Any:
harshmotw-db marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(variant, VariantVal)
return {"value": variant.value, "metadata": variant.metadata}

return convert_variant

return None

conv = _converter(data_type)
Expand Down
41 changes: 40 additions & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@
from typing import cast

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, lit
from pyspark.sql.types import DoubleType, StructType, StructField, LongType, DayTimeIntervalType
from pyspark.sql.types import (
DoubleType,
StructType,
StructField,
LongType,
DayTimeIntervalType,
VariantType,
)
from pyspark.errors import ParseException, PythonException, PySparkTypeError
from pyspark.util import PythonEvalType
from pyspark.testing.sqlutils import (
Expand All @@ -42,33 +49,65 @@ def test_pandas_udf_basic(self):
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, VariantType())
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, VariantType(), PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, StructType([StructField("v", DoubleType())]), PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, StructType([StructField("v", VariantType())]), PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_basic_with_return_type_string(self):
udf = pandas_udf(lambda x: x, "double", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "variant", PandasUDFType.SCALAR)
self.assertEqual(udf.returnType, VariantType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v double", PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v variant", PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v double", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, "v variant", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(lambda x: x, returnType="v double", functionType=PandasUDFType.GROUPED_MAP)
self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

udf = pandas_udf(
lambda x: x, returnType="v variant", functionType=PandasUDFType.GROUPED_MAP
)
self.assertEqual(udf.returnType, StructType([StructField("v", VariantType())]))
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_decorator(self):
@pandas_udf(DoubleType())
def foo(x):
Expand Down
Loading