Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Row Level Security rules for guest tokens #17836

Merged
merged 43 commits into from
Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
a21a417
helper methods and dashboard access
suddjian Dec 7, 2021
ad3572e
guest token dashboard authz
suddjian Dec 11, 2021
4fd8715
adjust csrf exempt list
suddjian Dec 11, 2021
7353826
eums don't work that way
suddjian Dec 11, 2021
5c0a7f5
Remove unnecessary import
suddjian Dec 20, 2021
2141f2d
move row level security tests to their own file
suddjian Dec 22, 2021
bfdbad5
a bit of refactoring
suddjian Dec 22, 2021
36a1224
add guest token security tests
suddjian Dec 22, 2021
7181333
refactor tests
suddjian Jan 6, 2022
27a2ec3
clean imports
suddjian Jan 6, 2022
09aed96
variable names can be too long apparently
suddjian Jan 13, 2022
c7bfa96
missing argument to get_user_roles
suddjian Jan 13, 2022
aa672b9
don't redefine builtins
suddjian Jan 13, 2022
091786d
remove unused imports
suddjian Jan 13, 2022
5e23ce9
fix test import
suddjian Jan 13, 2022
a62b2f5
default to global user when getting roles
suddjian Jan 13, 2022
e9d50c2
Merge branch 'embedded' into guest-token-authz
suddjian Jan 13, 2022
67affe4
Merge branch 'embedded' into guest-token-authz
suddjian Jan 13, 2022
4d5a691
missing import
suddjian Jan 13, 2022
13a2038
mock it
suddjian Jan 13, 2022
0be28da
test get_user_roles
suddjian Jan 13, 2022
c5bded9
infer g.user for ease of tests
suddjian Jan 14, 2022
ebf9400
remove redundant check
suddjian Jan 14, 2022
deceb33
tests for guest user security manager fns
suddjian Jan 14, 2022
128f3a0
use algo to get rid of warning messages
suddjian Jan 14, 2022
ede1367
tweaking access checks
suddjian Jan 18, 2022
38a89ae
fix guest token security tests
suddjian Jan 18, 2022
1927a1c
missing imports
suddjian Jan 18, 2022
128b90c
more tests
suddjian Jan 18, 2022
b73a81f
more testing and also some small refactoring
suddjian Jan 20, 2022
eccec4d
move validation out of parsing
suddjian Jan 20, 2022
e833ef5
fix dashboard access check again
suddjian Jan 20, 2022
1527c41
rls rules for guest tokens
suddjian Jan 4, 2022
07debee
test guest token rls rules
suddjian Jan 20, 2022
f16ae10
more flexible rls rules
suddjian Jan 20, 2022
bfafc13
lint
suddjian Jan 21, 2022
bf302f4
fix tests
suddjian Jan 21, 2022
bdd842b
fix test
suddjian Jan 21, 2022
28b69df
defaults
suddjian Jan 21, 2022
f1d9b7c
Merge remote-tracking branch 'upstream/embedded' into guest-token-rls
Jan 21, 2022
1c41a74
fix some tests
Jan 21, 2022
0853ecc
fix some tests
Jan 21, 2022
ff3e034
lint
Jan 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,14 +963,28 @@ def _get_sqla_row_level_filters(
:returns: A list of SQL clauses to be ANDed together.
:rtype: List[str]
"""
filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list)
all_filters: List[TextClause] = []
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(filter_.clause)})"
)
filters_grouped[filter_.group_key or filter_.id].append(clause)
return [or_(*clauses) for clauses in filters_grouped.values()]
if filter_.group_key:
filter_groups[filter_.group_key].append(clause)
else:
all_filters.append(clause)

if is_feature_enabled("EMBEDDED_SUPERSET"):
for rule in security_manager.get_guest_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(rule['clause'])})"
)
all_filters.append(clause)

grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
all_filters.extend(grouped_filters)
return all_filters
except TemplateError as ex:
raise QueryObjectValidationError(
_("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,)
Expand Down
45 changes: 35 additions & 10 deletions superset/security/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,59 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict

from flask import request, Response
from flask_appbuilder import expose
from flask_appbuilder.api import BaseApi, safe
from flask_appbuilder.security.decorators import permission_name, protect
from flask_wtf.csrf import generate_csrf
from marshmallow import fields, Schema, ValidationError
from marshmallow import EXCLUDE, fields, post_load, Schema, ValidationError
from marshmallow_enum import EnumField

from superset.extensions import event_logger
from superset.security.guest_token import GuestTokenResourceType

logger = logging.getLogger(__name__)


class UserSchema(Schema):
class PermissiveSchema(Schema):
"""
A marshmallow schema that ignores unexpected fields, instead of throwing an error.
"""

class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE


class UserSchema(PermissiveSchema):
username = fields.String()
first_name = fields.String()
last_name = fields.String()


class ResourceSchema(Schema):
type = fields.String(required=True) # todo figure out how to make this an enum
class ResourceSchema(PermissiveSchema):
type = EnumField(GuestTokenResourceType, by_value=True, required=True)
id = fields.String(required=True)
rls = fields.String()

@post_load
def convert_enum_to_value( # pylint: disable=no-self-use
self, data: Dict[str, Any], **kwargs: Any # pylint: disable=unused-argument
) -> Dict[str, Any]:
# we don't care about the enum, we want the value inside
data["type"] = data["type"].value
return data


class RlsRuleSchema(PermissiveSchema):
dataset = fields.Integer()
clause = fields.String(required=True) # todo other options?


class GuestTokenCreateSchema(Schema):
class GuestTokenCreateSchema(PermissiveSchema):
user = fields.Nested(UserSchema)
resource = fields.Nested(ResourceSchema, required=True)
resources = fields.List(fields.Nested(ResourceSchema), required=True)
rls = fields.List(fields.Nested(RlsRuleSchema), required=True)


guest_token_create_schema = GuestTokenCreateSchema()
Expand Down Expand Up @@ -117,12 +142,12 @@ def guest_token(self) -> Response:
"""
try:
body = guest_token_create_schema.load(request.json)
# validate stuff:
# make sure the resource id is valid
# todo validate stuff:
# make sure the resource ids are valid
# make sure username doesn't reference an existing user
# check rls rules for validity?
token = self.appbuilder.sm.create_guest_access_token(
body["user"], [body["resource"]]
body["user"], body["resources"], body["rls"]
)
return self.response(200, token=token)
except ValidationError as error:
Expand Down
8 changes: 7 additions & 1 deletion superset/security/guest_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@ class GuestTokenResourceType(Enum):
class GuestTokenResource(TypedDict):
type: GuestTokenResourceType
id: Union[str, int]
rls: Optional[str]


GuestTokenResources = List[GuestTokenResource]


class GuestTokenRlsRule(TypedDict):
dataset: Optional[str]
clause: str


class GuestToken(TypedDict):
iat: float
exp: float
user: GuestTokenUser
resources: GuestTokenResources
rls_rules: List[GuestTokenRlsRule]


class GuestUser(AnonymousUserMixin):
Expand Down Expand Up @@ -79,3 +84,4 @@ def __init__(self, token: GuestToken, roles: List[Role]):
self.last_name = user.get("last_name", "User")
self.roles = roles
self.resources = token["resources"]
self.rls = token.get("rls_rules", [])
36 changes: 29 additions & 7 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
GuestToken,
GuestTokenResources,
GuestTokenResourceType,
GuestTokenRlsRule,
GuestTokenUser,
GuestUser,
)
Expand Down Expand Up @@ -1111,6 +1112,25 @@ def get_user_roles(self, user: Optional[User] = None) -> List[Role]:
return [self.get_public_role()] if public_role else []
return user.roles

def get_guest_rls_filters(
self, dataset: "BaseDatasource"
) -> List[GuestTokenRlsRule]:
"""
Retrieves the row level security filters for the current user and the dataset,
if the user is authenticated with a guest token.
:param dataset: The dataset to check against
:return: A list of filters
"""
guest_user = self.get_current_guest_user_if_guest()
if guest_user:
return [
rule
for rule in guest_user.rls
if not rule.get("dataset")
or str(rule.get("dataset")) == str(dataset.id)
]
return []

def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]:
"""
Retrieves the appropriate row level security filters for the current user and
Expand All @@ -1119,19 +1139,15 @@ def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]:
:param table: The table to check against
:returns: A list of filters
"""
if hasattr(g, "user") and hasattr(g.user, "id"):
if hasattr(g, "user"):
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import (
RLSFilterRoles,
RLSFilterTables,
RowLevelSecurityFilter,
)

user_roles = (
self.get_session.query(assoc_user_role.c.role_id)
.filter(assoc_user_role.c.user_id == g.user.get_id())
.subquery()
)
user_roles = [role.id for role in self.get_user_roles()]
regular_filter_roles = (
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
Expand Down Expand Up @@ -1274,7 +1290,10 @@ def _get_current_epoch_time() -> float:
return time.time()

def create_guest_access_token(
self, user: GuestTokenUser, resources: GuestTokenResources
self,
user: GuestTokenUser,
resources: GuestTokenResources,
rls: List[GuestTokenRlsRule],
) -> bytes:
secret = current_app.config["GUEST_TOKEN_JWT_SECRET"]
algo = current_app.config["GUEST_TOKEN_JWT_ALGO"]
Expand All @@ -1286,6 +1305,7 @@ def create_guest_access_token(
claims = {
"user": user,
"resources": resources,
"rls_rules": rls,
# standard jwt claims:
"iat": now, # issued at
"exp": exp, # expiration time
Expand All @@ -1312,6 +1332,8 @@ def get_guest_user_from_request(self, req: Request) -> Optional[GuestUser]:
raise ValueError("Guest token does not contain a user claim")
if token.get("resources") is None:
raise ValueError("Guest token does not contain a resources claim")
if token.get("rls_rules") is None:
raise ValueError("Guest token does not contain an rls_rules claim")
except Exception: # pylint: disable=broad-except
# The login manager will handle sending 401s.
# We don't need to send a special error message.
Expand Down
9 changes: 5 additions & 4 deletions tests/integration_tests/security/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,19 @@ def test_post_guest_token_unauthorized(self):
response = self.client.post(self.uri)
self.assert403(response)

def test_post_embed_token_authorized(self):
def test_post_guest_token_authorized(self):
self.login(username="admin")
user = {"username": "bob", "first_name": "Bob", "last_name": "Also Bob"}
resource = {"type": "dashboard", "id": "blah", "rls": "1 = 1"}
params = {"user": user, "resource": resource}
resource = {"type": "dashboard", "id": "blah"}
rls_rule = {"dataset": 1, "clause": "1=1"}
params = {"user": user, "resources": [resource], "rls": [rls_rule]}

response = self.client.post(
self.uri, data=json.dumps(params), content_type="application/json"
)

self.assert200(response)
token = json.loads(response.data)["token"]
decoded_token = jwt.decode(token, self.app.config["GUEST_TOKEN_JWT_SECRET"])

self.assertEqual(user, decoded_token["user"])
self.assertEqual(resource, decoded_token["resources"][0])
119 changes: 114 additions & 5 deletions tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@
# under the License.
# isort:skip_file
import re
from typing import Any, Dict
from typing import Any, Dict, List, Optional
from unittest import mock

import pytest
from flask import g

from superset import db, security_manager
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.security.guest_token import (
GuestTokenRlsRule,
GuestTokenResourceType,
GuestUser,
)
from ..base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
Expand Down Expand Up @@ -66,11 +72,11 @@ def setUp(self):
session = db.session

# Create roles
security_manager.add_role(self.NAME_AB_ROLE)
security_manager.add_role(self.NAME_Q_ROLE)
self.role_ab = security_manager.add_role(self.NAME_AB_ROLE)
self.role_q = security_manager.add_role(self.NAME_Q_ROLE)
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.append(security_manager.find_role(self.NAME_AB_ROLE))
gamma_user.roles.append(security_manager.find_role(self.NAME_Q_ROLE))
gamma_user.roles.append(self.role_ab)
gamma_user.roles.append(self.role_q)
self.create_user_with_roles("NoRlsRoleUser", ["Gamma"])
session.commit()

Expand Down Expand Up @@ -205,3 +211,106 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self):
assert not self.NAMES_B_REGEX.search(sql)
assert not self.NAMES_Q_REGEX.search(sql)
assert not self.BASE_FILTER_REGEX.search(sql)


RLS_ALICE_REGEX = re.compile(r"name = 'Alice'")
RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)")


@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags", EMBEDDED_SUPERSET=True,
)
class GuestTokenRowLevelSecurityTests(SupersetTestCase):
query_obj: Dict[str, Any] = dict(
groupby=[],
metrics=None,
filter=[],
is_timeseries=False,
columns=["value"],
granularity=None,
from_dttm=None,
to_dttm=None,
extras={},
)

def default_rls_rule(self):
return {
"dataset": self.get_table(name="birth_names").id,
"clause": "name = 'Alice'",
}

def guest_user_with_rls(self, rules: Optional[List[Any]] = None) -> GuestUser:
if rules is None:
rules = [self.default_rls_rule()]
return security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [{"type": GuestTokenResourceType.DASHBOARD.value}],
"rls_rules": rules,
}
)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_alters_query(self):
g.user = self.guest_user_with_rls()
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)

self.assertRegexpMatches(sql, RLS_ALICE_REGEX)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_does_not_alter_unrelated_query(self):
g.user = self.guest_user_with_rls(
rules=[
{
"dataset": self.get_table(name="birth_names").id + 1,
"clause": "name = 'Alice'",
}
]
)
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)

self.assertNotRegexpMatches(sql, RLS_ALICE_REGEX)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_multiple_rls_filters_are_unionized(self):
g.user = self.guest_user_with_rls(
rules=[
self.default_rls_rule(),
{
"dataset": self.get_table(name="birth_names").id,
"clause": "gender = 'girl'",
},
]
)
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)

self.assertRegexpMatches(sql, RLS_ALICE_REGEX)
self.assertRegexpMatches(sql, RLS_GENDER_REGEX)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_rls_filter_for_all_datasets(self):
births = self.get_table(name="birth_names")
energy = self.get_table(name="energy_usage")
guest = self.guest_user_with_rls(rules=[{"clause": "name = 'Alice'"}])
guest.resources.append({type: "dashboard", id: energy.id})
g.user = guest
births_sql = births.get_query_str(self.query_obj)
energy_sql = energy.get_query_str(self.query_obj)

self.assertRegexpMatches(births_sql, RLS_ALICE_REGEX)
self.assertRegexpMatches(energy_sql, RLS_ALICE_REGEX)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_dataset_id_can_be_string(self):
dataset = self.get_table(name="birth_names")
str_id = str(dataset.id)
g.user = self.guest_user_with_rls(
rules=[{"dataset": str_id, "clause": "name = 'Alice'"}]
)
sql = dataset.get_query_str(self.query_obj)

self.assertRegexpMatches(sql, RLS_ALICE_REGEX)
Loading