From dcfe67f5aff6a1f8739c704dfc4f71ad2f2b293e Mon Sep 17 00:00:00 2001 From: Toby Mao Date: Sun, 28 May 2023 19:03:03 -0700 Subject: [PATCH] Feat!(snowflake): normalize identifiers as uppercase (#1697) * Feat!(snowflake): normalize identifiers as uppercase * fixup * fixup * Fix\!: lower aliases as well * Fix: maintain quotes in aliased names * Apply suggestions from code review Co-authored-by: Jo <46752250+GeorgeSittas@users.noreply.github.com> * Simplify normalize_identifiers. * Update sqlglot/helper.py Co-authored-by: Jo <46752250+GeorgeSittas@users.noreply.github.com> * default identify to true --------- Co-authored-by: George Sittas Co-authored-by: Jo <46752250+GeorgeSittas@users.noreply.github.com> --- sqlglot/dialects/dialect.py | 5 ++ sqlglot/expressions.py | 3 +- sqlglot/helper.py | 18 +++- sqlglot/lineage.py | 19 ++-- sqlglot/optimizer/canonicalize.py | 17 ++-- sqlglot/optimizer/lower_identities.py | 88 ------------------- sqlglot/optimizer/normalize_identifiers.py | 37 ++++++++ sqlglot/optimizer/optimizer.py | 21 ++--- sqlglot/optimizer/qualify.py | 77 ++++++++++++++++ sqlglot/optimizer/qualify_columns.py | 25 ++++-- sqlglot/schema.py | 10 ++- ...entities.sql => normalize_identifiers.sql} | 30 +++++-- tests/fixtures/optimizer/optimizer.sql | 27 +++++- tests/fixtures/optimizer/qualify_columns.sql | 3 +- tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 8 +- tests/test_executor.py | 4 + tests/test_optimizer.py | 55 ++++++++++-- tests/test_schema.py | 4 + 18 files changed, 303 insertions(+), 148 deletions(-) delete mode 100644 sqlglot/optimizer/lower_identities.py create mode 100644 sqlglot/optimizer/normalize_identifiers.py create mode 100644 sqlglot/optimizer/qualify.py rename tests/fixtures/optimizer/{lower_identities.sql => normalize_identifiers.sql} (54%) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 013dd2aebf..890a3c3cb5 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -15,6 +15,11 @@ from sqlglot._typing import E +# Only Snowflake is currently known to resolve unquoted identifiers as uppercase. +# https://docs.snowflake.com/en/sql-reference/identifiers-syntax +RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"} + + class Dialects(str, Enum): DIALECT = "" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index adca4eea05..47d72594fe 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -874,8 +874,7 @@ def alias_column_names(self) -> t.List[str]: table_alias = self.args.get("alias") if not table_alias: return [] - column_list = table_alias.assert_is(TableAlias).args.get("columns") or [] - return [c.name for c in column_list] + return [c.name for c in table_alias.args.get("columns") or []] @property def selects(self): diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 2cffb1740e..191051ce42 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -13,6 +13,7 @@ if t.TYPE_CHECKING: from sqlglot import exp from sqlglot._typing import E, T + from sqlglot.dialects.dialect import DialectType from sqlglot.expressions import Expression CAMEL_CASE_PATTERN = re.compile("(? T: return next(i for i in it) -def should_identify(text: str, identify: str | bool) -> bool: +def case_sensitive(text: str, dialect: DialectType) -> bool: + """Checks if text contains any case sensitive characters depending on dialect.""" + from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE + + unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper + return any(unsafe(char) for char in text) + + +def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool: """Checks if text should be identified given an identify option. Args: text: the text to check. - identify: "always" | True - always returns true, "safe" - true if no upper case + identify: + "always" or `True`: always returns true. + "safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`. + dialect: the dialect to use in order to decide whether a text should be identified. Returns: Whether or not a string should be identified. @@ -432,5 +444,5 @@ def should_identify(text: str, identify: str | bool) -> bool: if identify is True or identify == "always": return True if identify == "safe": - return not any(char.isupper() for char in text) + return not case_sensitive(text, dialect) return False diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 217b2fc362..04a807322a 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -6,10 +6,7 @@ from sqlglot import Schema, exp, maybe_parse from sqlglot.errors import SqlglotError -from sqlglot.optimizer import Scope, build_scope, optimize -from sqlglot.optimizer.lower_identities import lower_identities -from sqlglot.optimizer.qualify_columns import qualify_columns -from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer import Scope, build_scope, qualify if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -41,8 +38,8 @@ def lineage( sql: str | exp.Expression, schema: t.Optional[t.Dict | Schema] = None, sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, - rules: t.Sequence[t.Callable] = (lower_identities, qualify_tables, qualify_columns), dialect: DialectType = None, + **kwargs, ) -> Node: """Build the lineage graph for a column of a SQL query. @@ -51,8 +48,8 @@ def lineage( sql: The SQL string or expression. schema: The schema of tables. sources: A mapping of queries which will be used to continue building lineage. - rules: Optimizer rules to apply, by default only qualifying tables and columns. dialect: The dialect of input SQL. + **kwargs: Qualification optimizer kwargs. Returns: A lineage node. @@ -69,8 +66,14 @@ def lineage( }, ) - optimized = optimize(expression, schema=schema, rules=rules) - scope = build_scope(optimized) + qualified = qualify.qualify( + expression, + dialect=dialect, + schema=schema, + **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore + ) + + scope = build_scope(qualified) if not scope: raise SqlglotError("Cannot build lineage, sql must be SELECT") diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index ef929ac780..a74fa87460 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,12 +1,18 @@ from __future__ import annotations import itertools +import typing as t from sqlglot import exp -from sqlglot.helper import should_identify +from sqlglot.optimizer.qualify_columns import quote_identifiers +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType -def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression: + +def canonicalize( + expression: exp.Expression, identify: bool = True, dialect: DialectType = None +) -> exp.Expression: """Converts a sql expression into a standard form. This method relies on annotate_types because many of the @@ -16,16 +22,13 @@ def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expr expression: The expression to canonicalize. identify: Whether or not to force identify identifier. """ - exp.replace_children(expression, canonicalize, identify=identify) + exp.replace_children(expression, canonicalize, identify=identify, dialect=dialect) expression = add_text_to_concat(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) expression = ensure_bool_predicates(expression) - - if isinstance(expression, exp.Identifier): - if should_identify(expression.this, identify): - expression.set("quoted", True) + expression = quote_identifiers(expression, dialect=dialect, identify=identify) return expression diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py deleted file mode 100644 index fae172604c..0000000000 --- a/sqlglot/optimizer/lower_identities.py +++ /dev/null @@ -1,88 +0,0 @@ -from sqlglot import exp - - -def lower_identities(expression): - """ - Convert all unquoted identifiers to lower case. - - Assuming the schema is all lower case, this essentially makes identifiers case-insensitive. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') - >>> lower_identities(expression).sql() - 'SELECT bar.a AS A FROM "Foo".bar' - - Args: - expression (sqlglot.Expression): expression to quote - Returns: - sqlglot.Expression: quoted expression - """ - # We need to leave the output aliases unchanged, so the selects need special handling - _lower_selects(expression) - - # These clauses can reference output aliases and also need special handling - _lower_order(expression) - _lower_having(expression) - - # We've already handled these args, so don't traverse into them - traversed = {"expressions", "order", "having"} - - if isinstance(expression, exp.Subquery): - # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1 - lower_identities(expression.this) - traversed |= {"this"} - - if isinstance(expression, exp.Union): - # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X - lower_identities(expression.left) - lower_identities(expression.right) - traversed |= {"this", "expression"} - - for k, v in expression.iter_expressions(): - if k in traversed: - continue - v.transform(_lower, copy=False) - - return expression - - -def _lower_selects(expression): - for e in expression.expressions: - # Leave output aliases as-is - e.unalias().transform(_lower, copy=False) - - -def _lower_order(expression): - order = expression.args.get("order") - - if not order: - return - - output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)} - - for ordered in order.expressions: - # Don't lower references to output aliases - if not ( - isinstance(ordered.this, exp.Column) - and not ordered.this.table - and ordered.this.name in output_aliases - ): - ordered.transform(_lower, copy=False) - - -def _lower_having(expression): - having = expression.args.get("having") - - if not having: - return - - # Don't lower references to output aliases - for agg in having.find_all(exp.AggFunc): - agg.transform(_lower, copy=False) - - -def _lower(node): - if isinstance(node, exp.Identifier) and not node.quoted: - node.set("this", node.this.lower()) - return node diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py new file mode 100644 index 0000000000..bf4e3329de --- /dev/null +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -0,0 +1,37 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType + + +def normalize_identifiers( + expression: exp.Expression, dialect: DialectType = None +) -> exp.Expression: + """ + Normalize all unquoted identifiers to either lower or upper case, depending on + the dialect. This essentially makes those identifiers case-insensitive. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') + >>> normalize_identifiers(expression).sql() + 'SELECT bar.a AS a FROM "Foo".bar' + + Args: + expression: The expression to transform. + dialect: The dialect to use in order to decide how to normalize identifiers. + + Returns: + The transformed expression. + """ + return expression.transform(_normalize, dialect, copy=False) + + +def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression: + if isinstance(node, exp.Identifier) and not node.quoted: + node.set( + "this", + node.this.upper() + if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE + else node.this.lower(), + ) + + return node diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index e2cd0c4c63..d2b1054a73 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -10,26 +10,19 @@ from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries -from sqlglot.optimizer.isolate_table_selects import isolate_table_selects -from sqlglot.optimizer.lower_identities import lower_identities from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections -from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns -from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.qualify import qualify from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema RULES = ( - lower_identities, - qualify_tables, - isolate_table_selects, - qualify_columns, + qualify, pushdown_projections, - validate_qualify_columns, normalize, unnest_subqueries, pushdown_predicates, @@ -77,7 +70,15 @@ def optimize( sqlglot.Expression: optimized expression """ schema = ensure_schema(schema or sqlglot.schema, dialect=dialect) - possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} + possible_kwargs = { + "db": db, + "catalog": catalog, + "schema": schema, + "dialect": dialect, + "isolate_tables": True, # needed for other optimizations to perform well + "quote_identifiers": False, # this happens in canonicalize + **kwargs, + } expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py new file mode 100644 index 0000000000..ea9d4ebccd --- /dev/null +++ b/sqlglot/optimizer/qualify.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlglot.dialects.dialect import DialectType +from sqlglot.optimizer import qualify_columns +from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.schema import Schema, ensure_schema + + +def qualify( + expression: exp.Expression, + dialect: DialectType = None, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + schema: t.Optional[dict | Schema] = None, + expand_alias_refs: bool = True, + infer_schema: t.Optional[bool] = None, + isolate_tables: bool = False, + validate_qualify_columns: bool = True, + quote_identifiers: bool = True, + identify: bool = True, +) -> exp.Expression: + """ + Rewrite sqlglot AST to have normalized and qualified tables and columns. + + This step is necessary for all further SQLGlot optimizations. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify(expression, schema=schema).sql() + 'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"' + + Args: + expression: Expression to qualify. + db: Default database name for tables. + catalog: Default catalog name for tables. + schema: Schema to infer column names and types. + expand_alias_refs: Whether or not to expand references to aliases. + infer_schema: Whether or not to infer the schema if missing. + isolate_tables: Whether or not to isolate table selects. + validate_qualify_columns: Whether or not to validate columns. + quote_identifiers: Whether or not to run the quote_identifiers step. + This step is necessary to ensure correctness for case sensitive queries. + But this flag is provided in case this step is performed at a later time. + identify: If True, quote all identifiers, else only necessary ones. + Returns: + The qualified expression. + """ + schema = ensure_schema(schema, dialect=dialect) + expression = normalize_identifiers(expression, dialect=dialect) + expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema) + + if isolate_tables: + expression = isolate_table_selects(expression, schema=schema) + + expression = qualify_columns.qualify_columns( + expression, + schema, + expand_alias_refs=expand_alias_refs, + infer_schema=infer_schema, + ) + + if quote_identifiers: + expression = expression.transform( + qualify_columns.quote_identifiers, dialect, identify, copy=False + ) + + if validate_qualify_columns: + qualify_columns.validate_qualify_columns(expression) + + return expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 65a943f51a..799dd0633a 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -4,8 +4,9 @@ import typing as t from sqlglot import alias, exp +from sqlglot.dialects.dialect import DialectType from sqlglot.errors import OptimizeError -from sqlglot.helper import seq_get +from sqlglot.helper import case_sensitive, seq_get from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope from sqlglot.schema import Schema, ensure_schema @@ -404,9 +405,6 @@ def _qualify_outputs(scope): selection = alias( selection, alias=selection.output_name or f"_col_{i}", - quoted=True - if isinstance(selection, exp.Column) and selection.this.quoted - else None, ) if aliased_column: selection.set("alias", exp.to_identifier(aliased_column)) @@ -416,6 +414,21 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) +def quote_identifiers( + expression: exp.Expression, dialect: DialectType, identify: bool +) -> exp.Expression: + """Makes sure all identifiers that need to be quoted are quoted.""" + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify + or case_sensitive(name, dialect=dialect) + or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + return expression + + class Resolver: """ Helper for resolving columns. @@ -469,9 +482,7 @@ def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: if node_alias: return exp.to_identifier(node_alias.this) - return exp.to_identifier( - table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None - ) + return exp.to_identifier(table_name) @property def all_columns(self): diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 6bb5a362de..25abaa202e 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -6,6 +6,7 @@ import sqlglot from sqlglot import expressions as exp from sqlglot._typing import T +from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE from sqlglot.errors import ParseError, SchemaError from sqlglot.helper import dict_depth from sqlglot.trie import in_trie, new_trie @@ -22,6 +23,8 @@ class Schema(abc.ABC): """Abstract base class for database schemas""" + dialect: DialectType + @abc.abstractmethod def add_table( self, @@ -328,7 +331,12 @@ def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = Non except ParseError: return name if isinstance(name, str) else name.name - return identifier.name if identifier.quoted else identifier.name.lower() + name = identifier.name + + if identifier.quoted: + return name + + return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower() def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those diff --git a/tests/fixtures/optimizer/lower_identities.sql b/tests/fixtures/optimizer/normalize_identifiers.sql similarity index 54% rename from tests/fixtures/optimizer/lower_identities.sql rename to tests/fixtures/optimizer/normalize_identifiers.sql index cea346f1aa..ddb755f61e 100644 --- a/tests/fixtures/optimizer/lower_identities.sql +++ b/tests/fixtures/optimizer/normalize_identifiers.sql @@ -1,11 +1,19 @@ SELECT a FROM x; SELECT a FROM x; +# dialect: snowflake +SELECT A FROM X; +SELECT A FROM X; + SELECT "A" FROM "X"; SELECT "A" FROM "X"; SELECT a AS A FROM x; -SELECT a AS A FROM x; +SELECT a AS a FROM x; + +# dialect: snowflake +SELECT A AS a FROM X; +SELECT A AS A FROM X; SELECT * FROM x; SELECT * FROM x; @@ -13,29 +21,37 @@ SELECT * FROM x; SELECT A FROM x; SELECT a FROM x; +# dialect: snowflake +SELECT a FROM X; +SELECT A FROM X; + SELECT a FROM X; SELECT a FROM x; +# dialect: snowflake +SELECT A FROM x; +SELECT A FROM X; + SELECT A AS A FROM (SELECT a AS A FROM x); -SELECT a AS A FROM (SELECT a AS a FROM x); +SELECT a AS a FROM (SELECT a AS a FROM x); SELECT a AS B FROM x ORDER BY B; -SELECT a AS B FROM x ORDER BY B; +SELECT a AS b FROM x ORDER BY b; SELECT A FROM x ORDER BY A; SELECT a FROM x ORDER BY a; SELECT A AS B FROM X GROUP BY A HAVING SUM(B) > 0; -SELECT a AS B FROM x GROUP BY a HAVING SUM(b) > 0; +SELECT a AS b FROM x GROUP BY a HAVING SUM(b) > 0; SELECT A AS B, SUM(B) AS C FROM X GROUP BY A HAVING C > 0; -SELECT a AS B, SUM(b) AS C FROM x GROUP BY a HAVING C > 0; +SELECT a AS b, SUM(b) AS c FROM x GROUP BY a HAVING c > 0; SELECT A FROM X UNION SELECT A FROM X; SELECT a FROM x UNION SELECT a FROM x; SELECT A AS A FROM X UNION SELECT A AS A FROM X; -SELECT a AS A FROM x UNION SELECT a AS A FROM x; +SELECT a AS a FROM x UNION SELECT a AS a FROM x; (SELECT A AS A FROM X); -(SELECT a AS A FROM x); +(SELECT a AS a FROM x); diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index d9597e5395..e0567d72c9 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -577,10 +577,11 @@ FROM `u_cte` AS `u_cte` PIVOT(SUM(`u_cte`.`f`) AS `sum` FOR `u_cte`.`h` IN ('x', # dialect: snowflake SELECT * FROM u PIVOT (SUM(f) FOR h IN ('x', 'y')); SELECT - "_q_0"."g" AS "g", + "_q_0"."G" AS "G", "_q_0"."'x'" AS "'x'", "_q_0"."'y'" AS "'y'" -FROM "u" AS "u" PIVOT(SUM("u"."f") FOR "u"."h" IN ('x', 'y')) AS "_q_0"; +FROM "U" AS "U" PIVOT(SUM("U"."F") FOR "U"."H" IN ('x', 'y')) AS "_q_0" +; # title: selecting all columns from a pivoted source and generating spark # note: spark doesn't allow pivot aliases or qualified columns for the pivot's "field" (`h`) @@ -596,3 +597,25 @@ FROM ( * FROM `u` AS `u` PIVOT(SUM(`u`.`f`) FOR `h` IN ('x', 'y')) ) AS `_q_0`; + +# title: quoting is maintained +# dialect: snowflake +with cte1("id", foo) as (select 1, 2) select "id" from cte1; +WITH "CTE1" AS ( + SELECT + 1 AS "id" +) +SELECT + "CTE1"."id" AS "id" +FROM "CTE1"; + +# title: ensures proper quoting happens after all optimizations +# execute: false +SELECT "foO".x FROM (SELECT 1 AS x) AS "foO"; +WITH "foO" AS ( + SELECT + 1 AS "x" +) +SELECT + "foO"."x" AS "x" +FROM "foO" AS "foO"; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index e64148c36c..7be2c7f3e1 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -5,7 +5,7 @@ SELECT a FROM x; SELECT x.a AS a FROM x AS x; SELECT "a" FROM x; -SELECT x."a" AS "a" FROM x AS x; +SELECT x.a AS a FROM x AS x; # execute: false SELECT a FROM zz GROUP BY a ORDER BY a; @@ -402,7 +402,6 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1; SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1; -# dialect: snowflake SELECT * FROM x QUALIFY COUNT(a) OVER (PARTITION BY b) > 1; SELECT x.a AS a, x.b AS b FROM x AS x QUALIFY COUNT(x.a) OVER (PARTITION BY x.b) > 1; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 0f2e20671b..a6ee325c2c 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -12011,10 +12011,10 @@ GROUP BY cc_call_center_id, cd_education_status ORDER BY Sum(cr_net_loss) DESC; SELECT - "call_center"."cc_call_center_id" AS Call_Center, - "call_center"."cc_name" AS Call_Center_Name, - "call_center"."cc_manager" AS Manager, - SUM("catalog_returns"."cr_net_loss") AS Returns_Loss + "call_center"."cc_call_center_id" AS "call_center", + "call_center"."cc_name" AS "call_center_name", + "call_center"."cc_manager" AS "manager", + SUM("catalog_returns"."cr_net_loss") AS "returns_loss" FROM "call_center" AS "call_center" JOIN "catalog_returns" AS "catalog_returns" ON "catalog_returns"."cr_call_center_sk" = "call_center"."cc_call_center_sk" diff --git a/tests/test_executor.py b/tests/test_executor.py index 56a6674a07..a121dea0c7 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -609,6 +609,10 @@ def test_scalar_functions(self): def test_case_sensitivity(self): result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]}) + self.assertEqual(result.columns, ("a",)) + self.assertEqual(result.rows, [(1,)]) + + result = execute('SELECT A AS "A" FROM X', tables={"x": [{"a": 1}]}) self.assertEqual(result.columns, ("A",)) self.assertEqual(result.rows, [(1,)]) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index d8b443550e..88c99c28ea 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -20,13 +20,14 @@ ) -def parse_and_optimize(func, sql, dialect, **kwargs): - return func(parse_one(sql, read=dialect), **kwargs) +def parse_and_optimize(func, sql, read_dialect, **kwargs): + return func(parse_one(sql, read=read_dialect), **kwargs) def qualify_columns(expression, **kwargs): - expression = optimizer.qualify_tables.qualify_tables(expression) - expression = optimizer.qualify_columns.qualify_columns(expression, infer_schema=True, **kwargs) + expression = optimizer.qualify.qualify( + expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs + ) return expression @@ -98,7 +99,7 @@ def setUp(self): }, } - def check_file(self, file, func, pretty=False, execute=False, **kwargs): + def check_file(self, file, func, pretty=False, execute=False, set_dialect=False, **kwargs): with ProcessPoolExecutor() as pool: results = {} @@ -113,6 +114,9 @@ def check_file(self, file, func, pretty=False, execute=False, **kwargs): if leave_tables_isolated is not None: func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) + if set_dialect and dialect: + func_kwargs["dialect"] = dialect + future = pool.submit(parse_and_optimize, func, sql, dialect, **func_kwargs) results[future] = ( sql, @@ -157,6 +161,7 @@ def test_optimize(self): pretty=True, execute=True, schema=schema, + set_dialect=True, ) def test_isolate_table_selects(self): @@ -217,8 +222,12 @@ def test_qualify_columns__invalid(self): ) optimizer.qualify_columns.validate_qualify_columns(expression) - def test_lower_identities(self): - self.check_file("lower_identities", optimizer.lower_identities.lower_identities) + def test_normalize_identifiers(self): + self.check_file( + "normalize_identifiers", + optimizer.normalize_identifiers.normalize_identifiers, + set_dialect=True, + ) def test_pushdown_projection(self): self.check_file("pushdown_projections", pushdown_projections, schema=self.schema) @@ -678,3 +687,35 @@ def test_schema_with_spaces(self): optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema), parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'), ) + + def test_quotes(self): + schema = { + "example": { + '"source"': { + "id": "text", + '"name"': "text", + '"payload"': "text", + } + } + } + + self.assertEqual( + optimizer.qualify.qualify( + parse_one( + """ + SELECT * FROM example."source" + """ + ), + dialect="snowflake", + schema=schema, + ).sql(pretty=True), + parse_one( + """ + SELECT + "source"."ID" AS "ID", + "source"."name" AS "name", + "source"."payload" AS "payload" + FROM "EXAMPLE"."source" AS "source" + """ + ).sql(pretty=True), + ) diff --git a/tests/test_schema.py b/tests/test_schema.py index 94495b1f0f..072a41dbc1 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -217,3 +217,7 @@ def test_schema_normalization(self): schema = MappingSchema() schema.add_table("Foo", {"SomeColumn": "INT", '"SomeColumn"': "DOUBLE"}) self.assertEqual(schema.column_names(exp.Table(this="fOO")), ["somecolumn", "SomeColumn"]) + + # Check that names are normalized to uppercase for Snowflake + schema = MappingSchema(schema={"x": {"foo": "int", '"bLa"': "int"}}, dialect="snowflake") + self.assertEqual(schema.column_names(exp.Table(this="x")), ["FOO", "bLa"])