From 2b886d6da226900b235c719f6532fcb1d71d51b5 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 00:08:19 +0300 Subject: [PATCH 01/12] Fix: set quote_identifiers in qualify, add normalize flag in schema --- sqlglot/dataframe/sql/dataframe.py | 4 ++- sqlglot/optimizer/normalize_identifiers.py | 7 +++-- sqlglot/optimizer/optimizer.py | 13 ++++----- sqlglot/schema.py | 5 +++- tests/test_optimizer.py | 33 +++++++++------------- tests/test_schema.py | 4 +++ 6 files changed, 35 insertions(+), 31 deletions(-) diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index d6d76abe8d..3f09c3bd72 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -299,7 +299,9 @@ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = optimize_func(select_expression, identify="always") + select_expression = t.cast( + exp.Select, optimize_func(select_expression, identify="always") + ) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index bf4e3329de..e35caea258 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -1,10 +1,11 @@ from sqlglot import exp from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType +if t.TYPE_CHECKING: + from sqlglot._typing import E -def normalize_identifiers( - expression: exp.Expression, dialect: DialectType = None -) -> exp.Expression: + +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: """ Normalize all unquoted identifiers to either lower or upper case, depending on the dialect. This essentially makes those identifiers case-insensitive. diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index d2b1054a73..2853cbb89d 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -45,7 +45,7 @@ def optimize( dialect: DialectType = None, rules: t.Sequence[t.Callable] = RULES, **kwargs, -): +) -> exp.Expression: """ Rewrite a sqlglot AST into an optimized form. @@ -63,11 +63,11 @@ def optimize( dialect: The dialect to parse the sql string. rules: sequence of optimizer rules to use. Many of the rules require tables and columns to be qualified. - Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know - what you're doing! + Do not remove `qualify` from the sequence of rules unless you know what you're doing! **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. + Returns: - sqlglot.Expression: optimized expression + The optimized expression. """ schema = ensure_schema(schema or sqlglot.schema, dialect=dialect) possible_kwargs = { @@ -76,11 +76,10 @@ def optimize( "schema": schema, "dialect": dialect, "isolate_tables": True, # needed for other optimizations to perform well - "quote_identifiers": False, # this happens in canonicalize **kwargs, } - expression = exp.maybe_parse(expression, dialect=dialect, copy=True) + expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames @@ -89,4 +88,4 @@ def optimize( } expression = rule(expression, **rule_kwargs) - return expression + return t.cast(exp.Expression, expression) diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 25abaa202e..e7e7d3d4e6 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -177,6 +177,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): 2. {db: {table: set(*cols)}}} 3. {catalog: {db: {table: set(*cols)}}}} dialect: The dialect to be used for custom type mappings & parsing string arguments. + normalize: Whether to normalize identifier names according to the given dialect or not. """ def __init__( @@ -184,9 +185,11 @@ def __init__( schema: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None, dialect: DialectType = None, + normalize: bool = True, ) -> None: self.dialect = dialect self.visible = visible or {} + self.normalize = normalize self._type_mapping_cache: t.Dict[str, exp.DataType] = {} super().__init__(self._normalize(schema or {})) @@ -333,7 +336,7 @@ def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = Non name = identifier.name - if identifier.quoted: + if not self.normalize or identifier.quoted: return name return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower() diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 88c99c28ea..6e12f13a4d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -699,23 +699,18 @@ def test_quotes(self): } } - self.assertEqual( - optimizer.qualify.qualify( - parse_one( - """ - SELECT * FROM example."source" - """ - ), - dialect="snowflake", - schema=schema, - ).sql(pretty=True), - parse_one( - """ - SELECT - "source"."ID" AS "ID", - "source"."name" AS "name", - "source"."payload" AS "payload" - FROM "EXAMPLE"."source" AS "source" + expected = parse_one( """ - ).sql(pretty=True), - ) + SELECT + "source"."ID" AS "ID", + "source"."name" AS "name", + "source"."payload" AS "payload" + FROM "EXAMPLE"."source" AS "source" + """, + read="snowflake", + ).sql(pretty=True, dialect="snowflake") + + for func in (optimizer.qualify.qualify, optimizer.optimize): + source_query = parse_one('SELECT * FROM example."source"', read="snowflake") + transformed = func(source_query, dialect="snowflake", schema=schema) + self.assertEqual(transformed.sql(pretty=True, dialect="snowflake"), expected) diff --git a/tests/test_schema.py b/tests/test_schema.py index 072a41dbc1..83aad213cf 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -221,3 +221,7 @@ def test_schema_normalization(self): # Check that names are normalized to uppercase for Snowflake schema = MappingSchema(schema={"x": {"foo": "int", '"bLa"': "int"}}, dialect="snowflake") self.assertEqual(schema.column_names(exp.Table(this="x")), ["FOO", "bLa"]) + + # Check that switching off the normalization logic works as expected + schema = MappingSchema(schema={"x": {"foo": "int"}}, normalize=False, dialect="snowflake") + self.assertEqual(schema.column_names(exp.Table(this="x")), ["foo"]) From a0bce4fbd96d160d01f7024699a2e2e37b2f3e48 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 00:14:48 +0300 Subject: [PATCH 02/12] import typing as t --- sqlglot/optimizer/normalize_identifiers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index e35caea258..75105b0814 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -1,3 +1,5 @@ +import typing as t + from sqlglot import exp from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType From a57fa4cfdd796f25091b5e5dfe644bda1e95c45a Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 00:19:48 +0300 Subject: [PATCH 03/12] Fixup --- sqlglot/optimizer/normalize_identifiers.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py index 75105b0814..1e5c104242 100644 --- a/sqlglot/optimizer/normalize_identifiers.py +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -1,11 +1,7 @@ -import typing as t - from sqlglot import exp +from sqlglot._typing import E from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType -if t.TYPE_CHECKING: - from sqlglot._typing import E - def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: """ From 0545ee16775064f167a48f88d4af00ce78179914 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 00:31:11 +0300 Subject: [PATCH 04/12] PR feedback --- sqlglot/dataframe/sql/dataframe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 3f09c3bd72..3fc923238f 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -299,9 +299,7 @@ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = t.cast( - exp.Select, optimize_func(select_expression, identify="always") - ) + select_expression = t.cast(exp.Select, optimize_func(select_expression)) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: From 670e3a73b95ebb90e64d8cca7d2e90bbc0e338bc Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 20:43:25 +0300 Subject: [PATCH 05/12] Use new quote_identifiers rule before annotate_types --- sqlglot/optimizer/canonicalize.py | 13 ++--------- sqlglot/optimizer/optimizer.py | 2 ++ sqlglot/optimizer/qualify.py | 27 ++++++++++++++--------- sqlglot/optimizer/qualify_columns.py | 18 +-------------- sqlglot/optimizer/quote_identifiers.py | 23 +++++++++++++++++++ tests/fixtures/optimizer/canonicalize.sql | 4 ++-- tests/test_optimizer.py | 7 +----- 7 files changed, 47 insertions(+), 47 deletions(-) create mode 100644 sqlglot/optimizer/quote_identifiers.py diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index a74fa87460..da2fce8f3c 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,18 +1,11 @@ from __future__ import annotations import itertools -import typing as t from sqlglot import exp -from sqlglot.optimizer.qualify_columns import quote_identifiers -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - -def canonicalize( - expression: exp.Expression, identify: bool = True, dialect: DialectType = None -) -> exp.Expression: +def canonicalize(expression: exp.Expression) -> exp.Expression: """Converts a sql expression into a standard form. This method relies on annotate_types because many of the @@ -20,15 +13,13 @@ def canonicalize( Args: expression: The expression to canonicalize. - identify: Whether or not to force identify identifier. """ - exp.replace_children(expression, canonicalize, identify=identify, dialect=dialect) + exp.replace_children(expression, canonicalize) expression = add_text_to_concat(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) expression = ensure_bool_predicates(expression) - expression = quote_identifiers(expression, dialect=dialect, identify=identify) return expression diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 2853cbb89d..35a7ba53e9 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -16,6 +16,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify import qualify +from sqlglot.optimizer.quote_identifiers import quote_identifiers from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema @@ -31,6 +32,7 @@ merge_subqueries, eliminate_joins, eliminate_ctes, + quote_identifiers, annotate_types, canonicalize, simplify, diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index ea9d4ebccd..95f8a81b3f 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -4,10 +4,16 @@ from sqlglot import exp from sqlglot.dialects.dialect import DialectType -from sqlglot.optimizer import qualify_columns from sqlglot.optimizer.isolate_table_selects import isolate_table_selects from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlglot.optimizer.qualify_columns import ( + qualify_columns as qualify_columns_func, + validate_qualify_columns as validate_qualify_columns_func, +) from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.quote_identifiers import ( + quote_identifiers as quote_identifiers_func, +) from sqlglot.schema import Schema, ensure_schema @@ -20,6 +26,7 @@ def qualify( expand_alias_refs: bool = True, infer_schema: t.Optional[bool] = None, isolate_tables: bool = False, + qualify_columns: bool = True, validate_qualify_columns: bool = True, quote_identifiers: bool = True, identify: bool = True, @@ -44,11 +51,13 @@ def qualify( expand_alias_refs: Whether or not to expand references to aliases. infer_schema: Whether or not to infer the schema if missing. isolate_tables: Whether or not to isolate table selects. + qualify_columns: Whether or not to qualify columns. validate_qualify_columns: Whether or not to validate columns. quote_identifiers: Whether or not to run the quote_identifiers step. This step is necessary to ensure correctness for case sensitive queries. But this flag is provided in case this step is performed at a later time. identify: If True, quote all identifiers, else only necessary ones. + Returns: The qualified expression. """ @@ -59,19 +68,15 @@ def qualify( if isolate_tables: expression = isolate_table_selects(expression, schema=schema) - expression = qualify_columns.qualify_columns( - expression, - schema, - expand_alias_refs=expand_alias_refs, - infer_schema=infer_schema, - ) + if qualify_columns: + expression = qualify_columns_func( + expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema + ) if quote_identifiers: - expression = expression.transform( - qualify_columns.quote_identifiers, dialect, identify, copy=False - ) + expression = quote_identifiers_func(expression, dialect, identify, copy=False) if validate_qualify_columns: - qualify_columns.validate_qualify_columns(expression) + validate_qualify_columns_func(expression) return expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 799dd0633a..2d8bf8cc0f 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -4,9 +4,8 @@ import typing as t from sqlglot import alias, exp -from sqlglot.dialects.dialect import DialectType from sqlglot.errors import OptimizeError -from sqlglot.helper import case_sensitive, seq_get +from sqlglot.helper import seq_get from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope from sqlglot.schema import Schema, ensure_schema @@ -414,21 +413,6 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) -def quote_identifiers( - expression: exp.Expression, dialect: DialectType, identify: bool -) -> exp.Expression: - """Makes sure all identifiers that need to be quoted are quoted.""" - if isinstance(expression, exp.Identifier): - name = expression.this - expression.set( - "quoted", - identify - or case_sensitive(name, dialect=dialect) - or not exp.SAFE_IDENTIFIER_RE.match(name), - ) - return expression - - class Resolver: """ Helper for resolving columns. diff --git a/sqlglot/optimizer/quote_identifiers.py b/sqlglot/optimizer/quote_identifiers.py new file mode 100644 index 0000000000..28a1747aa3 --- /dev/null +++ b/sqlglot/optimizer/quote_identifiers.py @@ -0,0 +1,23 @@ +from sqlglot import exp +from sqlglot.helper import case_sensitive +from sqlglot._typing import E +from sqlglot.dialects.dialect import DialectType as DialectType + + +def quote_identifiers( + expression: E, dialect: DialectType = None, identify: bool = True, copy: bool = True +) -> E: + """Makes sure all identifiers that need to be quoted are quoted.""" + + def _quote(expression: E) -> E: + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify + or case_sensitive(name, dialect=dialect) + or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + return expression + + return expression.transform(_quote, copy=copy) diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index ccf2f16b7e..1fc44efbe1 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -10,8 +10,8 @@ SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w"; SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; SELECT 1 + 3.2 AS "a" FROM "w" AS "w"; -SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day; -SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day AS "_col_0"; +SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day; +SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day AS "_col_0"; -------------------------------------- -- Ensure boolean predicates diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 6e12f13a4d..18c3a1fcd8 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -303,12 +303,7 @@ def test_eliminate_subqueries(self): def test_canonicalize(self): optimize = partial( optimizer.optimize, - rules=[ - optimizer.qualify_tables.qualify_tables, - optimizer.qualify_columns.qualify_columns, - annotate_types, - optimizer.canonicalize.canonicalize, - ], + rules=[optimizer.qualify.qualify, annotate_types, optimizer.canonicalize.canonicalize], ) self.check_file("canonicalize", optimize, schema=self.schema) From 55025aba8ab2335895f7c049f79ffb16968f0bfc Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 20:45:12 +0300 Subject: [PATCH 06/12] Reset quote_identifiers kwarg to False in optimize --- sqlglot/optimizer/optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 35a7ba53e9..d76d061e1d 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -78,6 +78,7 @@ def optimize( "schema": schema, "dialect": dialect, "isolate_tables": True, # needed for other optimizations to perform well + "quote_identifiers": False, # this happens in canonicalize **kwargs, } From 3e6a9ee389b672a324bcb01e6d0562306bed0268 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 20:46:28 +0300 Subject: [PATCH 07/12] Formatting --- sqlglot/optimizer/quote_identifiers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/optimizer/quote_identifiers.py b/sqlglot/optimizer/quote_identifiers.py index 28a1747aa3..7aae241a88 100644 --- a/sqlglot/optimizer/quote_identifiers.py +++ b/sqlglot/optimizer/quote_identifiers.py @@ -1,7 +1,7 @@ from sqlglot import exp -from sqlglot.helper import case_sensitive from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType as DialectType +from sqlglot.helper import case_sensitive def quote_identifiers( From 06f064a75f8119de0b0845ab99b8c7855222d8f7 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 20:49:20 +0300 Subject: [PATCH 08/12] Set kwargs instead of positional arguments in qualify --- sqlglot/optimizer/qualify.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index 95f8a81b3f..4f7ca4ded9 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -74,7 +74,9 @@ def qualify( ) if quote_identifiers: - expression = quote_identifiers_func(expression, dialect, identify, copy=False) + expression = quote_identifiers_func( + expression, dialect=dialect, identify=identify, copy=False + ) if validate_qualify_columns: validate_qualify_columns_func(expression) From 1dbd984fa1d260ab915f1d389a294e49cfdbbf3a Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 20:53:48 +0300 Subject: [PATCH 09/12] Include quote_identifiers rule in test_canonicalize --- tests/test_optimizer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 18c3a1fcd8..50b87a6f2a 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -303,7 +303,12 @@ def test_eliminate_subqueries(self): def test_canonicalize(self): optimize = partial( optimizer.optimize, - rules=[optimizer.qualify.qualify, annotate_types, optimizer.canonicalize.canonicalize], + rules=[ + optimizer.qualify.qualify, + optimizer.quote_identifiers.quote_identifiers, + annotate_types, + optimizer.canonicalize.canonicalize + ], ) self.check_file("canonicalize", optimize, schema=self.schema) From 6dd3559b96d4715032e2d22a3bc5390eddc2b391 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 20:58:20 +0300 Subject: [PATCH 10/12] Formatting --- tests/test_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 50b87a6f2a..461a6904e2 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -307,7 +307,7 @@ def test_canonicalize(self): optimizer.qualify.qualify, optimizer.quote_identifiers.quote_identifiers, annotate_types, - optimizer.canonicalize.canonicalize + optimizer.canonicalize.canonicalize, ], ) self.check_file("canonicalize", optimize, schema=self.schema) From d20e241ff2e91aae15a1c8833f13fed3679e90a1 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 21:11:29 +0300 Subject: [PATCH 11/12] PR feedback --- sqlglot/optimizer/optimizer.py | 2 +- sqlglot/optimizer/qualify.py | 4 +--- sqlglot/optimizer/qualify_columns.py | 23 ++++++++++++++++++++++- sqlglot/optimizer/quote_identifiers.py | 23 ----------------------- tests/test_optimizer.py | 2 +- 5 files changed, 25 insertions(+), 29 deletions(-) delete mode 100644 sqlglot/optimizer/quote_identifiers.py diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index d76d061e1d..dbe33a2088 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -16,7 +16,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify import qualify -from sqlglot.optimizer.quote_identifiers import quote_identifiers +from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index 4f7ca4ded9..e58d4a03b0 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -8,12 +8,10 @@ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import ( qualify_columns as qualify_columns_func, + quote_identifiers as quote_identifiers_func, validate_qualify_columns as validate_qualify_columns_func, ) from sqlglot.optimizer.qualify_tables import qualify_tables -from sqlglot.optimizer.quote_identifiers import ( - quote_identifiers as quote_identifiers_func, -) from sqlglot.schema import Schema, ensure_schema diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 2d8bf8cc0f..220a414e13 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -4,8 +4,10 @@ import typing as t from sqlglot import alias, exp +from sqlglot._typing import E +from sqlglot.dialects.dialect import DialectType from sqlglot.errors import OptimizeError -from sqlglot.helper import seq_get +from sqlglot.helper import case_sensitive, seq_get from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope from sqlglot.schema import Schema, ensure_schema @@ -413,6 +415,25 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) +def quote_identifiers( + expression: E, dialect: DialectType = None, identify: bool = True, copy: bool = True +) -> E: + """Makes sure all identifiers that need to be quoted are quoted.""" + + def _quote(expression: E) -> E: + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify + or case_sensitive(name, dialect=dialect) + or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + return expression + + return expression.transform(_quote, copy=copy) + + class Resolver: """ Helper for resolving columns. diff --git a/sqlglot/optimizer/quote_identifiers.py b/sqlglot/optimizer/quote_identifiers.py deleted file mode 100644 index 7aae241a88..0000000000 --- a/sqlglot/optimizer/quote_identifiers.py +++ /dev/null @@ -1,23 +0,0 @@ -from sqlglot import exp -from sqlglot._typing import E -from sqlglot.dialects.dialect import DialectType as DialectType -from sqlglot.helper import case_sensitive - - -def quote_identifiers( - expression: E, dialect: DialectType = None, identify: bool = True, copy: bool = True -) -> E: - """Makes sure all identifiers that need to be quoted are quoted.""" - - def _quote(expression: E) -> E: - if isinstance(expression, exp.Identifier): - name = expression.this - expression.set( - "quoted", - identify - or case_sensitive(name, dialect=dialect) - or not exp.SAFE_IDENTIFIER_RE.match(name), - ) - return expression - - return expression.transform(_quote, copy=copy) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 461a6904e2..2ae6da993a 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -305,7 +305,7 @@ def test_canonicalize(self): optimizer.optimize, rules=[ optimizer.qualify.qualify, - optimizer.quote_identifiers.quote_identifiers, + optimizer.qualify_columns.quote_identifiers, annotate_types, optimizer.canonicalize.canonicalize, ], From e0a3e3c5ba3e65890ad82a1f5734be79c715cca4 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 30 May 2023 22:39:15 +0300 Subject: [PATCH 12/12] Remove copy arg from quote_identifiers --- sqlglot/optimizer/qualify.py | 4 +--- sqlglot/optimizer/qualify_columns.py | 6 ++---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py index e58d4a03b0..5fdbde81cb 100644 --- a/sqlglot/optimizer/qualify.py +++ b/sqlglot/optimizer/qualify.py @@ -72,9 +72,7 @@ def qualify( ) if quote_identifiers: - expression = quote_identifiers_func( - expression, dialect=dialect, identify=identify, copy=False - ) + expression = quote_identifiers_func(expression, dialect=dialect, identify=identify) if validate_qualify_columns: validate_qualify_columns_func(expression) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 220a414e13..4a311714ba 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -415,9 +415,7 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) -def quote_identifiers( - expression: E, dialect: DialectType = None, identify: bool = True, copy: bool = True -) -> E: +def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: """Makes sure all identifiers that need to be quoted are quoted.""" def _quote(expression: E) -> E: @@ -431,7 +429,7 @@ def _quote(expression: E) -> E: ) return expression - return expression.transform(_quote, copy=copy) + return expression.transform(_quote, copy=False) class Resolver: