diff --git a/superset/security/api.py b/superset/security/api.py index 97720a64f3c61..e1ba8c50c25e6 100644 --- a/superset/security/api.py +++ b/superset/security/api.py @@ -62,7 +62,7 @@ def convert_enum_to_value( class RlsRuleSchema(PermissiveSchema): - dataset = fields.Integer(required=True) # todo make this optional when possible + dataset = fields.Integer() clause = fields.String(required=True) # todo other options? diff --git a/superset/security/guest_token.py b/superset/security/guest_token.py index c1f7ea9e9677c..13a88a10fe2e4 100644 --- a/superset/security/guest_token.py +++ b/superset/security/guest_token.py @@ -40,7 +40,7 @@ class GuestTokenResource(TypedDict): class GuestTokenRlsRule(TypedDict): - dataset: str + dataset: Optional[str] clause: str diff --git a/superset/security/manager.py b/superset/security/manager.py index 01a5d084b1f80..5ca81b2a9546e 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1123,7 +1123,12 @@ def get_guest_rls_filters( """ guest_user = self.get_current_guest_user_if_guest() if guest_user: - return [rule for rule in guest_user.rls if rule["dataset"] == dataset.id] + return [ + rule + for rule in guest_user.rls + if not rule.get("dataset") + or str(rule.get("dataset")) == str(dataset.id) + ] return [] def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index d933fc3fec253..b7db7242af037 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -16,7 +16,7 @@ # under the License. # isort:skip_file import re -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from unittest import mock import pytest @@ -24,7 +24,11 @@ from superset import db, security_manager from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable -from superset.security.guest_token import GuestTokenRlsRule, GuestTokenResourceType +from superset.security.guest_token import ( + GuestTokenRlsRule, + GuestTokenResourceType, + GuestUser, +) from ..base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, @@ -210,7 +214,7 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self): RLS_ALICE_REGEX = re.compile(r"name = 'Alice'") -RLS_GENDER_REGEX = re.compile(r"gender = 'girl'") +RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)") @mock.patch.dict( @@ -223,7 +227,7 @@ def default_rls_rule(self): "clause": "name = 'Alice'", } - def guest_user_with_rls(self, rules: List[Any] = None): + def guest_user_with_rls(self, rules: Optional[List[Any]] = None) -> GuestUser: if rules is None: rules = [self.default_rls_rule()] return security_manager.get_guest_user_from_token( @@ -273,3 +277,28 @@ def test_multiple_rls_filters_are_unionized(self): self.assertRegexpMatches(sql, RLS_ALICE_REGEX) self.assertRegexpMatches(sql, RLS_GENDER_REGEX) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_rls_filter_for_all_datasets(self): + births = self.get_table(name="birth_names") + energy = self.get_table(name="energy_usage") + guest = self.guest_user_with_rls(rules=[{"clause": "name = 'Alice'"}]) + guest.resources.append({type: "dashboard", id: energy.id}) + g.user = guest + births_sql = births.get_query_str(query_obj) + energy_sql = energy.get_query_str(query_obj) + + self.assertRegexpMatches(births_sql, RLS_ALICE_REGEX) + self.assertRegexpMatches(energy_sql, RLS_ALICE_REGEX) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_dataset_id_can_be_string(self): + dataset = self.get_table(name="birth_names") + str_id = str(dataset.id) + g.user = self.guest_user_with_rls( + rules=[{"dataset": str_id, "clause": "name = 'Alice'"}] + ) + sql = dataset.get_query_str(query_obj) + + self.assertRegexpMatches(sql, RLS_ALICE_REGEX)