Skip to content

Commit

Permalink
Feat!(snowflake): normalize identifiers as uppercase (#1697)
Browse files Browse the repository at this point in the history
* Feat!(snowflake): normalize identifiers as uppercase

* fixup

* fixup

* Fix\!: lower aliases as well

* Fix: maintain quotes in aliased names

* Apply suggestions from code review

Co-authored-by: Jo <[email protected]>

* Simplify normalize_identifiers.

* Update sqlglot/helper.py

Co-authored-by: Jo <[email protected]>

* default identify to true

---------

Co-authored-by: George Sittas <[email protected]>
Co-authored-by: Jo <[email protected]>
  • Loading branch information
3 people authored May 29, 2023
1 parent c2c955c commit dcfe67f
Show file tree
Hide file tree
Showing 18 changed files with 303 additions and 148 deletions.
5 changes: 5 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from sqlglot._typing import E


# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}


class Dialects(str, Enum):
DIALECT = ""

Expand Down
3 changes: 1 addition & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,8 +874,7 @@ def alias_column_names(self) -> t.List[str]:
table_alias = self.args.get("alias")
if not table_alias:
return []
column_list = table_alias.assert_is(TableAlias).args.get("columns") or []
return [c.name for c in column_list]
return [c.name for c in table_alias.args.get("columns") or []]

@property
def selects(self):
Expand Down
18 changes: 15 additions & 3 deletions sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
if t.TYPE_CHECKING:
from sqlglot import exp
from sqlglot._typing import E, T
from sqlglot.dialects.dialect import DialectType
from sqlglot.expressions import Expression

CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
Expand Down Expand Up @@ -419,18 +420,29 @@ def first(it: t.Iterable[T]) -> T:
return next(i for i in it)


def should_identify(text: str, identify: str | bool) -> bool:
def case_sensitive(text: str, dialect: DialectType) -> bool:
"""Checks if text contains any case sensitive characters depending on dialect."""
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE

unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
return any(unsafe(char) for char in text)


def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
"""Checks if text should be identified given an identify option.
Args:
text: the text to check.
identify: "always" | True - always returns true, "safe" - true if no upper case
identify:
"always" or `True`: always returns true.
"safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
dialect: the dialect to use in order to decide whether a text should be identified.
Returns:
Whether or not a string should be identified.
"""
if identify is True or identify == "always":
return True
if identify == "safe":
return not any(char.isupper() for char in text)
return not case_sensitive(text, dialect)
return False
19 changes: 11 additions & 8 deletions sqlglot/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@

from sqlglot import Schema, exp, maybe_parse
from sqlglot.errors import SqlglotError
from sqlglot.optimizer import Scope, build_scope, optimize
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer import Scope, build_scope, qualify

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
Expand Down Expand Up @@ -41,8 +38,8 @@ def lineage(
sql: str | exp.Expression,
schema: t.Optional[t.Dict | Schema] = None,
sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None,
rules: t.Sequence[t.Callable] = (lower_identities, qualify_tables, qualify_columns),
dialect: DialectType = None,
**kwargs,
) -> Node:
"""Build the lineage graph for a column of a SQL query.
Expand All @@ -51,8 +48,8 @@ def lineage(
sql: The SQL string or expression.
schema: The schema of tables.
sources: A mapping of queries which will be used to continue building lineage.
rules: Optimizer rules to apply, by default only qualifying tables and columns.
dialect: The dialect of input SQL.
**kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
Expand All @@ -69,8 +66,14 @@ def lineage(
},
)

optimized = optimize(expression, schema=schema, rules=rules)
scope = build_scope(optimized)
qualified = qualify.qualify(
expression,
dialect=dialect,
schema=schema,
**{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore
)

scope = build_scope(qualified)

if not scope:
raise SqlglotError("Cannot build lineage, sql must be SELECT")
Expand Down
17 changes: 10 additions & 7 deletions sqlglot/optimizer/canonicalize.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import itertools
import typing as t

from sqlglot import exp
from sqlglot.helper import should_identify
from sqlglot.optimizer.qualify_columns import quote_identifiers

if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType

def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression:

def canonicalize(
expression: exp.Expression, identify: bool = True, dialect: DialectType = None
) -> exp.Expression:
"""Converts a sql expression into a standard form.
This method relies on annotate_types because many of the
Expand All @@ -16,16 +22,13 @@ def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expr
expression: The expression to canonicalize.
identify: Whether or not to force identify identifier.
"""
exp.replace_children(expression, canonicalize, identify=identify)
exp.replace_children(expression, canonicalize, identify=identify, dialect=dialect)

expression = add_text_to_concat(expression)
expression = coerce_type(expression)
expression = remove_redundant_casts(expression)
expression = ensure_bool_predicates(expression)

if isinstance(expression, exp.Identifier):
if should_identify(expression.this, identify):
expression.set("quoted", True)
expression = quote_identifiers(expression, dialect=dialect, identify=identify)

return expression

Expand Down
88 changes: 0 additions & 88 deletions sqlglot/optimizer/lower_identities.py

This file was deleted.

37 changes: 37 additions & 0 deletions sqlglot/optimizer/normalize_identifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from sqlglot import exp
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType


def normalize_identifiers(
expression: exp.Expression, dialect: DialectType = None
) -> exp.Expression:
"""
Normalize all unquoted identifiers to either lower or upper case, depending on
the dialect. This essentially makes those identifiers case-insensitive.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar')
>>> normalize_identifiers(expression).sql()
'SELECT bar.a AS a FROM "Foo".bar'
Args:
expression: The expression to transform.
dialect: The dialect to use in order to decide how to normalize identifiers.
Returns:
The transformed expression.
"""
return expression.transform(_normalize, dialect, copy=False)


def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression:
if isinstance(node, exp.Identifier) and not node.quoted:
node.set(
"this",
node.this.upper()
if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE
else node.this.lower(),
)

return node
21 changes: 11 additions & 10 deletions sqlglot/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,19 @@
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.lower_identities import lower_identities
from sqlglot.optimizer.merge_subqueries import merge_subqueries
from sqlglot.optimizer.normalize import normalize
from sqlglot.optimizer.optimize_joins import optimize_joins
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.simplify import simplify
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
from sqlglot.schema import ensure_schema

RULES = (
lower_identities,
qualify_tables,
isolate_table_selects,
qualify_columns,
qualify,
pushdown_projections,
validate_qualify_columns,
normalize,
unnest_subqueries,
pushdown_predicates,
Expand Down Expand Up @@ -77,7 +70,15 @@ def optimize(
sqlglot.Expression: optimized expression
"""
schema = ensure_schema(schema or sqlglot.schema, dialect=dialect)
possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
possible_kwargs = {
"db": db,
"catalog": catalog,
"schema": schema,
"dialect": dialect,
"isolate_tables": True, # needed for other optimizations to perform well
"quote_identifiers": False, # this happens in canonicalize
**kwargs,
}
expression = exp.maybe_parse(expression, dialect=dialect, copy=True)

for rule in rules:
Expand Down
Loading

0 comments on commit dcfe67f

Please sign in to comment.