Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: CTE queries with non-SELECT statements #25014

Merged
merged 5 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,17 @@ def limit(self) -> Optional[int]:
def is_select(self) -> bool:
# make sure we strip comments; prevents a bug with comments in the CTE
parsed = sqlparse.parse(self.strip_comments())

# Check if this is a CTE
if parsed[0].is_group and parsed[0][0].ttype == Keyword.CTE:
inner_cte = self.get_inner_cte_expression(parsed[0].tokens) or []
# Check if the inner CTE is a not a SELECT
if any(token.ttype == DDL for token in inner_cte) or any(
token.ttype == DML and token.normalized != "SELECT"
for token in inner_cte
):
return False

if parsed[0].get_type() == "SELECT":
return True

Expand All @@ -241,6 +252,17 @@ def is_select(self) -> bool:
token.ttype == DML and token.normalized == "SELECT" for token in parsed[0]
)

def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
for token in tokens:
if self._is_identifier(token):
for identifier_token in token.tokens:
if (
isinstance(identifier_token, Parenthesis)
and identifier_token.is_group
):
return identifier_token.tokens
return None

def is_valid_ctas(self) -> bool:
parsed = sqlparse.parse(self.strip_comments())
return parsed[-1].get_type() == "SELECT"
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,42 @@ def test_cte_is_select_lowercase() -> None:
assert sql.is_select()


def test_cte_insert_is_not_select() -> None:
"""
Some CTEs with lowercase select are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH foo AS(
INSERT INTO foo (id) VALUES (1) RETURNING 1
) select * FROM foo f"""
)
assert sql.is_select() is False


def test_cte_delete_is_not_select() -> None:
"""
Some CTEs with lowercase select are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH foo AS(
DELETE FROM foo RETURNING *
) select * FROM foo f"""
)
assert sql.is_select() is False


def test_cte_is_not_select_lowercase() -> None:
"""
Some CTEs with lowercase select are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH foo AS(
insert into foo (id) values (1) RETURNING 1
) select * FROM foo f"""
)
assert sql.is_select() is False


def test_unknown_select() -> None:
"""
Test that `is_select` works when sqlparse fails to identify the type.
Expand Down