Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify Snowflake object name handling in the Snowpark AST #2789

Merged
merged 16 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions src/snowflake/snowpark/_internal/ast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,17 @@ def build_proto_from_struct_type(
ast_field.nullable = field.nullable


def build_sp_name(name: Union[str, Iterable[str]], expr: proto.SpName) -> None:
if isinstance(name, str):
expr.sp_name_flat.name = name
elif isinstance(name, Iterable):
expr.sp_name_structured.name.extend(name)
else:
raise ValueError(
f"Invalid object name: {name}. The object name must be a string or an iterable of strings."
)


# TODO(SNOW-1491199) - This method is not covered by tests until the end of phase 0. Drop the pragma when it is covered.
def _set_fn_name(
name: Union[str, Iterable[str]], fn: proto.FnNameRefExpr
Expand All @@ -358,26 +369,27 @@ def _set_fn_name(
Raises:
ValueError: Raised if the function name is not a string or an iterable of strings.
"""
if isinstance(name, str):
fn.name.fn_name_flat.name = name # type: ignore[attr-defined] # TODO(SNOW-1491199) # "FnNameRefExpr" has no attribute "name"
elif isinstance(name, Iterable):
fn.name.fn_name_structured.name.extend(name) # type: ignore[attr-defined] # TODO(SNOW-1491199) # "FnNameRefExpr" has no attribute "name"
else:
raise ValueError(
f"Invalid function name: {name}. The function name must be a string or an iterable of strings."
)
try:
build_sp_name(name, fn.name.name)
except ValueError as e:
raise ValueError("Invalid function name") from e


# TODO(SNOW-1491199) - This method is not covered by tests until the end of phase 0. Drop the pragma when it is covered.
def build_sp_table_name( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function is missing a return type annotation
expr_builder: proto.SpTableName, name: Union[str, Iterable[str]]
): # pragma: no cover
if isinstance(name, str):
expr_builder.sp_table_name_flat.name = name
elif isinstance(name, Iterable):
expr_builder.sp_table_name_structured.name.extend(name)
else:
raise ValueError(f"Invalid name type {type(name)} for SpTableName entity.")
def build_sp_table_name(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these type-specific functions necessary? One thing I notice is that the re-emitted exceptions lose the diagnostic contents of the original exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I notice is that the re-emitted exceptions lose the diagnostic contents of the original exception.

That's weird, raise ... from ... should preserve the contents. The functions aren't strictly necessary, but I find the call sites more readable (the name clarifies the intent), and the errors should have been more readable, too.

expr_builder: proto.SpNameRef, name: Union[str, Iterable[str]]
) -> None: # pragma: no cover
try:
build_sp_name(name, expr_builder.name)
except ValueError as e:
raise ValueError("Invalid table name") from e


def build_sp_view_name(expr: proto.SpNameRef, name: Union[str, Iterable[str]]) -> None:
try:
build_sp_name(name, expr.name)
except ValueError as e:
raise ValueError("Invalid view name") from e


def build_function_expr(
Expand Down Expand Up @@ -1108,7 +1120,7 @@ def build_udf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function i
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1197,7 +1209,7 @@ def build_udaf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location.value = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1294,7 +1306,7 @@ def build_udtf( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down Expand Up @@ -1406,7 +1418,7 @@ def build_sproc( # type: ignore[no-untyped-def] # TODO(SNOW-1491199) # Function
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
import_expr = proto.SpNameRef()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
Expand Down
Loading
Loading