Skip to content

Commit

Permalink
add new query_context_validation to the chart data flows
Browse files Browse the repository at this point in the history
  • Loading branch information
ofekisr committed Apr 18, 2022
1 parent ae041fa commit 8fbce8e
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 19 deletions.
23 changes: 15 additions & 8 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@
from superset.views.base import CsvResponse, generate_download_headers
from superset.views.base_api import statsd_metrics

from .commands.command_factory import GetChartDataCommandFactory

if TYPE_CHECKING:
from superset.common.query_context import QueryContext

logger = logging.getLogger(__name__)


class ChartDataRestApi(ChartRestApi):

include_route_methods = {"get_data", "data", "data_from_cache"}

@expose("/<int:pk>/data/", methods=["GET"])
Expand Down Expand Up @@ -133,7 +136,7 @@ def get_data(self, pk: int) -> Response:

try:
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command = GetChartDataCommandFactory.make(query_context)
command.validate()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
Expand All @@ -150,15 +153,17 @@ def get_data(self, pk: int) -> Response:
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
return self._run_async(json_body, command)
return self._run_async(json_body, command) # type: ignore

try:
form_data = json.loads(chart.params)
except (TypeError, json.decoder.JSONDecodeError):
form_data = {}

return self._get_data_response(
command=command, form_data=form_data, datasource=query_context.datasource
command=command, # type: ignore
form_data=form_data,
datasource=query_context.datasource,
)

@expose("/data", methods=["POST"])
Expand Down Expand Up @@ -221,7 +226,7 @@ def data(self) -> Response:

try:
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command = GetChartDataCommandFactory.make(query_context)
command.validate()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
Expand All @@ -238,11 +243,13 @@ def data(self) -> Response:
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
return self._run_async(json_body, command)
return self._run_async(json_body, command) # type: ignore

form_data = json_body.get("form_data")
return self._get_data_response(
command, form_data=form_data, datasource=query_context.datasource
command, # type: ignore
form_data=form_data,
datasource=query_context.datasource,
)

@expose("/data/<cache_key>", methods=["GET"])
Expand Down Expand Up @@ -288,7 +295,7 @@ def data_from_cache(self, cache_key: str) -> Response:
try:
cached_data = self._load_query_context_form_from_cache(cache_key)
query_context = self._create_query_context_from_form(cached_data)
command = ChartDataCommand(query_context)
command = GetChartDataCommandFactory.make(query_context)
command.validate()
except ChartDataCacheLoadError:
return self.response_404()
Expand All @@ -297,7 +304,7 @@ def data_from_cache(self, cache_key: str) -> Response:
message=_("Request is incorrect: %(error)s", error=error.messages)
)

return self._get_data_response(command, True)
return self._get_data_response(command, True) # type: ignore

def _run_async(
self, form_data: Dict[str, Any], command: ChartDataCommand
Expand Down
13 changes: 10 additions & 3 deletions superset/charts/data/commands/get_data_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
from typing import Any, Dict
from typing import Any, Dict, TYPE_CHECKING

from superset.charts.commands.exceptions import (
ChartDataCacheLoadError,
Expand All @@ -27,12 +29,17 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from ..query_context_validators import QueryContextValidator


class ChartDataCommand(BaseCommand):
_query_context: QueryContext
_validator: QueryContextValidator

def __init__(self, query_context: QueryContext):
def __init__(self, query_context: QueryContext, validator: QueryContextValidator):
self._query_context = query_context
self._validator = validator

def run(self, **kwargs: Any) -> Dict[str, Any]:
# caching is handled in query_context.get_df_payload
Expand Down Expand Up @@ -61,4 +68,4 @@ def run(self, **kwargs: Any) -> Dict[str, Any]:
return return_value

def validate(self) -> None:
self._query_context.raise_for_access()
self._validator.validate(self._query_context)
21 changes: 21 additions & 0 deletions superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def init_app_in_ctx(self) -> None:
self.configure_url_map_converters()
self.configure_data_sources()
self.configure_auth_provider()
self.init_factories()
self.configure_async_queries()

# Hook that provides administrators a handle on the Flask APP
Expand All @@ -483,6 +484,9 @@ def init_app_in_ctx(self) -> None:

self.init_views()

def init_factories(self) -> None:
self._init_chart_data_command_factory()

def check_secret_key(self) -> None:
if self.config["SECRET_KEY"] == CHANGE_ME_SECRET_KEY:
top_banner = 80 * "-" + "\n" + 36 * " " + "WARNING\n" + 80 * "-"
Expand Down Expand Up @@ -667,6 +671,23 @@ def enable_profiling(self) -> None:
if self.config["PROFILING"]:
profiling.init_app(self.superset_app)

@staticmethod
def _init_chart_data_command_factory() -> None: # pylint: disable=invalid-name
# pylint: disable=import-outside-toplevel
from superset import security_manager
from superset.charts.data.commands.command_factory import (
GetChartDataCommandFactory,
)
from superset.charts.data.query_context_validators.validaor_factory import (
QueryContextValidatorFactory,
)
from superset.datasets.dao import DatasetDAO

query_context_validator_factory = QueryContextValidatorFactory(
security_manager, DatasetDAO() # type: ignore
)
GetChartDataCommandFactory.init(query_context_validator_factory)


class SupersetIndexView(IndexView):
@expose("/")
Expand Down
8 changes: 3 additions & 5 deletions superset/tasks/async_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from flask import current_app, g
from marshmallow import ValidationError

from superset.charts.data.commands.command_factory import GetChartDataCommandFactory
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.exceptions import SupersetVizException
from superset.extensions import (
Expand Down Expand Up @@ -73,15 +74,12 @@ def load_chart_data_into_cache(
job_metadata: Dict[str, Any],
form_data: Dict[str, Any],
) -> None:
# pylint: disable=import-outside-toplevel
from superset.charts.data.commands.get_data_command import ChartDataCommand

try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
command = ChartDataCommand(query_context)
result = command.run(cache=True)
command = GetChartDataCommandFactory.make(query_context)
result = command.run(cache=True) # type: ignore
cache_key = result["cache_key"]
result_url = f"/api/v1/chart/data/{cache_key}"
async_query_manager.update_job(
Expand Down
86 changes: 83 additions & 3 deletions tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from tests.integration_tests.fixtures.single_column_table import (
load_single_column_example_datasource,
)
from tests.integration_tests.test_app import app

import pytest

from superset.charts.data.commands.get_data_command import ChartDataCommand
from superset.connectors.sqla.models import TableColumn, SqlaTable
from superset.errors import SupersetErrorType
from superset.extensions import async_query_manager, db
from superset.extensions import async_query_manager, db, security_manager
from superset.models.annotations import AnnotationLayer
from superset.models.slice import Slice
from superset.superset_typing import AdhocColumn
Expand All @@ -57,7 +60,10 @@
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType

from tests.common.query_context_generator import ANNOTATION_LAYERS
from tests.integration_tests.fixtures.query_context import get_query_context
from tests.integration_tests.fixtures.query_context import (
QueryContextGeneratorInteg,
get_query_context,
)


CHART_DATA_URI = "api/v1/chart/data"
Expand Down Expand Up @@ -766,8 +772,82 @@ def test_with_virtual_table_with_colons_as_datasource(self):
assert "':xyz:qwerty'" in result["query"]
assert "':qwerty:'" in result["query"]

@with_feature_flags(QUERY_CONTEXT_VALIDATION_SQL_EXPRESSION=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_when_actor_does_not_have_permission_to_datasource__403(self):
self.test_with_not_permitted_actor__403()

@with_feature_flags(QUERY_CONTEXT_VALIDATION_SQL_EXPRESSION=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_when_actor_has_permission_to_datasource__200(self):
self.logout()
self.login(username="gamma")
table = self.get_table("birth_names")
self.grant_role_access_to_table(table, "gamma")
# set required permissions to gamma role
role = security_manager.find_role("Gamma")
pvm1 = security_manager.find_permission_view_menu("can_sql_json", "Superset")
pvm2 = security_manager.find_permission_view_menu(
"database_access", "[examples].(id:1)"
)
security_manager.add_permission_role(role, pvm1)
security_manager.add_permission_role(role, pvm2)
self.test_with_valid_qc__data_is_returned()
role = security_manager.find_role("Gamma")
table = self.get_table("birth_names")
pvm1 = security_manager.find_permission_view_menu("can_sql_json", "Superset")
pvm2 = security_manager.find_permission_view_menu(
"database_access", "[examples].(id:1)"
)
security_manager.del_permission_role(role, pvm1)
security_manager.del_permission_role(role, pvm2)
self.revoke_role_access_to_table("Gamma", table)

@with_feature_flags(
QUERY_CONTEXT_VALIDATION_SQL_EXPRESSION=True, ALLOW_ADHOC_SUBQUERY=True
)
@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices",
"load_single_column_example_datasource",
)
def test_when_actor_does_not_have_permission_to_metric_datasource__403(self):
self.logout()
self.login(username="gamma")
table = self.get_table("birth_names")
self.grant_role_access_to_table(table, "gamma")
metric = QueryContextGeneratorInteg.generate_sql_expression_metric(
column_name="name", table_name="single_column_example"
)
self.query_context_payload["queries"][0]["metrics"].append(metric)
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 403
self.revoke_role_access_to_table("Gamma", table)

@pytest.mark.chart_data_flow
@with_feature_flags(
QUERY_CONTEXT_VALIDATION_SQL_EXPRESSION=True, ALLOW_ADHOC_SUBQUERY=True
)
@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices",
"load_single_column_example_datasource",
)
def test_when_actor_has_permission_to_metric_datasource__200(self):
self.logout()
self.login(username="gamma")
table = self.get_table("birth_names")
self.grant_role_access_to_table(table, "gamma")
second_temp_datasource = self.get_table("single_column_example")
self.grant_role_access_to_table(second_temp_datasource, "gamma")
metric = QueryContextGeneratorInteg.generate_sql_expression_metric(
column_name="name", table_name="single_column_example"
)
self.query_context_payload["queries"][0]["metrics"] = [metric]
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 200
self.revoke_role_access_to_table("Gamma", table)
self.revoke_role_access_to_table("Gamma", second_temp_datasource)


@pytest.mark.chart_data_flow
class TestGetChartDataApi(BaseTestChartDataApi):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_data_when_query_context_is_null(self):
Expand Down

0 comments on commit 8fbce8e

Please sign in to comment.