Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dtypes: Add type converstion to pyspark types #574

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fennel/client_tests/test_featureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,7 @@ class IndexFeatures:
assert response.status_code == requests.codes.OK, response.json()


@pytest.mark.integration
@mock
def test_query_time_features(client):
@meta(owner="[email protected]")
Expand Down
118 changes: 118 additions & 0 deletions fennel/internal_lib/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,25 @@
parse_datetime_in_value,
)

import decimal
from pyspark.sql.types import (
DataType,
LongType,
DoubleType,
StringType,
BooleanType,
TimestampType,
DateType,
BinaryType,
DecimalType,
NullType,
ArrayType,
MapType,
StructType,
StructField,
)


FENNEL_STRUCT = "__fennel_struct__"
FENNEL_STRUCT_SRC_CODE = "__fennel_struct_src_code__"
FENNEL_STRUCT_DEPENDENCIES_SRC_CODE = "__fennel_struct_dependencies_src_code__"
Expand Down Expand Up @@ -976,3 +995,102 @@ def check_dtype_has_struct_type(dtype: schema_proto.DataType) -> bool:
elif dtype.HasField("map_type"):
return check_dtype_has_struct_type(dtype.map_type.value)
return False


def to_spark_type(py_type: Any, nullable: bool = True) -> DataType:
"""
Recursively convert a Python type to a corresponding PySpark SQL DataType.

Args:
py_type: The Python type to convert.
nullable: Whether the resulting DataType should be nullable.

Returns:
A PySpark SQL DataType corresponding to the given Python type.
"""
origin = get_origin(py_type)

# Handle Optional types (Union with None)
if origin is Union:
args = get_args(py_type)
if len(args) == 2 and type(None) in args:
# It's Optional[T]
non_none_type = (
args[0] if args[1] is type(None) else args[1] # noqa: E721
) # noqa: E721
return to_spark_type(non_none_type, nullable=True)
else:
# Unions of multiple types are not directly supported; default to StringType
return StringType()
elif isinstance(py_type, _Embedding):
return ArrayType(DoubleType(), containsNull=False)

# Handle List[T]
elif origin in (list, List):
element_type = get_args(py_type)[0]
spark_element_type = to_spark_type(element_type)
return ArrayType(spark_element_type, containsNull=True)

# Handle Dict[K, V]
elif origin in (dict, Dict):
key_type, value_type = get_args(py_type)
spark_key_type = to_spark_type(
key_type, nullable=False
) # Keys cannot be null
spark_value_type = to_spark_type(value_type)
return MapType(spark_key_type, spark_value_type, valueContainsNull=True)

# Handle dataclass (StructType)
elif dataclasses.is_dataclass(py_type):
fields = []
for field in dataclasses.fields(py_type):
field_name = field.name
field_type = field.type
field_nullable = False

# Check for Optional fields
field_origin = get_origin(field_type)
if field_origin is Union:
field_args = get_args(field_type)
if len(field_args) == 2 and type(None) in field_args:
field_nullable = True
field_type = (
field_args[0]
if field_args[1] is type(None) # noqa: E721
else field_args[1]
)

spark_field_type = to_spark_type(
field_type, nullable=field_nullable
)
fields.append(
StructField(
field_name, spark_field_type, nullable=field_nullable
)
)
return StructType(fields)

# Handle basic types
elif py_type is int:
return LongType()
elif py_type is float:
return DoubleType()
elif py_type is str:
return StringType()
elif py_type is bool:
return BooleanType()
elif py_type is datetime:
return TimestampType()
elif py_type is date:
return DateType()
elif py_type is bytes:
return BinaryType()
elif py_type is decimal.Decimal:
# Default precision and scale; adjust as needed
return DecimalType(precision=38, scale=18)

elif py_type is type(None): # noqa: E721
return NullType()

else:
raise ValueError(f"Unsupported type: {py_type}")
106 changes: 106 additions & 0 deletions fennel/internal_lib/schema/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime, date, timedelta, timezone
from decimal import Decimal as PythonDecimal
from typing import Dict, List, Optional, Union, get_type_hints
from dataclasses import dataclass

import fennel.gen.schema_pb2 as proto
from fennel.dtypes.dtypes import (
Expand All @@ -26,8 +27,25 @@
is_hashable,
parse_json,
convert_dtype_to_arrow_type,
to_spark_type,
validate_val_with_dtype,
)
from pyspark.sql.types import (
DataType,
LongType,
DoubleType,
StringType,
BooleanType,
TimestampType,
DateType,
BinaryType,
DecimalType,
NullType,
ArrayType,
MapType,
StructType,
StructField,
)


def test_get_data_type():
Expand Down Expand Up @@ -960,3 +978,91 @@ class ComplexStruct:
proto = get_datatype(original_type)
converted_type = from_proto(proto)
assert_struct_fields_match(self, original_type, converted_type)


def test_to_spark_type():
assert to_spark_type(int) == LongType()
assert to_spark_type(float) == DoubleType()
assert to_spark_type(str) == StringType()
assert to_spark_type(datetime) == TimestampType()
assert to_spark_type(date) == DateType()
# Types in pyspark are nullable by default
assert to_spark_type(Optional[int]) == LongType()

# Test complex types
assert to_spark_type(List[int]) == ArrayType(LongType(), containsNull=True)
assert to_spark_type(Dict[str, float]) == MapType(
StringType(), DoubleType(), valueContainsNull=True
)
assert to_spark_type(Optional[List[str]]) == ArrayType(
StringType(), containsNull=True
)

# Test nested complex types
assert to_spark_type(List[Dict[str, List[float]]]) == ArrayType(
MapType(
StringType(),
ArrayType(DoubleType(), containsNull=True),
valueContainsNull=True,
),
containsNull=True,
)

assert to_spark_type(Union[int, str, float]) == StringType()

# Test Embedding type (should default to ArrayType of DoubleType)
assert to_spark_type(Embedding[10]) == ArrayType(
DoubleType(), containsNull=False
)

# Test complex nested structure
complex_type = Dict[str, List[Optional[Dict[int, Union[str, float]]]]]
expected_complex_type = MapType(
StringType(),
ArrayType(
MapType(
LongType(),
StringType(),
valueContainsNull=True,
),
containsNull=True,
),
valueContainsNull=True,
)
assert to_spark_type(complex_type) == expected_complex_type

@struct
class Address:
street: str
city: str
zip_code: Optional[int]

@struct
class Person:
name: str
age: int
address: Address
emails: List[str]

# Convert Person dataclass to StructType
spark_type = to_spark_type(Person)

expected_spark_type = StructType(
[
StructField("name", StringType(), False),
StructField("age", LongType(), False),
StructField(
"address",
StructType(
[
StructField("street", StringType(), False),
StructField("city", StringType(), False),
StructField("zip_code", LongType(), True),
]
),
False,
),
StructField("emails", ArrayType(StringType()), False),
]
)
assert spark_type == expected_spark_type
Loading
Loading