diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 73468d651cbc5..ed16f876d43c8 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -48,6 +48,8 @@ from superset.views.base import CsvResponse, generate_download_headers from superset.views.base_api import statsd_metrics +from .commands.command_factory import GetChartDataCommandFactory + if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -55,6 +57,7 @@ class ChartDataRestApi(ChartRestApi): + include_route_methods = {"get_data", "data", "data_from_cache"} @expose("//data/", methods=["GET"]) @@ -133,7 +136,7 @@ def get_data(self, pk: int) -> Response: try: query_context = self._create_query_context_from_form(json_body) - command = ChartDataCommand(query_context) + command = GetChartDataCommandFactory.make(query_context) command.validate() except QueryObjectValidationError as error: return self.response_400(message=error.message) @@ -150,7 +153,7 @@ def get_data(self, pk: int) -> Response: and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - return self._run_async(json_body, command) + return self._run_async(json_body, command) # type: ignore try: form_data = json.loads(chart.params) @@ -158,7 +161,9 @@ def get_data(self, pk: int) -> Response: form_data = {} return self._get_data_response( - command=command, form_data=form_data, datasource=query_context.datasource + command=command, # type: ignore + form_data=form_data, + datasource=query_context.datasource, ) @expose("/data", methods=["POST"]) @@ -221,7 +226,7 @@ def data(self) -> Response: try: query_context = self._create_query_context_from_form(json_body) - command = ChartDataCommand(query_context) + command = GetChartDataCommandFactory.make(query_context) command.validate() except QueryObjectValidationError as error: return self.response_400(message=error.message) @@ -238,11 +243,13 @@ def data(self) -> Response: and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): - return self._run_async(json_body, command) + return self._run_async(json_body, command) # type: ignore form_data = json_body.get("form_data") return self._get_data_response( - command, form_data=form_data, datasource=query_context.datasource + command, # type: ignore + form_data=form_data, + datasource=query_context.datasource, ) @expose("/data/", methods=["GET"]) @@ -288,7 +295,7 @@ def data_from_cache(self, cache_key: str) -> Response: try: cached_data = self._load_query_context_form_from_cache(cache_key) query_context = self._create_query_context_from_form(cached_data) - command = ChartDataCommand(query_context) + command = GetChartDataCommandFactory.make(query_context) command.validate() except ChartDataCacheLoadError: return self.response_404() @@ -297,7 +304,7 @@ def data_from_cache(self, cache_key: str) -> Response: message=_("Request is incorrect: %(error)s", error=error.messages) ) - return self._get_data_response(command, True) + return self._get_data_response(command, True) # type: ignore def _run_async( self, form_data: Dict[str, Any], command: ChartDataCommand diff --git a/superset/charts/data/commands/get_data_command.py b/superset/charts/data/commands/get_data_command.py index 95f7513f253f6..a04c4c4fd413b 100644 --- a/superset/charts/data/commands/get_data_command.py +++ b/superset/charts/data/commands/get_data_command.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import logging -from typing import Any, Dict +from typing import Any, Dict, TYPE_CHECKING from superset.charts.commands.exceptions import ( ChartDataCacheLoadError, @@ -27,12 +29,17 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from ..query_context_validators import QueryContextValidator + class ChartDataCommand(BaseCommand): _query_context: QueryContext + _validator: QueryContextValidator - def __init__(self, query_context: QueryContext): + def __init__(self, query_context: QueryContext, validator: QueryContextValidator): self._query_context = query_context + self._validator = validator def run(self, **kwargs: Any) -> Dict[str, Any]: # caching is handled in query_context.get_df_payload @@ -61,4 +68,4 @@ def run(self, **kwargs: Any) -> Dict[str, Any]: return return_value def validate(self) -> None: - self._query_context.raise_for_access() + self._validator.validate(self._query_context) diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 1bc6b0b82417d..f146805726bad 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -473,6 +473,7 @@ def init_app_in_ctx(self) -> None: self.configure_url_map_converters() self.configure_data_sources() self.configure_auth_provider() + self.init_factories() self.configure_async_queries() # Hook that provides administrators a handle on the Flask APP @@ -483,6 +484,9 @@ def init_app_in_ctx(self) -> None: self.init_views() + def init_factories(self) -> None: + self._init_chart_data_command_factory() + def check_secret_key(self) -> None: if self.config["SECRET_KEY"] == CHANGE_ME_SECRET_KEY: top_banner = 80 * "-" + "\n" + 36 * " " + "WARNING\n" + 80 * "-" @@ -667,6 +671,23 @@ def enable_profiling(self) -> None: if self.config["PROFILING"]: profiling.init_app(self.superset_app) + @staticmethod + def _init_chart_data_command_factory() -> None: # pylint: disable=invalid-name + # pylint: disable=import-outside-toplevel + from superset import security_manager + from superset.charts.data.commands.command_factory import ( + GetChartDataCommandFactory, + ) + from superset.charts.data.query_context_validators.validaor_factory import ( + QueryContextValidatorFactory, + ) + from superset.datasets.dao import DatasetDAO + + query_context_validator_factory = QueryContextValidatorFactory( + security_manager, DatasetDAO() # type: ignore + ) + GetChartDataCommandFactory.init(query_context_validator_factory) + class SupersetIndexView(IndexView): @expose("/") diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 74adcd080c0c3..9549d44c4468a 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -24,6 +24,7 @@ from flask import current_app, g from marshmallow import ValidationError +from superset.charts.data.commands.command_factory import GetChartDataCommandFactory from superset.charts.schemas import ChartDataQueryContextSchema from superset.exceptions import SupersetVizException from superset.extensions import ( @@ -73,15 +74,12 @@ def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], ) -> None: - # pylint: disable=import-outside-toplevel - from superset.charts.data.commands.get_data_command import ChartDataCommand - try: ensure_user_is_set(job_metadata.get("user_id")) set_form_data(form_data) query_context = _create_query_context_from_form(form_data) - command = ChartDataCommand(query_context) - result = command.run(cache=True) + command = GetChartDataCommandFactory.make(query_context) + result = command.run(cache=True) # type: ignore cache_key = result["cache_key"] result_url = f"/api/v1/chart/data/{cache_key}" async_query_manager.update_job( diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 73425fb58f68c..d63a0f8e7356e 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -37,6 +37,9 @@ load_birth_names_dashboard_with_slices, load_birth_names_data, ) +from tests.integration_tests.fixtures.single_column_table import ( + load_single_column_example_datasource, +) from tests.integration_tests.test_app import app import pytest @@ -44,7 +47,7 @@ from superset.charts.data.commands.get_data_command import ChartDataCommand from superset.connectors.sqla.models import TableColumn, SqlaTable from superset.errors import SupersetErrorType -from superset.extensions import async_query_manager, db +from superset.extensions import async_query_manager, db, security_manager from superset.models.annotations import AnnotationLayer from superset.models.slice import Slice from superset.superset_typing import AdhocColumn @@ -57,7 +60,10 @@ from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from tests.common.query_context_generator import ANNOTATION_LAYERS -from tests.integration_tests.fixtures.query_context import get_query_context +from tests.integration_tests.fixtures.query_context import ( + QueryContextGeneratorInteg, + get_query_context, +) CHART_DATA_URI = "api/v1/chart/data" @@ -766,8 +772,82 @@ def test_with_virtual_table_with_colons_as_datasource(self): assert "':xyz:qwerty'" in result["query"] assert "':qwerty:'" in result["query"] + @with_feature_flags(QUERY_CONTEXT_VALIDATION_SQL_EXPRESSION=True) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_when_actor_does_not_have_permission_to_datasource__403(self): + self.test_with_not_permitted_actor__403() + + @with_feature_flags(QUERY_CONTEXT_VALIDATION_SQL_EXPRESSION=True) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_when_actor_has_permission_to_datasource__200(self): + self.logout() + self.login(username="gamma") + table = self.get_table("birth_names") + self.grant_role_access_to_table(table, "gamma") + # set required permissions to gamma role + role = security_manager.find_role("Gamma") + pvm1 = security_manager.find_permission_view_menu("can_sql_json", "Superset") + pvm2 = security_manager.find_permission_view_menu( + "database_access", "[examples].(id:1)" + ) + security_manager.add_permission_role(role, pvm1) + security_manager.add_permission_role(role, pvm2) + self.test_with_valid_qc__data_is_returned() + role = security_manager.find_role("Gamma") + table = self.get_table("birth_names") + pvm1 = security_manager.find_permission_view_menu("can_sql_json", "Superset") + pvm2 = security_manager.find_permission_view_menu( + "database_access", "[examples].(id:1)" + ) + security_manager.del_permission_role(role, pvm1) + security_manager.del_permission_role(role, pvm2) + self.revoke_role_access_to_table("Gamma", table) + + @with_feature_flags( + QUERY_CONTEXT_VALIDATION_SQL_EXPRESSION=True, ALLOW_ADHOC_SUBQUERY=True + ) + @pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", + "load_single_column_example_datasource", + ) + def test_when_actor_does_not_have_permission_to_metric_datasource__403(self): + self.logout() + self.login(username="gamma") + table = self.get_table("birth_names") + self.grant_role_access_to_table(table, "gamma") + metric = QueryContextGeneratorInteg.generate_sql_expression_metric( + column_name="name", table_name="single_column_example" + ) + self.query_context_payload["queries"][0]["metrics"].append(metric) + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 403 + self.revoke_role_access_to_table("Gamma", table) + + @pytest.mark.chart_data_flow + @with_feature_flags( + QUERY_CONTEXT_VALIDATION_SQL_EXPRESSION=True, ALLOW_ADHOC_SUBQUERY=True + ) + @pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", + "load_single_column_example_datasource", + ) + def test_when_actor_has_permission_to_metric_datasource__200(self): + self.logout() + self.login(username="gamma") + table = self.get_table("birth_names") + self.grant_role_access_to_table(table, "gamma") + second_temp_datasource = self.get_table("single_column_example") + self.grant_role_access_to_table(second_temp_datasource, "gamma") + metric = QueryContextGeneratorInteg.generate_sql_expression_metric( + column_name="name", table_name="single_column_example" + ) + self.query_context_payload["queries"][0]["metrics"] = [metric] + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 200 + self.revoke_role_access_to_table("Gamma", table) + self.revoke_role_access_to_table("Gamma", second_temp_datasource) + -@pytest.mark.chart_data_flow class TestGetChartDataApi(BaseTestChartDataApi): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_data_when_query_context_is_null(self):