Skip to content

Commit

Permalink
feat: improve logic in is_select (#17329)
Browse files Browse the repository at this point in the history
* feat: improve logic in is_select

* Add more edge cases
  • Loading branch information
betodealmeida authored Nov 3, 2021
1 parent 9a4ab10 commit 93bafa0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 4 deletions.
25 changes: 22 additions & 3 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
Expand Down Expand Up @@ -133,7 +133,26 @@ def limit(self) -> Optional[int]:
def is_select(self) -> bool:
# make sure we strip comments; prevents a bug with coments in the CTE
parsed = sqlparse.parse(self.strip_comments())
return parsed[0].get_type() == "SELECT"
if parsed[0].get_type() == "SELECT":
return True

if parsed[0].get_type() != "UNKNOWN":
return False

# for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed,
# and no DDL is allowed
if any(token.ttype == DDL for token in parsed[0]) or any(
token.ttype == DML and token.value != "SELECT" for token in parsed[0]
):
return False

# return false on `EXPLAIN`, `SET`, `SHOW`, etc.
if parsed[0][0].ttype == Keyword:
return False

return any(
token.ttype == DML and token.value == "SELECT" for token in parsed[0]
)

def is_valid_ctas(self) -> bool:
parsed = sqlparse.parse(self.strip_comments())
Expand All @@ -150,7 +169,7 @@ def is_explain(self) -> bool:
)

# Explain statements will only be the first statement
return statements_without_comments.startswith("EXPLAIN")
return statements_without_comments.upper().startswith("EXPLAIN")

def is_show(self) -> bool:
# Remove comments
Expand Down
49 changes: 48 additions & 1 deletion tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=invalid-name

import sqlparse

from superset.sql_parse import ParsedQuery


def test_cte_with_comments():
def test_cte_with_comments_is_select():
"""
Some CTES with comments are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH blah AS
(SELECT * FROM core_dev.manager_team),
Expand All @@ -44,3 +51,43 @@ def test_cte_with_comments():
INNER JOIN blah2 ON blah2.team_id = blah.team_id"""
)
assert sql.is_select()


def test_cte_is_select():
"""
Some CTEs are not correctly identified as SELECTS.
"""
# `AS(` gets parsed as a function
sql = ParsedQuery(
"""WITH foo AS(
SELECT
FLOOR(__time TO WEEK) AS "week",
name,
COUNT(DISTINCT user_id) AS "unique_users"
FROM "druid"."my_table"
GROUP BY 1,2
)
SELECT
f.week,
f.name,
f.unique_users
FROM foo f"""
)
assert sql.is_select()


def test_unknown_select():
"""
Test that `is_select` works when sqlparse fails to identify the type.
"""
sql = "WITH foo AS(SELECT 1) SELECT 1"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert ParsedQuery(sql).is_select()

sql = "WITH foo AS(SELECT 1) INSERT INTO my_table (a) VALUES (1)"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert not ParsedQuery(sql).is_select()

sql = "WITH foo AS(SELECT 1) DELETE FROM my_table"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert not ParsedQuery(sql).is_select()

0 comments on commit 93bafa0

Please sign in to comment.