Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve explore REST api validations #27395

Merged
merged 3 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions superset/commands/explore/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import simplejson as json
from flask import request
from flask_babel import lazy_gettext as _
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy.exc import SQLAlchemyError

from superset.commands.base import BaseCommand
Expand All @@ -37,6 +37,7 @@
from superset.exceptions import SupersetException
from superset.explore.exceptions import WrongEndpointError
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.extensions import security_manager
from superset.utils import core as utils
from superset.views.utils import (
get_datasource_info,
Expand All @@ -61,7 +62,6 @@
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def run(self) -> Optional[dict[str, Any]]:
initial_form_data = {}

if self._permalink_key is not None:
command = GetExplorePermalinkCommand(self._permalink_key)
permalink_value = command.run()
Expand Down Expand Up @@ -110,12 +110,19 @@
self._datasource_type = SqlaTable.type

datasource: Optional[BaseDatasource] = None

if self._datasource_id is not None:
with contextlib.suppress(DatasourceNotFound):
datasource = DatasourceDAO.get_datasource(
cast(str, self._datasource_type), self._datasource_id
)
datasource_name = datasource.name if datasource else _("[Missing Dataset]")

datasource_name = _("[Missing Dataset]")

if datasource:
datasource_name = datasource.name
security_manager.can_access_datasource(datasource)

Check warning on line 124 in superset/commands/explore/get.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/explore/get.py#L123-L124

Added lines #L123 - L124 were not covered by tests

viz_type = form_data.get("viz_type")
if not viz_type and datasource and datasource.default_endpoint:
raise WrongEndpointError(redirect=datasource.default_endpoint)
Expand Down
20 changes: 19 additions & 1 deletion tests/integration_tests/explore/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_get_from_permalink_unknown_key(test_client, login_as_admin):


@patch("superset.security.SupersetSecurityManager.can_access_datasource")
def test_get_dataset_access_denied(
def test_get_dataset_access_denied_with_form_data_key(
mock_can_access_datasource, test_client, login_as_admin, dataset
):
message = "Dataset access denied"
Expand All @@ -214,6 +214,24 @@ def test_get_dataset_access_denied(
assert data["message"] == message


@patch("superset.security.SupersetSecurityManager.can_access_datasource")
def test_get_dataset_access_denied(
mock_can_access_datasource, test_client, login_as_admin, dataset
):
message = "Dataset access denied"
mock_can_access_datasource.side_effect = DatasetAccessDeniedError(
message=message, datasource_id=dataset.id, datasource_type=dataset.type
)
resp = test_client.get(
f"api/v1/explore/?datasource_id={dataset.id}&datasource_type={dataset.type}"
)
data = json.loads(resp.data.decode("utf-8"))
assert resp.status_code == 403
assert data["datasource_id"] == dataset.id
assert data["datasource_type"] == dataset.type
assert data["message"] == message


@patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_wrong_endpoint(mock_get_datasource, test_client, login_as_admin, dataset):
dataset.default_endpoint = "another_endpoint"
Expand Down
Loading