From 1c4c7828c376410cbd437fe2a30e43d67b605a21 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Sat, 19 Mar 2022 06:08:06 +0800 Subject: [PATCH] feat: add support for comments in adhoc clauses (#19248) * feat: add support for comments in adhoc clauses * sanitize remaining freeform clauses * sanitize adhoc having in frontend * address review comment --- superset/common/query_object.py | 10 +++--- superset/connectors/sqla/models.py | 19 ++++++---- superset/sql_parse.py | 32 ++++++++++++++--- superset/utils/core.py | 7 ++-- superset/viz.py | 11 +++--- tests/unit_tests/sql_parse_tests.py | 55 +++++++++++++---------------- 6 files changed, 80 insertions(+), 54 deletions(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 257ab83607798..b6ff1d3ab483b 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -29,7 +29,7 @@ QueryClauseValidationException, QueryObjectValidationError, ) -from superset.sql_parse import validate_filter_clause +from superset.sql_parse import sanitize_clause from superset.typing_local import Metric, OrderBy from superset.utils import pandas_postprocessing from superset.utils.core import ( @@ -281,7 +281,7 @@ def validate( try: self._validate_there_are_no_missing_series() self._validate_no_have_duplicate_labels() - self._validate_filters() + self._sanitize_filters() return None except QueryObjectValidationError as ex: if raise_exceptions: @@ -300,12 +300,14 @@ def _validate_no_have_duplicate_labels(self) -> None: ) ) - def _validate_filters(self) -> None: + def _sanitize_filters(self) -> None: for param in ("where", "having"): clause = self.extras.get(param) if clause: try: - validate_filter_clause(clause) + sanitized_clause = sanitize_clause(clause) + if sanitized_clause != clause: + self.extras[param] = sanitized_clause except QueryClauseValidationException as ex: raise QueryObjectValidationError(ex.message) from ex diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 44c4f06f68054..80bbb1729c01e 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -77,7 +77,10 @@ get_virtual_table_metadata, ) from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression -from superset.exceptions import QueryObjectValidationError +from superset.exceptions import ( + QueryClauseValidationException, + QueryObjectValidationError, +) from superset.jinja_context import ( BaseTemplateProcessor, ExtraCache, @@ -86,7 +89,7 @@ from superset.models.annotations import Annotation from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, sanitize_clause from superset.typing_local import AdhocMetric, Metric, OrderBy, QueryObjectDict from superset.utils import core as utils from superset.utils.core import ( @@ -858,6 +861,10 @@ def adhoc_metric_to_sqla( elif expression_type == utils.AdhocMetricExpressionType.SQL: tp = self.get_template_processor() expression = tp.process_template(cast(str, metric["sqlExpression"])) + try: + expression = sanitize_clause(expression) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex sqla_metric = literal_column(expression) else: raise QueryObjectValidationError("Adhoc metric expressionType is invalid") @@ -1265,7 +1272,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma where = extras.get("where") if where: try: - where = template_processor.process_template(where) + where = template_processor.process_template(f"({where})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1273,11 +1280,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - where_clause_and += [text(f"({where})")] + where_clause_and += [self.text(where)] having = extras.get("having") if having: try: - having = template_processor.process_template(having) + having = template_processor.process_template(f"({having})") except TemplateError as ex: raise QueryObjectValidationError( _( @@ -1285,7 +1292,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma msg=ex.message, ) ) from ex - having_clause_and += [text(f"({having})")] + having_clause_and += [self.text(having)] if apply_fetch_values_predicate and self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) if granularity: diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 8b18b9f8ecba9..4134a9eb14b56 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -28,6 +28,18 @@ remove_quotes, Token, TokenList, + Where, +) +from sqlparse.tokens import ( + Comment, + CTE, + DDL, + DML, + Keyword, + Name, + Punctuation, + String, + Whitespace, ) from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace from sqlparse.utils import imt @@ -349,21 +361,31 @@ def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: return str_res -def validate_filter_clause(clause: str) -> None: - if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause): - raise QueryClauseValidationException("Filter clause contains comment") - +def sanitize_clause(clause: str) -> str: + # clause = sqlparse.format(clause, strip_comments=True) statements = sqlparse.parse(clause) if len(statements) != 1: - raise QueryClauseValidationException("Filter clause contains multiple queries") + raise QueryClauseValidationException("Clause contains multiple statements") open_parens = 0 + previous_token = None for token in statements[0]: + if token.value == "/" and previous_token and previous_token.value == "*": + raise QueryClauseValidationException("Closing unopened multiline comment") + if token.value == "*" and previous_token and previous_token.value == "/": + raise QueryClauseValidationException("Unclosed multiline comment") if token.value in (")", "("): open_parens += 1 if token.value == "(" else -1 if open_parens < 0: raise QueryClauseValidationException( "Closing unclosed parenthesis in filter clause" ) + previous_token = token if open_parens > 0: raise QueryClauseValidationException("Unclosed parenthesis in filter clause") + + if previous_token and previous_token.ttype in Comment: + if previous_token.value[-1] != "\n": + clause = f"{clause}\n" + + return clause diff --git a/superset/utils/core.py b/superset/utils/core.py index 3d18bb870391a..dd950b9387362 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -97,6 +97,7 @@ SupersetException, SupersetTimeoutException, ) +from superset.sql_parse import sanitize_clause from superset.typing_local import ( AdhocMetric, AdhocMetricColumn, @@ -1392,10 +1393,12 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name } ) elif expression_type == "SQL": + sql_expression = adhoc_filter.get("sqlExpression") + sql_expression = sanitize_clause(sql_expression) if clause == "WHERE": - sql_where_filters.append(adhoc_filter.get("sqlExpression")) + sql_where_filters.append(sql_expression) elif clause == "HAVING": - sql_having_filters.append(adhoc_filter.get("sqlExpression")) + sql_having_filters.append(sql_expression) form_data["where"] = " AND ".join( ["({})".format(sql) for sql in sql_where_filters] ) diff --git a/superset/viz.py b/superset/viz.py index 3bc1962c82de2..7a39ab8253ad2 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -28,7 +28,6 @@ from collections import defaultdict, OrderedDict from datetime import date, datetime, timedelta from itertools import product -from collections import Counter from typing import ( Any, Callable, @@ -61,14 +60,13 @@ from superset.exceptions import ( CacheLoadError, NullValueException, - QueryClauseValidationException, QueryObjectValidationError, SpatialException, SupersetSecurityException, ) from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult -from superset.sql_parse import validate_filter_clause +from superset.sql_parse import sanitize_clause from superset.typing_local import Metric, QueryObjectDict, VizData, VizPayload from superset.utils import core as utils, csv from superset.utils.cache import set_and_log_cache @@ -362,10 +360,9 @@ def query_obj(self) -> QueryObjectDict: # pylint: disable=too-many-locals for param in ("where", "having"): clause = self.form_data.get(param) if clause: - try: - validate_filter_clause(clause) - except QueryClauseValidationException as ex: - raise QueryObjectValidationError(ex.message) from ex + sanitized_clause = sanitize_clause(clause) + if sanitized_clause != clause: + self.form_data[param] = sanitized_clause # extras are used to query elements specific to a datasource type # for instance the extra where clause that applies only to Tables diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 410e55229d4fd..3365db83d0425 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -19,13 +19,14 @@ from typing import Set import pytest +import sqlparse from superset.exceptions import QueryClauseValidationException from superset.sql_parse import ( ParsedQuery, + sanitize_clause, strip_comments_from_sql, Table, - validate_filter_clause, ) @@ -1091,49 +1092,43 @@ def test_strip_comments_from_sql() -> None: ) -def test_validate_filter_clause_valid(): +def test_sanitize_clause_valid(): # regular clauses - assert validate_filter_clause("col = 1") is None - assert validate_filter_clause("1=\t\n1") is None - assert validate_filter_clause("(col = 1)") is None - assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None + assert sanitize_clause("col = 1") == "col = 1" + assert sanitize_clause("1=\t\n1") == "1=\t\n1" + assert sanitize_clause("(col = 1)") == "(col = 1)" + assert sanitize_clause("(col1 = 1) AND (col2 = 2)") == "(col1 = 1) AND (col2 = 2)" + assert sanitize_clause("col = 'abc' -- comment") == "col = 'abc' -- comment\n" - # Valid literal values that appear to be invalid - assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None - assert validate_filter_clause("col = 'select 1; select 2'") is None - assert validate_filter_clause("col = 'abc -- comment'") is None - - -def test_validate_filter_clause_closing_unclosed(): - with pytest.raises(QueryClauseValidationException): - validate_filter_clause("col1 = 1) AND (col2 = 2)") - - -def test_validate_filter_clause_unclosed(): - with pytest.raises(QueryClauseValidationException): - validate_filter_clause("(col1 = 1) AND (col2 = 2") + # Valid literal values that at could be flagged as invalid by a naive query parser + assert ( + sanitize_clause("col = 'col1 = 1) AND (col2 = 2'") + == "col = 'col1 = 1) AND (col2 = 2'" + ) + assert sanitize_clause("col = 'select 1; select 2'") == "col = 'select 1; select 2'" + assert sanitize_clause("col = 'abc -- comment'") == "col = 'abc -- comment'" -def test_validate_filter_clause_closing_and_unclosed(): +def test_sanitize_clause_closing_unclosed(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("col1 = 1) AND (col2 = 2") + sanitize_clause("col1 = 1) AND (col2 = 2)") -def test_validate_filter_clause_closing_and_unclosed_nested(): +def test_sanitize_clause_unclosed(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("(col1 = 1)) AND ((col2 = 2)") + sanitize_clause("(col1 = 1) AND (col2 = 2") -def test_validate_filter_clause_multiple(): +def test_sanitize_clause_closing_and_unclosed(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("TRUE; SELECT 1") + sanitize_clause("col1 = 1) AND (col2 = 2") -def test_validate_filter_clause_comment(): +def test_sanitize_clause_closing_and_unclosed_nested(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("1 = 1 -- comment") + sanitize_clause("(col1 = 1)) AND ((col2 = 2)") -def test_validate_filter_clause_subquery_comment(): +def test_sanitize_clause_multiple(): with pytest.raises(QueryClauseValidationException): - validate_filter_clause("(1 = 1 -- comment\n)") + sanitize_clause("TRUE; SELECT 1")