Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor!: move normalization logic in Dialect, update case-sensitivity info #1784

Merged
merged 8 commits into from
Jun 16, 2023
18 changes: 11 additions & 7 deletions sqlglot/dataframe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Currently many of the common operations are covered and more functionality will
## Instructions
* [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library.
* Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`.
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>)`.
* Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('<table_name>', <column_structure>, dialect="spark")`.
* The column structure can be defined the following ways:
* Dictionary where the keys are column names and values are string of the Spark SQL type name.
* Ex: `{'cola': 'string', 'colb': 'int'}`
Expand All @@ -33,12 +33,16 @@ import sqlglot
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.dataframe.sql import functions as F

sqlglot.schema.add_table('employee', {
'employee_id': 'INT',
'fname': 'STRING',
'lname': 'STRING',
'age': 'INT',
}) # Register the table structure prior to reading from the table
sqlglot.schema.add_table(
'employee',
{
'employee_id': 'INT',
'fname': 'STRING',
'lname': 'STRING',
'age': 'INT',
},
dialect="spark",
) # Register the table structure prior to reading from the table

spark = SparkSession()

Expand Down
5 changes: 5 additions & 0 deletions sqlglot/dataframe/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlglot
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.types import DataType
from sqlglot.dialects import Spark
from sqlglot.helper import flatten, is_iterable

if t.TYPE_CHECKING:
Expand All @@ -22,6 +23,10 @@ def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expressio
expression = sqlglot.maybe_parse(expression, dialect="spark")
if expression is None:
raise ValueError(f"Could not parse {expression}")

if isinstance(expression, exp.Column):
expression.transform(Spark.normalize_identifier, copy=False)

self.expression: exp.Expression = expression

def __repr__(self):
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dataframe/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def sql(self, dialect="spark", optimize=True, **kwargs) -> t.List[str]:
expression.alias_or_name: expression.type.sql("spark")
for expression in select_expression.expressions
},
dialect="spark",
)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dataframe/sql/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlglot import expressions as exp
from sqlglot.dataframe.sql.column import Column
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
from sqlglot.dialects import Spark
from sqlglot.helper import ensure_list

NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
Expand All @@ -19,6 +20,7 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
for expression in expressions:
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
Spark.normalize_identifier(identifier)
replace_alias_name_with_cte_name(spark, expression_context, identifier)
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)

Expand Down
14 changes: 6 additions & 8 deletions sqlglot/dataframe/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import sqlglot
from sqlglot import expressions as exp
from sqlglot.helper import object_to_dict, should_identify
from sqlglot.dialects import Spark
from sqlglot.helper import object_to_dict

if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.dataframe import DataFrame
Expand All @@ -18,17 +19,14 @@ def __init__(self, spark: SparkSession):
def table(self, tableName: str) -> DataFrame:
from sqlglot.dataframe.sql.dataframe import DataFrame

sqlglot.schema.add_table(tableName)
sqlglot.schema.add_table(tableName, dialect="spark")

return DataFrame(
self.spark,
exp.Select()
.from_(tableName)
.from_(exp.to_table(tableName, dialect="spark").transform(Spark.normalize_identifier))
.select(
*(
column if should_identify(column, "safe") else f'"{column}"'
for column in sqlglot.schema.column_names(tableName)
)
*(column for column in sqlglot.schema.column_names(tableName, dialect="spark"))
),
)

Expand Down Expand Up @@ -73,7 +71,7 @@ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> Data
)
df = self._df.copy(output_expression_container=output_expression_container)
if self._by_name:
columns = sqlglot.schema.column_names(tableName, only_visible=True)
columns = sqlglot.schema.column_names(tableName, only_visible=True, dialect="spark")
df = df._convert_leaf_to_cte().select(*columns)

return self.copy(_df=df)
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
class BigQuery(Dialect):
UNNEST_COLUMN_ONLY = True

# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

TIME_MAPPING = {
"%D": "%m/%d/%y",
}
Expand Down
75 changes: 67 additions & 8 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,14 @@
from enum import Enum

from sqlglot import exp
from sqlglot._typing import E
from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie

if t.TYPE_CHECKING:
from sqlglot._typing import E


# Only Snowflake is currently known to resolve unquoted identifiers as uppercase.
# https://docs.snowflake.com/en/sql-reference/identifiers-syntax
RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"}


class Dialects(str, Enum):
DIALECT = ""
Expand Down Expand Up @@ -129,6 +122,8 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
if not klass.STRICT_STRING_CONCAT:
klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe

klass.generator_class.can_identify = klass.can_identify

return klass


Expand All @@ -142,6 +137,10 @@ class Dialect(metaclass=_Dialect):
# Determines whether or not the table alias comes after tablesample
ALIAS_POST_TABLESAMPLE = False

# Determines whether or not unquoted identifiers are resolved as uppercase
# When set to None, it means that the dialect treats all identifiers as case-insensitive
RESOLVES_IDENTIFIERS_AS_UPPERCASE: t.Optional[bool] = False
tobymao marked this conversation as resolved.
Show resolved Hide resolved

# Determines whether or not an unquoted identifier can start with a digit
IDENTIFIERS_CAN_START_WITH_DIGIT = False

Expand Down Expand Up @@ -216,6 +215,66 @@ def format_time(

return expression

@classmethod
def normalize_identifier(cls, expression: E) -> E:
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
"""
Normalizes an unquoted identifier to either lower or upper case, thus essentially
making it case-insensitive. If a dialect treats all identifiers as case-insensitive,
they will be normalized regardless of being quoted or not.
"""
if isinstance(expression, exp.Identifier) and (
not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None
):
expression.set(
"this",
expression.this.upper()
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE
else expression.this.lower(),
)

return expression

@classmethod
def case_sensitive(cls, text: str) -> bool:
"""Checks if text contains any case sensitive characters, based on the dialect's rules."""
if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None:
return False

unsafe = str.islower if cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
return any(unsafe(char) for char in text)

@classmethod
def can_identify(cls, text: str, identify: str | bool = "safe") -> bool:
"""Checks if text can be identified given an identify option.

Args:
text: The text to check.
identify:
"always" or `True`: Always returns true.
"safe": True if the identifier is case-insensitive.

Returns:
Whether or not the given text can be identified.
"""
if identify is True or identify == "always":
return True

if identify == "safe":
return not cls.case_sensitive(text)

return False

@classmethod
def quote_identifier(cls, expression: E, identify: bool = True) -> E:
if isinstance(expression, exp.Identifier):
name = expression.this
expression.set(
"quoted",
identify or cls.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name),
)

return expression

def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenize(sql), sql)

Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract
class DuckDB(Dialect):
NULL_ORDERING = "nulls_are_last"

# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ class Hive(Dialect):
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True

# https://spark.apache.org/docs/latest/sql-ref-identifier.html#description
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

TIME_MAPPING = {
"y": "%Y",
"Y": "%Y",
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ class Presto(Dialect):
TIME_MAPPING = MySQL.TIME_MAPPING
STRICT_STRING_CONCAT = True

# https://github.com/trinodb/trino/issues/17
# https://github.com/trinodb/trino/issues/12289
# https://github.com/prestodb/presto/issues/2863
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONEx


class Redshift(Postgres):
# https://docs.aws.amazon.com/redshift/latest/dg/r_names.html
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
**Postgres.TIME_MAPPING,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def _parse_convert_timezone(args: t.List) -> exp.Expression:


class Snowflake(Dialect):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = True
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"

Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def _transform_create(expression: exp.Expression) -> exp.Expression:


class SQLite(Dialect):
# https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None

class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
Expand Down
6 changes: 4 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv, seq_get, should_identify
from sqlglot.helper import apply_index_offset, csv, seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType

Expand Down Expand Up @@ -266,6 +266,8 @@ class Generator:
NORMALIZE_FUNCTIONS: bool | str = "upper"
NULL_ORDERING = "nulls_are_small"

can_identify: t.Callable[[str, str | bool], bool]

# Delimiters for quotes, identifiers and the corresponding escape characters
QUOTE_START = "'"
QUOTE_END = "'"
Expand Down Expand Up @@ -886,7 +888,7 @@ def identifier_sql(self, expression: exp.Identifier) -> str:
text = text.replace(self.IDENTIFIER_END, self._escaped_identifier_end)
if (
expression.quoted
or should_identify(text, self.identify)
or self.can_identify(text, self.identify)
or lower in self.RESERVED_KEYWORDS
or (not self.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit())
):
Expand Down
29 changes: 0 additions & 29 deletions sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
if t.TYPE_CHECKING:
from sqlglot import exp
from sqlglot._typing import E, T
from sqlglot.dialects.dialect import DialectType
from sqlglot.expressions import Expression

CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
Expand Down Expand Up @@ -430,31 +429,3 @@ def first(it: t.Iterable[T]) -> T:
Useful for sets.
"""
return next(i for i in it)


def case_sensitive(text: str, dialect: DialectType) -> bool:
"""Checks if text contains any case sensitive characters depending on dialect."""
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE

unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper
return any(unsafe(char) for char in text)


def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool:
"""Checks if text should be identified given an identify option.

Args:
text: the text to check.
identify:
"always" or `True`: always returns true.
"safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`.
dialect: the dialect to use in order to decide whether a text should be identified.

Returns:
Whether or not a string should be identified.
"""
if identify is True or identify == "always":
return True
if identify == "safe":
return not case_sensitive(text, dialect)
return False
25 changes: 8 additions & 17 deletions sqlglot/optimizer/normalize_identifiers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from sqlglot import exp
from sqlglot._typing import E
from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType
from sqlglot.dialects.dialect import Dialect, DialectType


def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
"""
Normalize all unquoted identifiers to either lower or upper case, depending on
the dialect. This essentially makes those identifiers case-insensitive.
Normalize all unquoted identifiers to either lower or upper case, depending
on the dialect. This essentially makes those identifiers case-insensitive.

Note:
Some dialects (e.g. BigQuery) treat identifiers as case-insensitive even
when they're quoted, so in these cases all identifiers are normalized.

Example:
>>> import sqlglot
Expand All @@ -21,16 +24,4 @@ def normalize_identifiers(expression: E, dialect: DialectType = None) -> E:
Returns:
The transformed expression.
"""
return expression.transform(_normalize, dialect, copy=False)


def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression:
if isinstance(node, exp.Identifier) and not node.quoted:
node.set(
"this",
node.this.upper()
if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE
else node.this.lower(),
)

return node
return expression.transform(Dialect.get_or_raise(dialect).normalize_identifier, copy=False)
Loading