Skip to content

Commit

Permalink
fix: don't strip SQL comments in Explore - 2nd try (apache#28753)
Browse files Browse the repository at this point in the history
(cherry picked from commit 514eda8)
  • Loading branch information
mistercrunch authored and sadpandajoe committed Jun 28, 2024
1 parent e21ec98 commit c45f71b
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/superset-python-integrationtest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
mysql+mysqldb://superset:[email protected]:13306/superset?charset=utf8mb4&binary_prefix=true
services:
mysql:
image: mysql:5.7
image: mysql:8.0
env:
MYSQL_ROOT_PASSWORD: root
ports:
Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,7 +1631,7 @@ def get_from_clause(
if not self.is_virtual:
return self.get_sqla_table(), None

from_sql = self.get_rendered_sql(template_processor)
from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
Expand Down
5 changes: 2 additions & 3 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,9 +916,8 @@ def apply_top_to_sql(cls, sql: str, limit: int) -> str:
cte = None
sql_remainder = None
sql = sql.strip(" \t\n;")
sql_statement = sqlparse.format(sql, strip_comments=True)
query_limit: int | None = sql_parse.extract_top_from_query(
sql_statement, cls.top_keywords
sql, cls.top_keywords
)
if not limit:
final_limit = query_limit
Expand All @@ -927,7 +926,7 @@ def apply_top_to_sql(cls, sql: str, limit: int) -> str:
else:
final_limit = limit
if not cls.allows_cte_in_subquery:
cte, sql_remainder = sql_parse.get_cte_remainder_query(sql_statement)
cte, sql_remainder = sql_parse.get_cte_remainder_query(sql)
if cte:
str_statement = str(sql_remainder)
cte = cte + "\n"
Expand Down
7 changes: 3 additions & 4 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,8 +1060,7 @@ def get_rendered_sql(
"""
Render sql with template engine (Jinja).
"""

sql = self.sql
sql = self.sql.strip("\t\r\n; ")
if template_processor:
try:
sql = template_processor.process_template(sql)
Expand All @@ -1072,7 +1071,7 @@ def get_rendered_sql(
msg=ex.message,
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
sql = sqlparse.format(sql.strip("\t\r\n; "))
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
Expand All @@ -1093,7 +1092,7 @@ def get_from_clause(
CTE, the CTE is returned as the second value in the return tuple.
"""

from_sql = self.get_rendered_sql(template_processor)
from_sql = self.get_rendered_sql(template_processor) + "\n"
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
Expand Down
81 changes: 62 additions & 19 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import contextlib
import functools
import os
from textwrap import dedent
from typing import Any, Callable, TYPE_CHECKING
from unittest.mock import patch

Expand Down Expand Up @@ -296,25 +297,67 @@ def virtual_dataset():
dataset = SqlaTable(
table_name="virtual_dataset",
sql=(
"SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5, 1 as col6 "
"UNION ALL "
"SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00', NULL "
"UNION ALL "
"SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00', 3 "
"UNION ALL "
"SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00', 4 "
"UNION ALL "
"SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00', 5 "
"UNION ALL "
"SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6 "
"UNION ALL "
"SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00', 7 "
"UNION ALL "
"SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00', 8 "
"UNION ALL "
"SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00', 9 "
"UNION ALL "
"SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00', 10"
dedent("""\
SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5, 1 as col6
UNION ALL
SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00', NULL
UNION ALL
SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00', 3
UNION ALL
SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00', 4
UNION ALL
SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00', 5
UNION ALL
SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6
UNION ALL
SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00', 7
UNION ALL
SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00', 8
UNION ALL
SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00', 9
UNION ALL
SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00', 10
""")
),
database=get_example_database(),
)
TableColumn(column_name="col1", type="INTEGER", table=dataset)
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
# Different database dialect datetime type is not consistent, so temporarily use varchar
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col6", type="INTEGER", table=dataset)

SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.add(dataset)
db.session.commit()

yield dataset

db.session.delete(dataset)
db.session.commit()


@pytest.fixture
def virtual_dataset_with_comments():
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn

dataset = SqlaTable(
table_name="virtual_dataset_with_comments",
sql=(
dedent("""\
--COMMENT
/*COMMENT*/
WITH cte as (--COMMENT
SELECT 2 as col1, /*COMMENT*/'j' as col2, 1.9, NULL, '2000-01-10 00:00:00', 10
)
SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5, 1 as col6
\n /* COMMENT */ \n
UNION ALL/*COMMENT*/
SELECT 1 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6 --COMMENT
UNION ALL--COMMENT
SELECT * FROM cte --COMMENT""")
),
database=get_example_database(),
)
Expand Down
4 changes: 3 additions & 1 deletion tests/integration_tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,9 @@ def test_comments_in_sqlatable_query(self):
database=get_example_database(),
)
rendered_query = str(table.get_from_clause()[0])
self.assertEqual(clean_query, rendered_query)
assert "comment 1" in rendered_query
assert "comment 2" in rendered_query
assert "FROM tbl" in rendered_query

def test_slice_payload_no_datasource(self):
form_data = {
Expand Down
8 changes: 5 additions & 3 deletions tests/integration_tests/datasource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,12 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
assert "coltypes" in rv2.json["result"]
assert "data" in rv2.json["result"]

eager_samples = virtual_dataset.database.get_df(
f"select * from ({virtual_dataset.sql}) as tbl"
f' limit {app.config["SAMPLES_ROW_LIMIT"]}'
sql = (
f"select * from ({virtual_dataset.sql}) as tbl "
f'limit {app.config["SAMPLES_ROW_LIMIT"]}'
)
eager_samples = virtual_dataset.database.get_df(sql)

# the col3 is Decimal
eager_samples["col3"] = eager_samples["col3"].apply(float)
eager_samples = eager_samples.to_dict(orient="records")
Expand Down
31 changes: 31 additions & 0 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,3 +1164,34 @@ def test_time_offset_with_temporal_range_filter(app_context, physical_dataset):
re.search(r"WHERE col6 >= .*2001-10-01", sqls[1])
and re.search(r"AND col6 < .*2002-10-01", sqls[1])
) is not None


def test_virtual_dataset_with_comments(app_context, virtual_dataset_with_comments):
qc = QueryContextFactory().create(
datasource={
"type": virtual_dataset_with_comments.type,
"id": virtual_dataset_with_comments.id,
},
queries=[
{
"columns": ["col1", "col2"],
"metrics": ["count"],
"post_processing": [
{
"operation": "pivot",
"options": {
"aggregates": {"count": {"operator": "mean"}},
"columns": ["col2"],
"index": ["col1"],
},
},
{"operation": "flatten"},
],
}
],
result_type=ChartDataResultType.FULL,
force=True,
)
query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
assert len(df) == 3

0 comments on commit c45f71b

Please sign in to comment.