Skip to content

Commit

Permalink
feat: add support for comments in adhoc clauses (apache#19248)
Browse files Browse the repository at this point in the history
* feat: add support for comments in adhoc clauses

* sanitize remaining freeform clauses

* sanitize adhoc having in frontend

* address review comment
  • Loading branch information
villebro authored and yangfei4913438 committed Apr 2, 2022
1 parent c1810a2 commit 1c4c782
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 54 deletions.
10 changes: 6 additions & 4 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
QueryClauseValidationException,
QueryObjectValidationError,
)
from superset.sql_parse import validate_filter_clause
from superset.sql_parse import sanitize_clause
from superset.typing_local import Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
Expand Down Expand Up @@ -281,7 +281,7 @@ def validate(
try:
self._validate_there_are_no_missing_series()
self._validate_no_have_duplicate_labels()
self._validate_filters()
self._sanitize_filters()
return None
except QueryObjectValidationError as ex:
if raise_exceptions:
Expand All @@ -300,12 +300,14 @@ def _validate_no_have_duplicate_labels(self) -> None:
)
)

def _validate_filters(self) -> None:
def _sanitize_filters(self) -> None:
for param in ("where", "having"):
clause = self.extras.get(param)
if clause:
try:
validate_filter_clause(clause)
sanitized_clause = sanitize_clause(clause)
if sanitized_clause != clause:
self.extras[param] = sanitized_clause
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex

Expand Down
19 changes: 13 additions & 6 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@
get_virtual_table_metadata,
)
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import (
QueryClauseValidationException,
QueryObjectValidationError,
)
from superset.jinja_context import (
BaseTemplateProcessor,
ExtraCache,
Expand All @@ -86,7 +89,7 @@
from superset.models.annotations import Annotation
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, sanitize_clause
from superset.typing_local import AdhocMetric, Metric, OrderBy, QueryObjectDict
from superset.utils import core as utils
from superset.utils.core import (
Expand Down Expand Up @@ -858,6 +861,10 @@ def adhoc_metric_to_sqla(
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
Expand Down Expand Up @@ -1265,27 +1272,27 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
where = extras.get("where")
if where:
try:
where = template_processor.process_template(where)
where = template_processor.process_template(f"({where})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in WHERE clause: %(msg)s",
msg=ex.message,
)
) from ex
where_clause_and += [text(f"({where})")]
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
try:
having = template_processor.process_template(having)
having = template_processor.process_template(f"({having})")
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in HAVING clause: %(msg)s",
msg=ex.message,
)
) from ex
having_clause_and += [text(f"({having})")]
having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:
Expand Down
32 changes: 27 additions & 5 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@
remove_quotes,
Token,
TokenList,
Where,
)
from sqlparse.tokens import (
Comment,
CTE,
DDL,
DML,
Keyword,
Name,
Punctuation,
String,
Whitespace,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
Expand Down Expand Up @@ -349,21 +361,31 @@ def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
return str_res


def validate_filter_clause(clause: str) -> None:
if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause):
raise QueryClauseValidationException("Filter clause contains comment")

def sanitize_clause(clause: str) -> str:
# clause = sqlparse.format(clause, strip_comments=True)
statements = sqlparse.parse(clause)
if len(statements) != 1:
raise QueryClauseValidationException("Filter clause contains multiple queries")
raise QueryClauseValidationException("Clause contains multiple statements")
open_parens = 0

previous_token = None
for token in statements[0]:
if token.value == "/" and previous_token and previous_token.value == "*":
raise QueryClauseValidationException("Closing unopened multiline comment")
if token.value == "*" and previous_token and previous_token.value == "/":
raise QueryClauseValidationException("Unclosed multiline comment")
if token.value in (")", "("):
open_parens += 1 if token.value == "(" else -1
if open_parens < 0:
raise QueryClauseValidationException(
"Closing unclosed parenthesis in filter clause"
)
previous_token = token
if open_parens > 0:
raise QueryClauseValidationException("Unclosed parenthesis in filter clause")

if previous_token and previous_token.ttype in Comment:
if previous_token.value[-1] != "\n":
clause = f"{clause}\n"

return clause
7 changes: 5 additions & 2 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
SupersetException,
SupersetTimeoutException,
)
from superset.sql_parse import sanitize_clause
from superset.typing_local import (
AdhocMetric,
AdhocMetricColumn,
Expand Down Expand Up @@ -1392,10 +1393,12 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
}
)
elif expression_type == "SQL":
sql_expression = adhoc_filter.get("sqlExpression")
sql_expression = sanitize_clause(sql_expression)
if clause == "WHERE":
sql_where_filters.append(adhoc_filter.get("sqlExpression"))
sql_where_filters.append(sql_expression)
elif clause == "HAVING":
sql_having_filters.append(adhoc_filter.get("sqlExpression"))
sql_having_filters.append(sql_expression)
form_data["where"] = " AND ".join(
["({})".format(sql) for sql in sql_where_filters]
)
Expand Down
11 changes: 4 additions & 7 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from collections import defaultdict, OrderedDict
from datetime import date, datetime, timedelta
from itertools import product
from collections import Counter
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -61,14 +60,13 @@
from superset.exceptions import (
CacheLoadError,
NullValueException,
QueryClauseValidationException,
QueryObjectValidationError,
SpatialException,
SupersetSecurityException,
)
from superset.extensions import cache_manager, security_manager
from superset.models.helpers import QueryResult
from superset.sql_parse import validate_filter_clause
from superset.sql_parse import sanitize_clause
from superset.typing_local import Metric, QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache
Expand Down Expand Up @@ -362,10 +360,9 @@ def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals
for param in ("where", "having"):
clause = self.form_data.get(param)
if clause:
try:
validate_filter_clause(clause)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sanitized_clause = sanitize_clause(clause)
if sanitized_clause != clause:
self.form_data[param] = sanitized_clause

# extras are used to query elements specific to a datasource type
# for instance the extra where clause that applies only to Tables
Expand Down
55 changes: 25 additions & 30 deletions tests/unit_tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from typing import Set

import pytest
import sqlparse

from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
Table,
validate_filter_clause,
)


Expand Down Expand Up @@ -1091,49 +1092,43 @@ def test_strip_comments_from_sql() -> None:
)


def test_validate_filter_clause_valid():
def test_sanitize_clause_valid():
# regular clauses
assert validate_filter_clause("col = 1") is None
assert validate_filter_clause("1=\t\n1") is None
assert validate_filter_clause("(col = 1)") is None
assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None
assert sanitize_clause("col = 1") == "col = 1"
assert sanitize_clause("1=\t\n1") == "1=\t\n1"
assert sanitize_clause("(col = 1)") == "(col = 1)"
assert sanitize_clause("(col1 = 1) AND (col2 = 2)") == "(col1 = 1) AND (col2 = 2)"
assert sanitize_clause("col = 'abc' -- comment") == "col = 'abc' -- comment\n"

# Valid literal values that appear to be invalid
assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None
assert validate_filter_clause("col = 'select 1; select 2'") is None
assert validate_filter_clause("col = 'abc -- comment'") is None


def test_validate_filter_clause_closing_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("col1 = 1) AND (col2 = 2)")


def test_validate_filter_clause_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(col1 = 1) AND (col2 = 2")
# Valid literal values that at could be flagged as invalid by a naive query parser
assert (
sanitize_clause("col = 'col1 = 1) AND (col2 = 2'")
== "col = 'col1 = 1) AND (col2 = 2'"
)
assert sanitize_clause("col = 'select 1; select 2'") == "col = 'select 1; select 2'"
assert sanitize_clause("col = 'abc -- comment'") == "col = 'abc -- comment'"


def test_validate_filter_clause_closing_and_unclosed():
def test_sanitize_clause_closing_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("col1 = 1) AND (col2 = 2")
sanitize_clause("col1 = 1) AND (col2 = 2)")


def test_validate_filter_clause_closing_and_unclosed_nested():
def test_sanitize_clause_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(col1 = 1)) AND ((col2 = 2)")
sanitize_clause("(col1 = 1) AND (col2 = 2")


def test_validate_filter_clause_multiple():
def test_sanitize_clause_closing_and_unclosed():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("TRUE; SELECT 1")
sanitize_clause("col1 = 1) AND (col2 = 2")


def test_validate_filter_clause_comment():
def test_sanitize_clause_closing_and_unclosed_nested():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("1 = 1 -- comment")
sanitize_clause("(col1 = 1)) AND ((col2 = 2)")


def test_validate_filter_clause_subquery_comment():
def test_sanitize_clause_multiple():
with pytest.raises(QueryClauseValidationException):
validate_filter_clause("(1 = 1 -- comment\n)")
sanitize_clause("TRUE; SELECT 1")

0 comments on commit 1c4c782

Please sign in to comment.