From 88c09c21380fdfe072f9c235d5b282b921d9b9f3 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Wed, 15 Sep 2021 20:00:26 +0300 Subject: [PATCH] feat(jinja): improve url parameter formatting (#16711) * feat(jinja): improve url parameter formatting * add UPDATING.md * fix test --- UPDATING.md | 3 ++ superset/jinja_context.py | 22 +++++++++++-- tests/integration_tests/base_tests.py | 13 ++++++-- .../integration_tests/jinja_context_tests.py | 31 +++++++++++++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/UPDATING.md b/UPDATING.md index 711ccad1a6717..b319c196279f1 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -25,6 +25,9 @@ assists people when migrating to a new version. ## Next ### Breaking Changes + +- [16711](https://github.com/apache/incubator-superset/pull/16711): The `url_param` Jinja function will now by default escape the result. For instance, the value `O'Brien` will now be changed to `O''Brien`. To disable this behavior, call `url_param` with `escape_result` set to `False`: `url_param("my_key", "my default", escape_result=False)`. + ### Potential Downtime ### Deprecations ### Other diff --git a/superset/jinja_context.py b/superset/jinja_context.py index ffcf497d6f231..e6a4cab963fb7 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -34,6 +34,8 @@ from flask_babel import gettext as _ from jinja2 import DebugUndefined from jinja2.sandbox import SandboxedEnvironment +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.types import String from typing_extensions import TypedDict from superset.exceptions import SupersetTemplateException @@ -95,9 +97,11 @@ def __init__( self, extra_cache_keys: Optional[List[Any]] = None, removed_filters: Optional[List[str]] = None, + dialect: Optional[Dialect] = None, ): self.extra_cache_keys = extra_cache_keys self.removed_filters = removed_filters if removed_filters is not None else [] + self.dialect = dialect def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]: """ @@ -145,7 +149,11 @@ def cache_key_wrapper(self, key: Any) -> Any: return key def url_param( - self, param: str, default: Optional[str] = None, add_to_cache_keys: bool = True + self, + param: str, + default: Optional[str] = None, + add_to_cache_keys: bool = True, + escape_result: bool = True, ) -> Optional[str]: """ Read a url or post parameter and use it in your SQL Lab query. @@ -166,6 +174,7 @@ def url_param( :param param: the parameter to lookup :param default: the value to return in the absence of the parameter :param add_to_cache_keys: Whether the value should be included in the cache key + :param escape_result: Should special characters in the result be escaped :returns: The URL parameters """ @@ -178,6 +187,11 @@ def url_param( form_data, _ = get_form_data() url_params = form_data.get("url_params") or {} result = url_params.get(param, default) + if result and escape_result and self.dialect: + # use the dialect specific quoting logic to escape string + result = String().literal_processor(dialect=self.dialect)(value=result)[ + 1:-1 + ] if add_to_cache_keys: self.cache_key_wrapper(result) return result @@ -430,7 +444,11 @@ def process_template(self, sql: str, **kwargs: Any) -> str: class JinjaTemplateProcessor(BaseTemplateProcessor): def set_context(self, **kwargs: Any) -> None: super().set_context(**kwargs) - extra_cache = ExtraCache(self._extra_cache_keys, self._removed_filters) + extra_cache = ExtraCache( + extra_cache_keys=self._extra_cache_keys, + removed_filters=self._removed_filters, + dialect=self._database.get_dialect(), + ) self._context.update( { "url_param": partial(safe_proxy, extra_cache.url_param), diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 7e4ebfd7df035..e808badf1fe1a 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -28,9 +28,11 @@ from flask import Response from flask_appbuilder.security.sqla import models as ab_models from flask_testing import TestCase +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.ext.declarative.api import DeclarativeMeta from sqlalchemy.orm import Session from sqlalchemy.sql import func +from sqlalchemy.dialects.mysql import dialect from tests.integration_tests.test_app import app from superset.sql_parse import CtasMethod @@ -422,7 +424,7 @@ def create_fake_db_for_macros(self): self.login(username="admin") database_name = "db_for_macros_testing" db_id = 200 - return self.get_or_create( + database = self.get_or_create( cls=models.Database, criteria={"database_name": database_name}, session=db.session, @@ -430,7 +432,14 @@ def create_fake_db_for_macros(self): id=db_id, ) - def delete_fake_db_for_macros(self): + def mock_get_dialect() -> Dialect: + return dialect() + + database.get_dialect = mock_get_dialect + return database + + @staticmethod + def delete_fake_db_for_macros(): database = ( db.session.query(Database) .filter(Database.database_name == "db_for_macros_testing") diff --git a/tests/integration_tests/jinja_context_tests.py b/tests/integration_tests/jinja_context_tests.py index a990968e5adae..b82adfa05f9dc 100644 --- a/tests/integration_tests/jinja_context_tests.py +++ b/tests/integration_tests/jinja_context_tests.py @@ -20,6 +20,7 @@ from unittest import mock import pytest +from sqlalchemy.dialects.postgresql import dialect import tests.integration_tests.test_app from superset import app @@ -199,6 +200,36 @@ def test_url_param_form_data(self) -> None: cache = ExtraCache() self.assertEqual(cache.url_param("foo"), "bar") + def test_url_param_escaped_form_data(self) -> None: + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} + ): + cache = ExtraCache(dialect=dialect()) + self.assertEqual(cache.url_param("foo"), "O''Brien") + + def test_url_param_escaped_default_form_data(self) -> None: + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} + ): + cache = ExtraCache(dialect=dialect()) + self.assertEqual(cache.url_param("bar", "O'Malley"), "O''Malley") + + def test_url_param_unescaped_form_data(self) -> None: + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} + ): + cache = ExtraCache(dialect=dialect()) + self.assertEqual(cache.url_param("foo", escape_result=False), "O'Brien") + + def test_url_param_unescaped_default_form_data(self) -> None: + with app.test_request_context( + query_string={"form_data": json.dumps({"url_params": {"foo": "O'Brien"}})} + ): + cache = ExtraCache(dialect=dialect()) + self.assertEqual( + cache.url_param("bar", "O'Malley", escape_result=False), "O'Malley" + ) + def test_safe_proxy_primitive(self) -> None: def func(input: Any) -> Any: return input