Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add annotations to sql_parse.py #27520

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

# pylint: disable=too-many-lines

from __future__ import annotations

import logging
import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional, Union
from typing import Any, cast

import sqlglot
import sqlparse
Expand Down Expand Up @@ -138,7 +140,7 @@ class CtasMethod(StrEnum):
VIEW = "VIEW"


def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
def _extract_limit_from_query(statement: TokenList) -> int | None:
"""
Extract limit clause from SQL statement.
Expand All @@ -159,9 +161,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None


def extract_top_from_query(
statement: TokenList, top_keywords: set[str]
) -> Optional[int]:
def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
"""
Extract top clause value from SQL statement.
Expand All @@ -185,15 +185,15 @@ def extract_top_from_query(
return top


def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
"""
parse the SQL and return the CTE and rest of the block to the caller
:param sql: SQL query
:return: CTE and remainder block to the caller
"""
cte: Optional[str] = None
cte: str | None = None
remainder = sql
stmt = sqlparse.parse(sql)[0]

Expand All @@ -211,7 +211,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder


def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
Expand All @@ -235,8 +235,8 @@ class Table:
"""

table: str
schema: Optional[str] = None
catalog: Optional[str] = None
schema: str | None = None
catalog: str | None = None

def __str__(self) -> str:
"""
Expand All @@ -255,7 +255,7 @@ def __eq__(self, __o: object) -> bool:

def extract_tables_from_statement(
statement: exp.Expression,
dialect: Optional[Dialects],
dialect: Dialects | None,
) -> set[Table]:
"""
Extract all table references in a single statement.
Expand Down Expand Up @@ -334,7 +334,7 @@ class SQLScript:
def __init__(
self,
query: str,
engine: Optional[str] = None,
engine: str | None = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None

Expand Down Expand Up @@ -375,8 +375,8 @@ class SQLStatement:

def __init__(
self,
statement: Union[str, exp.Expression],
engine: Optional[str] = None,
statement: str | exp.Expression,
engine: str | None = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None

Expand All @@ -394,7 +394,7 @@ def __init__(
@staticmethod
def _parse_statement(
sql_statement: str,
dialect: Optional[Dialects],
dialect: Dialects | None,
) -> exp.Expression:
"""
Parse a single SQL statement.
Expand Down Expand Up @@ -437,7 +437,7 @@ def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
engine: Optional[str] = None,
engine: str | None = None,
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
Expand All @@ -446,7 +446,7 @@ def __init__(
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: Optional[int] = None
self._limit: int | None = None

logger.debug("Parsing with sqlparse statement: %s", self.sql)
self._parsed = sqlparse.parse(self.stripped())
Expand Down Expand Up @@ -550,7 +550,7 @@ def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
return source.name in ctes_in_scope

@property
def limit(self) -> Optional[int]:
def limit(self) -> int | None:
return self._limit

def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -631,7 +631,7 @@ def is_select(self) -> bool:

return True

def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
for token in tokens:
if self._is_identifier(token):
for identifier_token in token.tokens:
Expand Down Expand Up @@ -695,7 +695,7 @@ def get_statements(self) -> list[str]:
return statements

@staticmethod
def get_table(tlist: TokenList) -> Optional[Table]:
def get_table(tlist: TokenList) -> Table | None:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
Expand Down Expand Up @@ -731,7 +731,7 @@ def _is_identifier(token: Token) -> bool:
def as_create_table(
self,
table_name: str,
schema_name: Optional[str] = None,
schema_name: str | None = None,
overwrite: bool = False,
method: CtasMethod = CtasMethod.TABLE,
) -> str:
Expand Down Expand Up @@ -891,8 +891,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: Optional[str],
) -> Optional[TokenList]:
default_schema: str | None,
) -> TokenList | None:
"""
Given a table name, return any associated RLS predicates.
"""
Expand Down Expand Up @@ -938,7 +938,7 @@ def get_rls_for_table(
def insert_rls_as_subquery(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
default_schema: str | None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
Expand All @@ -954,7 +954,7 @@ def insert_rls_as_subquery(
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
databases.
"""
rls: Optional[TokenList] = None
rls: TokenList | None = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def insert_rls_as_subquery(
def insert_rls_in_predicate(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
default_schema: str | None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
Expand All @@ -1041,7 +1041,7 @@ def insert_rls_in_predicate(
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
"""
rls: Optional[TokenList] = None
rls: TokenList | None = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def insert_rls_in_predicate(

def extract_table_references(
sql_text: str, sqla_dialect: str, show_warning: bool = True
) -> set["Table"]:
) -> set[Table]:
"""
Return all the dependencies from a SQL sql_text.
"""
Expand Down
Loading