Skip to content

Commit

Permalink
SNOW-1853347: Add mechanism to allow changing type strs when printing…
Browse files Browse the repository at this point in the history
… schema
  • Loading branch information
sfc-gh-jrose committed Jan 2, 2025
1 parent ba31301 commit beb70ea
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5655,7 +5655,10 @@ def convert(col: ColumnOrName) -> Expression:
return exprs

def _format_schema(
self, level: Optional[int] = None, translate_columns: Optional[dict] = None
self,
level: Optional[int] = None,
translate_columns: Optional[dict] = None,
translate_types: Optional[dict] = None,
) -> str:
def _format_datatype(name, dtype, nullable=None, depth=0):
if level is not None and depth >= level:
Expand All @@ -5669,6 +5672,10 @@ def _format_datatype(name, dtype, nullable=None, depth=0):
extra_lines = []
type_str = dtype.__class__.__name__

translated = None
if translate_types:
translated = translate_types.get(type_str, type_str)

# Structured Type format their parameters on multiple lines.
if isinstance(dtype, ArrayType):
extra_lines = [
Expand All @@ -5695,7 +5702,7 @@ def _format_datatype(name, dtype, nullable=None, depth=0):

return "\n".join(
[
f"{prefix} |-- {name}: {type_str}{nullable_str}",
f"{prefix} |-- {name}: {translated or type_str}{nullable_str}",
]
+ [f"{line}" for line in extra_lines if line]
)
Expand Down
13 changes: 13 additions & 0 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,19 @@ def test_structured_type_print_schema(
== 'root\n |-- "map": MapType (nullable = True)'
)

# Check that column types can be translated
assert (
df._format_schema(
2,
translate_types={
"MapType": "dict",
"StringType": "str",
"ArrayType": "list",
},
)
== 'root\n |-- "MAP": dict (nullable = True)\n | |-- key: str\n | |-- value: list'
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
Expand Down

0 comments on commit beb70ea

Please sign in to comment.