Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve logic in is_select #17329

Merged
merged 2 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
Expand Down Expand Up @@ -133,7 +133,26 @@ def limit(self) -> Optional[int]:
def is_select(self) -> bool:
# make sure we strip comments; prevents a bug with coments in the CTE
parsed = sqlparse.parse(self.strip_comments())
return parsed[0].get_type() == "SELECT"
if parsed[0].get_type() == "SELECT":
return True

if parsed[0].get_type() != "UNKNOWN":
return False

# for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed,
# and no DDL is allowed
if any(token.ttype == DDL for token in parsed[0]) or any(
token.ttype == DML and token.value != "SELECT" for token in parsed[0]
):
return False

# return false on `EXPLAIN`, `SET`, `SHOW`, etc.
if parsed[0][0].ttype == Keyword:
return False

return any(
token.ttype == DML and token.value == "SELECT" for token in parsed[0]
)

def is_valid_ctas(self) -> bool:
parsed = sqlparse.parse(self.strip_comments())
Expand All @@ -150,7 +169,7 @@ def is_explain(self) -> bool:
)

# Explain statements will only be the first statement
return statements_without_comments.startswith("EXPLAIN")
return statements_without_comments.upper().startswith("EXPLAIN")

def is_show(self) -> bool:
# Remove comments
Expand Down
49 changes: 48 additions & 1 deletion tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name

import sqlparse

from superset.sql_parse import ParsedQuery


def test_cte_with_comments():
def test_cte_with_comments_is_select():
"""
Some CTES with comments are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH blah AS
(SELECT * FROM core_dev.manager_team),
Expand All @@ -44,3 +51,43 @@ def test_cte_with_comments():
INNER JOIN blah2 ON blah2.team_id = blah.team_id"""
)
assert sql.is_select()


def test_cte_is_select():
"""
Some CTEs are not correctly identified as SELECTS.
"""
# `AS(` gets parsed as a function
sql = ParsedQuery(
"""WITH foo AS(
SELECT
FLOOR(__time TO WEEK) AS "week",
name,
COUNT(DISTINCT user_id) AS "unique_users"
FROM "druid"."my_table"
GROUP BY 1,2
)
SELECT
f.week,
f.name,
f.unique_users
FROM foo f"""
)
assert sql.is_select()


def test_unknown_select():
"""
Test that `is_select` works when sqlparse fails to identify the type.
"""
sql = "WITH foo AS(SELECT 1) SELECT 1"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert ParsedQuery(sql).is_select()

sql = "WITH foo AS(SELECT 1) INSERT INTO my_table (a) VALUES (1)"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert not ParsedQuery(sql).is_select()

sql = "WITH foo AS(SELECT 1) DELETE FROM my_table"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert not ParsedQuery(sql).is_select()