diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 22516934db139..311889b090596 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -963,14 +963,28 @@ def _get_sqla_row_level_filters( :returns: A list of SQL clauses to be ANDed together. :rtype: List[str] """ - filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list) + all_filters: List[TextClause] = [] + filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) try: for filter_ in security_manager.get_rls_filters(self): clause = self.text( f"({template_processor.process_template(filter_.clause)})" ) - filters_grouped[filter_.group_key or filter_.id].append(clause) - return [or_(*clauses) for clauses in filters_grouped.values()] + if filter_.group_key: + filter_groups[filter_.group_key].append(clause) + else: + all_filters.append(clause) + + if is_feature_enabled("EMBEDDED_SUPERSET"): + for rule in security_manager.get_guest_rls_filters(self): + clause = self.text( + f"({template_processor.process_template(rule['clause'])})" + ) + all_filters.append(clause) + + grouped_filters = [or_(*clauses) for clauses in filter_groups.values()] + all_filters.extend(grouped_filters) + return all_filters except TemplateError as ex: raise QueryObjectValidationError( _("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,) diff --git a/superset/security/api.py b/superset/security/api.py index 54efcd07e0dbd..b919e29f78ddd 100644 --- a/superset/security/api.py +++ b/superset/security/api.py @@ -15,34 +15,59 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import Any, Dict from flask import request, Response from flask_appbuilder import expose from flask_appbuilder.api import BaseApi, safe from flask_appbuilder.security.decorators import permission_name, protect from flask_wtf.csrf import generate_csrf -from marshmallow import fields, Schema, ValidationError +from marshmallow import EXCLUDE, fields, post_load, Schema, ValidationError +from marshmallow_enum import EnumField from superset.extensions import event_logger +from superset.security.guest_token import GuestTokenResourceType logger = logging.getLogger(__name__) -class UserSchema(Schema): +class PermissiveSchema(Schema): + """ + A marshmallow schema that ignores unexpected fields, instead of throwing an error. + """ + + class Meta: # pylint: disable=too-few-public-methods + unknown = EXCLUDE + + +class UserSchema(PermissiveSchema): username = fields.String() first_name = fields.String() last_name = fields.String() -class ResourceSchema(Schema): - type = fields.String(required=True) # todo figure out how to make this an enum +class ResourceSchema(PermissiveSchema): + type = EnumField(GuestTokenResourceType, by_value=True, required=True) id = fields.String(required=True) - rls = fields.String() + + @post_load + def convert_enum_to_value( # pylint: disable=no-self-use + self, data: Dict[str, Any], **kwargs: Any # pylint: disable=unused-argument + ) -> Dict[str, Any]: + # we don't care about the enum, we want the value inside + data["type"] = data["type"].value + return data + + +class RlsRuleSchema(PermissiveSchema): + dataset = fields.Integer() + clause = fields.String(required=True) # todo other options? -class GuestTokenCreateSchema(Schema): +class GuestTokenCreateSchema(PermissiveSchema): user = fields.Nested(UserSchema) - resource = fields.Nested(ResourceSchema, required=True) + resources = fields.List(fields.Nested(ResourceSchema), required=True) + rls = fields.List(fields.Nested(RlsRuleSchema), required=True) guest_token_create_schema = GuestTokenCreateSchema() @@ -117,12 +142,12 @@ def guest_token(self) -> Response: """ try: body = guest_token_create_schema.load(request.json) - # validate stuff: - # make sure the resource id is valid + # todo validate stuff: + # make sure the resource ids are valid # make sure username doesn't reference an existing user # check rls rules for validity? token = self.appbuilder.sm.create_guest_access_token( - body["user"], [body["resource"]] + body["user"], body["resources"], body["rls"] ) return self.response(200, token=token) except ValidationError as error: diff --git a/superset/security/guest_token.py b/superset/security/guest_token.py index 60add8175400d..af86da326288c 100644 --- a/superset/security/guest_token.py +++ b/superset/security/guest_token.py @@ -34,17 +34,22 @@ class GuestTokenResourceType(Enum): class GuestTokenResource(TypedDict): type: GuestTokenResourceType id: Union[str, int] - rls: Optional[str] GuestTokenResources = List[GuestTokenResource] +class GuestTokenRlsRule(TypedDict): + dataset: Optional[str] + clause: str + + class GuestToken(TypedDict): iat: float exp: float user: GuestTokenUser resources: GuestTokenResources + rls_rules: List[GuestTokenRlsRule] class GuestUser(AnonymousUserMixin): @@ -79,3 +84,4 @@ def __init__(self, token: GuestToken, roles: List[Role]): self.last_name = user.get("last_name", "User") self.roles = roles self.resources = token["resources"] + self.rls = token.get("rls_rules", []) diff --git a/superset/security/manager.py b/superset/security/manager.py index 5993f952f3473..5ca81b2a9546e 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -69,6 +69,7 @@ GuestToken, GuestTokenResources, GuestTokenResourceType, + GuestTokenRlsRule, GuestTokenUser, GuestUser, ) @@ -1111,6 +1112,25 @@ def get_user_roles(self, user: Optional[User] = None) -> List[Role]: return [self.get_public_role()] if public_role else [] return user.roles + def get_guest_rls_filters( + self, dataset: "BaseDatasource" + ) -> List[GuestTokenRlsRule]: + """ + Retrieves the row level security filters for the current user and the dataset, + if the user is authenticated with a guest token. + :param dataset: The dataset to check against + :return: A list of filters + """ + guest_user = self.get_current_guest_user_if_guest() + if guest_user: + return [ + rule + for rule in guest_user.rls + if not rule.get("dataset") + or str(rule.get("dataset")) == str(dataset.id) + ] + return [] + def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: """ Retrieves the appropriate row level security filters for the current user and @@ -1119,7 +1139,7 @@ def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: :param table: The table to check against :returns: A list of filters """ - if hasattr(g, "user") and hasattr(g.user, "id"): + if hasattr(g, "user"): # pylint: disable=import-outside-toplevel from superset.connectors.sqla.models import ( RLSFilterRoles, @@ -1127,11 +1147,7 @@ def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: RowLevelSecurityFilter, ) - user_roles = ( - self.get_session.query(assoc_user_role.c.role_id) - .filter(assoc_user_role.c.user_id == g.user.get_id()) - .subquery() - ) + user_roles = [role.id for role in self.get_user_roles()] regular_filter_roles = ( self.get_session.query(RLSFilterRoles.c.rls_filter_id) .join(RowLevelSecurityFilter) @@ -1274,7 +1290,10 @@ def _get_current_epoch_time() -> float: return time.time() def create_guest_access_token( - self, user: GuestTokenUser, resources: GuestTokenResources + self, + user: GuestTokenUser, + resources: GuestTokenResources, + rls: List[GuestTokenRlsRule], ) -> bytes: secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] @@ -1286,6 +1305,7 @@ def create_guest_access_token( claims = { "user": user, "resources": resources, + "rls_rules": rls, # standard jwt claims: "iat": now, # issued at "exp": exp, # expiration time @@ -1312,6 +1332,8 @@ def get_guest_user_from_request(self, req: Request) -> Optional[GuestUser]: raise ValueError("Guest token does not contain a user claim") if token.get("resources") is None: raise ValueError("Guest token does not contain a resources claim") + if token.get("rls_rules") is None: + raise ValueError("Guest token does not contain an rls_rules claim") except Exception: # pylint: disable=broad-except # The login manager will handle sending 401s. # We don't need to send a special error message. diff --git a/tests/integration_tests/security/api_tests.py b/tests/integration_tests/security/api_tests.py index fcacd7ce668f2..d7b365985d9b2 100644 --- a/tests/integration_tests/security/api_tests.py +++ b/tests/integration_tests/security/api_tests.py @@ -77,18 +77,19 @@ def test_post_guest_token_unauthorized(self): response = self.client.post(self.uri) self.assert403(response) - def test_post_embed_token_authorized(self): + def test_post_guest_token_authorized(self): self.login(username="admin") user = {"username": "bob", "first_name": "Bob", "last_name": "Also Bob"} - resource = {"type": "dashboard", "id": "blah", "rls": "1 = 1"} - params = {"user": user, "resource": resource} + resource = {"type": "dashboard", "id": "blah"} + rls_rule = {"dataset": 1, "clause": "1=1"} + params = {"user": user, "resources": [resource], "rls": [rls_rule]} response = self.client.post( self.uri, data=json.dumps(params), content_type="application/json" ) + self.assert200(response) token = json.loads(response.data)["token"] decoded_token = jwt.decode(token, self.app.config["GUEST_TOKEN_JWT_SECRET"]) - self.assertEqual(user, decoded_token["user"]) self.assertEqual(resource, decoded_token["resources"][0]) diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 665666cb61f5b..06610bbd06de8 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -16,13 +16,19 @@ # under the License. # isort:skip_file import re -from typing import Any, Dict +from typing import Any, Dict, List, Optional +from unittest import mock import pytest from flask import g from superset import db, security_manager from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable +from superset.security.guest_token import ( + GuestTokenRlsRule, + GuestTokenResourceType, + GuestUser, +) from ..base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, @@ -66,11 +72,11 @@ def setUp(self): session = db.session # Create roles - security_manager.add_role(self.NAME_AB_ROLE) - security_manager.add_role(self.NAME_Q_ROLE) + self.role_ab = security_manager.add_role(self.NAME_AB_ROLE) + self.role_q = security_manager.add_role(self.NAME_Q_ROLE) gamma_user = security_manager.find_user(username="gamma") - gamma_user.roles.append(security_manager.find_role(self.NAME_AB_ROLE)) - gamma_user.roles.append(security_manager.find_role(self.NAME_Q_ROLE)) + gamma_user.roles.append(self.role_ab) + gamma_user.roles.append(self.role_q) self.create_user_with_roles("NoRlsRoleUser", ["Gamma"]) session.commit() @@ -205,3 +211,106 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self): assert not self.NAMES_B_REGEX.search(sql) assert not self.NAMES_Q_REGEX.search(sql) assert not self.BASE_FILTER_REGEX.search(sql) + + +RLS_ALICE_REGEX = re.compile(r"name = 'Alice'") +RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)") + + +@mock.patch.dict( + "superset.extensions.feature_flag_manager._feature_flags", EMBEDDED_SUPERSET=True, +) +class GuestTokenRowLevelSecurityTests(SupersetTestCase): + query_obj: Dict[str, Any] = dict( + groupby=[], + metrics=None, + filter=[], + is_timeseries=False, + columns=["value"], + granularity=None, + from_dttm=None, + to_dttm=None, + extras={}, + ) + + def default_rls_rule(self): + return { + "dataset": self.get_table(name="birth_names").id, + "clause": "name = 'Alice'", + } + + def guest_user_with_rls(self, rules: Optional[List[Any]] = None) -> GuestUser: + if rules is None: + rules = [self.default_rls_rule()] + return security_manager.get_guest_user_from_token( + { + "user": {}, + "resources": [{"type": GuestTokenResourceType.DASHBOARD.value}], + "rls_rules": rules, + } + ) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_rls_filter_alters_query(self): + g.user = self.guest_user_with_rls() + tbl = self.get_table(name="birth_names") + sql = tbl.get_query_str(self.query_obj) + + self.assertRegexpMatches(sql, RLS_ALICE_REGEX) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_rls_filter_does_not_alter_unrelated_query(self): + g.user = self.guest_user_with_rls( + rules=[ + { + "dataset": self.get_table(name="birth_names").id + 1, + "clause": "name = 'Alice'", + } + ] + ) + tbl = self.get_table(name="birth_names") + sql = tbl.get_query_str(self.query_obj) + + self.assertNotRegexpMatches(sql, RLS_ALICE_REGEX) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_multiple_rls_filters_are_unionized(self): + g.user = self.guest_user_with_rls( + rules=[ + self.default_rls_rule(), + { + "dataset": self.get_table(name="birth_names").id, + "clause": "gender = 'girl'", + }, + ] + ) + tbl = self.get_table(name="birth_names") + sql = tbl.get_query_str(self.query_obj) + + self.assertRegexpMatches(sql, RLS_ALICE_REGEX) + self.assertRegexpMatches(sql, RLS_GENDER_REGEX) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_rls_filter_for_all_datasets(self): + births = self.get_table(name="birth_names") + energy = self.get_table(name="energy_usage") + guest = self.guest_user_with_rls(rules=[{"clause": "name = 'Alice'"}]) + guest.resources.append({type: "dashboard", id: energy.id}) + g.user = guest + births_sql = births.get_query_str(self.query_obj) + energy_sql = energy.get_query_str(self.query_obj) + + self.assertRegexpMatches(births_sql, RLS_ALICE_REGEX) + self.assertRegexpMatches(energy_sql, RLS_ALICE_REGEX) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_dataset_id_can_be_string(self): + dataset = self.get_table(name="birth_names") + str_id = str(dataset.id) + g.user = self.guest_user_with_rls( + rules=[{"dataset": str_id, "clause": "name = 'Alice'"}] + ) + sql = dataset.get_query_str(self.query_obj) + + self.assertRegexpMatches(sql, RLS_ALICE_REGEX) diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index 2f5ad65aaea6a..46ca679deebc8 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1162,14 +1162,22 @@ class FakeRequest: class TestGuestTokens(SupersetTestCase): + def create_guest_token(self): + user = {"username": "test_guest"} + resources = [{"some": "resource"}] + rls = [{"dataset": 1, "clause": "access = 1"}] + return security_manager.create_guest_access_token(user, resources, rls) + @patch("superset.security.SupersetSecurityManager._get_current_epoch_time") def test_create_guest_access_token(self, get_time_mock): now = time.time() get_time_mock.return_value = now # so we know what it should = - user = {"any": "data"} + + user = {"username": "test_guest"} resources = [{"some": "resource"}] + rls = [{"dataset": 1, "clause": "access = 1"}] + token = security_manager.create_guest_access_token(user, resources, rls) - token = security_manager.create_guest_access_token(user, resources) # unfortunately we cannot mock time in the jwt lib decoded_token = jwt.decode( token, @@ -1186,9 +1194,7 @@ def test_create_guest_access_token(self, get_time_mock): ) def test_get_guest_user(self): - user = {"username": "test_guest"} - resources = [{"type": "dashboard", "id": 1}] - token = security_manager.create_guest_access_token(user, resources) + token = self.create_guest_token() fake_request = FakeRequest() fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token @@ -1203,9 +1209,7 @@ def test_get_guest_user_expired_token(self, get_time_mock): get_time_mock.return_value = ( time.time() - (self.app.config["GUEST_TOKEN_JWT_EXP_SECONDS"] * 1000) - 1 ) - user = {"username": "test_guest"} - resources = [{"type": "dashboard", "id": 1}] - token = security_manager.create_guest_access_token(user, resources) + token = self.create_guest_token() fake_request = FakeRequest() fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token @@ -1216,7 +1220,8 @@ def test_get_guest_user_expired_token(self, get_time_mock): def test_get_guest_user_no_user(self): user = None resources = [{"type": "dashboard", "id": 1}] - token = security_manager.create_guest_access_token(user, resources) + rls = {} + token = security_manager.create_guest_access_token(user, resources, rls) fake_request = FakeRequest() fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token guest_user = security_manager.get_guest_user_from_request(fake_request) @@ -1227,10 +1232,11 @@ def test_get_guest_user_no_user(self): def test_get_guest_user_no_resource(self): user = {"username": "test_guest"} resources = [] - token = security_manager.create_guest_access_token(user, resources) + rls = {} + token = security_manager.create_guest_access_token(user, resources, rls) fake_request = FakeRequest() fake_request.headers[current_app.config["GUEST_TOKEN_HEADER_NAME"]] = token - guest_user = security_manager.get_guest_user_from_request(fake_request) + security_manager.get_guest_user_from_request(fake_request) self.assertRaisesRegex( ValueError, "Guest token does not contain a resources claim"