Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1865926: Infer schema for StructType columns from nested Rows #2805

Merged
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d9bf2cb
SNOW-1829870: Allow structured types to be enabled by default
sfc-gh-jrose Dec 5, 2024
ec43e1a
type checking
sfc-gh-jrose Dec 6, 2024
7f3a5fd
lint
sfc-gh-jrose Dec 6, 2024
2e0dce9
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 16, 2024
ed232de
Move flag to context
sfc-gh-jrose Dec 16, 2024
0dd7b91
typo
sfc-gh-jrose Dec 16, 2024
13c1424
SNOW-1852779 Fix AST encoding for Column `in_`, `asc`, and `desc` (#2…
sfc-gh-vbudati Dec 16, 2024
a787e74
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 16, 2024
b32806f
merge main and fix test
sfc-gh-jrose Dec 17, 2024
c3db223
make feature flag thread safe
sfc-gh-jrose Dec 17, 2024
1c262d7
typo
sfc-gh-jrose Dec 17, 2024
869931f
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 17, 2024
0caef58
Fix ast test
sfc-gh-jrose Dec 17, 2024
2380040
move lock
sfc-gh-jrose Dec 18, 2024
995e519
test coverage
sfc-gh-jrose Dec 18, 2024
1b89027
remove context manager
sfc-gh-jrose Dec 18, 2024
4fc61d4
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 19, 2024
26fd29e
switch to using patch
sfc-gh-jrose Dec 19, 2024
9295e11
move test to other module
sfc-gh-jrose Dec 19, 2024
fcd16d7
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 19, 2024
77a57a6
fix broken import
sfc-gh-jrose Dec 19, 2024
4769169
another broken import
sfc-gh-jrose Dec 19, 2024
af5af87
another test fix
sfc-gh-jrose Dec 19, 2024
dea741b
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 20, 2024
ee22980
SNOW-1865926: Infer schema for StructType columns from nested Rows
sfc-gh-jrose Dec 20, 2024
be73744
Merge branch 'main' into jrose_snow_1865926_create_dataframe_default_…
sfc-gh-jrose Jan 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,16 @@ def to_sql(
return f"'{binascii.hexlify(bytes(value)).decode()}' :: BINARY"

if isinstance(value, (list, tuple, array)) and isinstance(datatype, ArrayType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: ARRAY"
type_str = "ARRAY"
if datatype.structured:
type_str = convert_sp_to_sf_type(datatype)
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: {type_str}"

if isinstance(value, dict) and isinstance(datatype, MapType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: OBJECT"
type_str = "OBJECT"
if datatype.structured:
type_str = convert_sp_to_sf_type(datatype)
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: {type_str}"

if isinstance(datatype, VariantType):
# PARSE_JSON returns VARIANT, so no need to append :: VARIANT here explicitly.
Expand Down Expand Up @@ -260,11 +266,14 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
return "to_timestamp('2020-09-16 06:30:00')"
if isinstance(data_type, ArrayType):
if data_type.structured:
assert isinstance(data_type.element_type, DataType)
element = schema_expression(data_type.element_type, is_nullable)
return f"to_array({element}) :: {convert_sp_to_sf_type(data_type)}"
return "to_array(0)"
if isinstance(data_type, MapType):
if data_type.structured:
assert isinstance(data_type.key_type, DataType)
assert isinstance(data_type.value_type, DataType)
key = schema_expression(data_type.key_type, is_nullable)
value = schema_expression(data_type.value_type, is_nullable)
return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}"
Expand Down
23 changes: 16 additions & 7 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from snowflake.connector.cursor import ResultMetadata
from snowflake.connector.options import installed_pandas, pandas
from snowflake.snowpark._internal.utils import quote_name
from snowflake.snowpark.row import Row
from snowflake.snowpark.types import (
LTZ,
NTZ,
Expand Down Expand Up @@ -159,7 +160,7 @@ def convert_metadata_to_sp_type(
[
StructField(
field.name
if context._should_use_structured_type_semantics
if context._should_use_structured_type_semantics()
else quote_name(field.name, keep_case=True),
convert_metadata_to_sp_type(field, max_string_size),
nullable=field.is_nullable,
Expand Down Expand Up @@ -187,12 +188,15 @@ def convert_sf_to_sp_type(
max_string_size: int,
) -> DataType:
"""Convert the Snowflake logical type to the Snowpark type."""
semi_structured_fill = (
None if context._should_use_structured_type_semantics() else StringType()
)
if column_type_name == "ARRAY":
return ArrayType(StringType())
return ArrayType(semi_structured_fill)
if column_type_name == "VARIANT":
return VariantType()
if column_type_name in {"OBJECT", "MAP"}:
return MapType(StringType(), StringType())
return MapType(semi_structured_fill, semi_structured_fill)
if column_type_name == "GEOGRAPHY":
return GeographyType()
if column_type_name == "GEOMETRY":
Expand Down Expand Up @@ -438,6 +442,8 @@ def infer_type(obj: Any) -> DataType:
if key is not None and value is not None:
return MapType(infer_type(key), infer_type(value))
return MapType(NullType(), NullType())
elif isinstance(obj, Row) and context._should_use_structured_type_semantics():
return infer_schema(obj)
Comment on lines +445 to +446
Copy link
Contributor

@sfc-gh-aling sfc-gh-aling Jan 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a chance that the given datatype is not StructType while users still input a Row as data? if so what would happen, do we error out?

or Row data always auto inferred as StructType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what that scenario would look like. Can you give an example?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I didn't make it clear, actually I had two questions:

  1. when the infer schema logic would be triggered for Row values -- is it only when schema is not explicitly set?

  2. for my own learning purposes, will following Row input + MapType be a valid input?

    struct = Row(f1="v1", f2=2)
    df = structured_type_session.create_dataframe(
        [
            (struct),
        ],
        schema=[StructureType(MapType(xxx)],
    )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes, the schema is only inferred if no explicitly set.
  2. Today this example would give an error like this:
>>> struct = Row(f1="v1", f2=2)
>>> df = session.create_dataframe([(struct,),], schema=StructType([StructField("test", MapType())]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "src/snowflake/snowpark/_internal/utils.py", line 960, in func_call_wrapper
    return func(*args, **kwargs)
  File "src/snowflake/snowpark/session.py", line 3318, in create_dataframe
    raise TypeError(
TypeError: Cannot cast <class 'snowflake.snowpark.row.Row'>(Row(f1='v1', f2=2)) to MapType(StringType(), StringType()).

Currently createDataFrame does not know how to handle casting Rows to Maps. You could get it to work by calling Row.as_dict if you wanted it to be a MapType.

After this change it's improved slightly:

# Inferred
>>> df = session.create_dataframe([(struct),])
>>> df.schema
StructType([StructField('F1', StringType(), nullable=False), StructField('F2', LongType(), nullable=False)])
 
# or explicit
>>> df = session.create_dataframe([(struct),], schema=StructType([StructField('F1', StringType(), nullable=False), StructField('F2', LongType(), nullable=False)]))

Without this change the explicit schema still works.

elif isinstance(obj, (list, tuple)):
for v in obj:
if v is not None:
Expand Down Expand Up @@ -534,7 +540,10 @@ def merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType
return a


def python_value_str_to_object(value, tp: DataType) -> Any:
def python_value_str_to_object(value, tp: Optional[DataType]) -> Any:
if tp is None:
return None

if isinstance(tp, StringType):
return value

Expand Down Expand Up @@ -643,7 +652,7 @@ def python_type_to_snow_type(
element_type = (
python_type_to_snow_type(tp_args[0], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
return ArrayType(element_type), False

Expand All @@ -653,12 +662,12 @@ def python_type_to_snow_type(
key_type = (
python_type_to_snow_type(tp_args[0], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
value_type = (
python_type_to_snow_type(tp_args[1], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
return MapType(key_type, value_type), False

Expand Down
13 changes: 11 additions & 2 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Callable, Optional

import snowflake.snowpark
import threading

_use_scoped_temp_objects = True

Expand All @@ -21,8 +22,16 @@
_should_continue_registration: Optional[Callable[..., bool]] = None


# Global flag that determines if structured type semantics should be used
_should_use_structured_type_semantics = False
# Internal-only global flag that determines if structured type semantics should be used
_use_structured_type_semantics = False
_use_structured_type_semantics_lock = threading.RLock()


def _should_use_structured_type_semantics():
global _use_structured_type_semantics
global _use_structured_type_semantics_lock
with _use_structured_type_semantics_lock:
return _use_structured_type_semantics


def get_active_session() -> "snowflake.snowpark.Session":
Expand Down
9 changes: 9 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import pkg_resources

import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
import snowflake.snowpark.context as context
from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.options import installed_pandas, pandas
from snowflake.connector.pandas_tools import write_pandas
Expand Down Expand Up @@ -3294,6 +3295,14 @@ def convert_row_to_list(
data_type, (MapType, StructType)
):
converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder))
elif (
isinstance(value, Row)
and isinstance(data_type, StructType)
and context._should_use_structured_type_semantics()
):
converted_row.append(
json.dumps(value.as_dict(), cls=PythonObjJSONEncoder)
)
elif isinstance(data_type, VariantType):
converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder))
elif isinstance(data_type, GeographyType):
Expand Down
70 changes: 54 additions & 16 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from enum import Enum
from typing import Generic, List, Optional, Type, TypeVar, Union, Dict, Any

import snowflake.snowpark.context as context
import snowflake.snowpark._internal.analyzer.expression as expression
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

# Use correct version from here:
from snowflake.snowpark._internal.utils import installed_pandas, pandas, quote_name
import snowflake.snowpark.context as context

# TODO: connector installed_pandas is broken. If pyarrow is not installed, but pandas is this function returns the wrong answer.
# The core issue is that in the connector detection of both pandas/arrow are mixed, which is wrong.
Expand Down Expand Up @@ -334,16 +334,22 @@ class ArrayType(DataType):
def __init__(
self,
element_type: Optional[DataType] = None,
structured: bool = False,
structured: Optional[bool] = None,
) -> None:
self.structured = structured
self.element_type = element_type if element_type else StringType()
if context._should_use_structured_type_semantics():
self.structured = (
structured if structured is not None else element_type is not None
)
self.element_type = element_type
else:
self.structured = structured or False
self.element_type = element_type if element_type else StringType()

def __repr__(self) -> str:
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"

def _as_nested(self) -> "ArrayType":
if not context._should_use_structured_type_semantics:
if not context._should_use_structured_type_semantics():
return self
element_type = self.element_type
if isinstance(element_type, (ArrayType, MapType, StructType)):
Expand Down Expand Up @@ -378,6 +384,10 @@ def json_value(self) -> Dict[str, Any]:

def _fill_ast(self, ast: proto.SpDataType) -> None:
ast.sp_array_type.structured = self.structured
if self.element_type is None:
raise NotImplementedError(
"SNOW-1862700: AST does not support empty element_type."
)
self.element_type._fill_ast(ast.sp_array_type.ty)


Expand All @@ -388,20 +398,36 @@ def __init__(
self,
key_type: Optional[DataType] = None,
value_type: Optional[DataType] = None,
structured: bool = False,
structured: Optional[bool] = None,
) -> None:
self.structured = structured
self.key_type = key_type if key_type else StringType()
self.value_type = value_type if value_type else StringType()
if context._should_use_structured_type_semantics():
if (key_type is None and value_type is not None) or (
key_type is not None and value_type is None
):
raise ValueError(
"Must either set both key_type and value_type or leave both unset."
)
self.structured = (
structured if structured is not None else key_type is not None
)
self.key_type = key_type
self.value_type = value_type
else:
self.structured = structured or False
self.key_type = key_type if key_type else StringType()
self.value_type = value_type if value_type else StringType()

def __repr__(self) -> str:
return f"MapType({repr(self.key_type) if self.key_type else ''}, {repr(self.value_type) if self.value_type else ''})"
type_str = ""
if self.key_type and self.value_type:
type_str = f"{repr(self.key_type)}, {repr(self.value_type)}"
return f"MapType({type_str})"

def is_primitive(self):
return False

def _as_nested(self) -> "MapType":
if not context._should_use_structured_type_semantics:
if not context._should_use_structured_type_semantics():
return self
value_type = self.value_type
if isinstance(value_type, (ArrayType, MapType, StructType)):
Expand Down Expand Up @@ -447,6 +473,10 @@ def valueType(self):

def _fill_ast(self, ast: proto.SpDataType) -> None:
ast.sp_map_type.structured = self.structured
if self.key_type is None or self.value_type is None:
raise NotImplementedError(
"SNOW-1862700: AST does not support empty key or value type."
)
self.key_type._fill_ast(ast.sp_map_type.key_ty)
self.value_type._fill_ast(ast.sp_map_type.value_ty)

Expand Down Expand Up @@ -578,7 +608,7 @@ def __init__(

@property
def name(self) -> str:
if self._is_column or not context._should_use_structured_type_semantics:
if self._is_column or not context._should_use_structured_type_semantics():
return self.column_identifier.name
else:
return self._name
Expand All @@ -593,7 +623,7 @@ def name(self, n: Union[ColumnIdentifier, str]) -> None:
self.column_identifier = ColumnIdentifier(n)

def _as_nested(self) -> "StructField":
if not context._should_use_structured_type_semantics:
if not context._should_use_structured_type_semantics():
return self
datatype = self.datatype
if isinstance(datatype, (ArrayType, MapType, StructType)):
Expand Down Expand Up @@ -651,9 +681,17 @@ class StructType(DataType):
"""Represents a table schema or structured column. Contains :class:`StructField` for each field."""

def __init__(
self, fields: Optional[List["StructField"]] = None, structured=False
self,
fields: Optional[List["StructField"]] = None,
structured: Optional[bool] = None,
) -> None:
self.structured = structured
if context._should_use_structured_type_semantics():
self.structured = (
structured if structured is not None else fields is not None
)
else:
self.structured = structured or False

self.fields = []
for field in fields or []:
self.add(field)
Expand Down Expand Up @@ -683,7 +721,7 @@ def add(
return self

def _as_nested(self) -> "StructType":
if not context._should_use_structured_type_semantics:
if not context._should_use_structured_type_semantics():
return self
return StructType(
[field._as_nested() for field in self.fields], self.structured
Expand Down
11 changes: 10 additions & 1 deletion src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
TempObjectType,
parse_positional_args_to_list,
publicapi,
warning,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.types import DataType
from snowflake.snowpark.types import DataType, MapType

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
# Python 3.9 can use both
Expand Down Expand Up @@ -710,6 +711,14 @@ def _do_register_udaf(
name,
)

if isinstance(return_type, MapType):
if return_type.structured:
warning(
"_do_register_udaf",
"Snowflake does not support structured maps as return type for UDAFs. Downcasting to semi-structured object.",
)
return_type = MapType()

# Capture original parameters.
if _emit_ast:
stmt = self._session._ast_batch.assign()
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,10 @@ def _do_register_udtf(
output_schema=output_schema,
)

# Structured Struct is interpreted as Object by function registration
# Force unstructured to ensure Table return type.
output_schema.structured = False

# Capture original parameters.
if _emit_ast:
stmt = self._session._ast_batch.assign()
Expand Down
Loading
Loading