Skip to content

Commit

Permalink
test guest token rls rules
Browse files Browse the repository at this point in the history
  • Loading branch information
suddjian committed Jan 20, 2022
1 parent 1527c41 commit 07debee
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 45 deletions.
10 changes: 6 additions & 4 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,16 +1112,18 @@ def get_user_roles(self, user: Optional[User] = None) -> List[Role]:
return [self.get_public_role()] if public_role else []
return user.roles

def get_guest_rls_filters(self, table: "BaseDatasource") -> List[GuestTokenRlsRule]:
def get_guest_rls_filters(
self, dataset: "BaseDatasource"
) -> List[GuestTokenRlsRule]:
"""
Retrieves the row level security filters for the current user and the table,
Retrieves the row level security filters for the current user and the dataset,
if the user is authenticated with a guest token.
:param table: The table to check against
:param dataset: The dataset to check against
:return: A list of filters
"""
guest_user = self.get_current_guest_user_if_guest()
if guest_user:
return [rule for rule in guest_user.rls if rule["dataset"] == table.id]
return [rule for rule in guest_user.rls if rule["dataset"] == dataset.id]
return []

def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]:
Expand Down
150 changes: 109 additions & 41 deletions tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# under the License.
# isort:skip_file
import re
from typing import Any, Dict
from typing import Any, Dict, List
from unittest import mock

import pytest
from flask import g

from superset import db, security_manager
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.security.guest_token import GuestTokenRlsRule, GuestTokenResourceType
from ..base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
Expand All @@ -38,39 +40,39 @@
)


query_obj: Dict[str, Any] = dict(
groupby=[],
metrics=None,
filter=[],
is_timeseries=False,
columns=["value"],
granularity=None,
from_dttm=None,
to_dttm=None,
extras={},
)
NAME_AB_ROLE = "NameAB"
NAME_Q_ROLE = "NameQ"
NAMES_A_REGEX = re.compile(r"name like 'A%'")
NAMES_B_REGEX = re.compile(r"name like 'B%'")
NAMES_Q_REGEX = re.compile(r"name like 'Q%'")
BASE_FILTER_REGEX = re.compile(r"gender = 'boy'")


class TestRowLevelSecurity(SupersetTestCase):
"""
Testing Row Level Security
"""

rls_entry = None
query_obj: Dict[str, Any] = dict(
groupby=[],
metrics=None,
filter=[],
is_timeseries=False,
columns=["value"],
granularity=None,
from_dttm=None,
to_dttm=None,
extras={},
)
NAME_AB_ROLE = "NameAB"
NAME_Q_ROLE = "NameQ"
NAMES_A_REGEX = re.compile(r"name like 'A%'")
NAMES_B_REGEX = re.compile(r"name like 'B%'")
NAMES_Q_REGEX = re.compile(r"name like 'Q%'")
BASE_FILTER_REGEX = re.compile(r"gender = 'boy'")

def setUp(self):
session = db.session

# Create roles
security_manager.add_role(self.NAME_AB_ROLE)
security_manager.add_role(self.NAME_Q_ROLE)
security_manager.add_role(NAME_AB_ROLE)
security_manager.add_role(NAME_Q_ROLE)
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.append(security_manager.find_role(self.NAME_AB_ROLE))
gamma_user.roles.append(security_manager.find_role(self.NAME_Q_ROLE))
gamma_user.roles.append(security_manager.find_role(NAME_AB_ROLE))
gamma_user.roles.append(security_manager.find_role(NAME_Q_ROLE))
self.create_user_with_roles("NoRlsRoleUser", ["Gamma"])
session.commit()

Expand Down Expand Up @@ -144,8 +146,8 @@ def tearDown(self):
def test_rls_filter_alters_energy_query(self):
g.user = self.get_user(username="alpha")
tbl = self.get_table(name="energy_usage")
sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == [1]
sql = tbl.get_query_str(query_obj)
assert tbl.get_extra_cache_keys(query_obj) == [1]
assert "value > 1" in sql

@pytest.mark.usefixtures("load_energy_table_with_slice")
Expand All @@ -154,8 +156,8 @@ def test_rls_filter_doesnt_alter_energy_query(self):
username="admin"
) # self.login() doesn't actually set the user
tbl = self.get_table(name="energy_usage")
sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == []
sql = tbl.get_query_str(query_obj)
assert tbl.get_extra_cache_keys(query_obj) == []
assert "value > 1" not in sql

@pytest.mark.usefixtures("load_unicode_dashboard_with_slice")
Expand All @@ -164,15 +166,15 @@ def test_multiple_table_filter_alters_another_tables_query(self):
username="alpha"
) # self.login() doesn't actually set the user
tbl = self.get_table(name="unicode_test")
sql = tbl.get_query_str(self.query_obj)
assert tbl.get_extra_cache_keys(self.query_obj) == [1]
sql = tbl.get_query_str(query_obj)
assert tbl.get_extra_cache_keys(query_obj) == [1]
assert "value > 1" in sql

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_alters_gamma_birth_names_query(self):
g.user = self.get_user(username="gamma")
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
sql = tbl.get_query_str(query_obj)

# establish that the filters are grouped together correctly with
# ANDs, ORs and parens in the correct place
Expand All @@ -185,23 +187,89 @@ def test_rls_filter_alters_gamma_birth_names_query(self):
def test_rls_filter_alters_no_role_user_birth_names_query(self):
g.user = self.get_user(username="NoRlsRoleUser")
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
sql = tbl.get_query_str(query_obj)

# gamma's filters should not be present query
assert not self.NAMES_A_REGEX.search(sql)
assert not self.NAMES_B_REGEX.search(sql)
assert not self.NAMES_Q_REGEX.search(sql)
assert not NAMES_A_REGEX.search(sql)
assert not NAMES_B_REGEX.search(sql)
assert not NAMES_Q_REGEX.search(sql)
# base query should be present
assert self.BASE_FILTER_REGEX.search(sql)
assert BASE_FILTER_REGEX.search(sql)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_doesnt_alter_admin_birth_names_query(self):
g.user = self.get_user(username="admin")
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(self.query_obj)
sql = tbl.get_query_str(query_obj)

# no filters are applied for admin user
assert not self.NAMES_A_REGEX.search(sql)
assert not self.NAMES_B_REGEX.search(sql)
assert not self.NAMES_Q_REGEX.search(sql)
assert not self.BASE_FILTER_REGEX.search(sql)
assert not NAMES_A_REGEX.search(sql)
assert not NAMES_B_REGEX.search(sql)
assert not NAMES_Q_REGEX.search(sql)
assert not BASE_FILTER_REGEX.search(sql)


RLS_ALICE_REGEX = re.compile(r"name = 'Alice'")
RLS_GENDER_REGEX = re.compile(r"gender = 'girl'")


@mock.patch.dict(
"superset.extensions.feature_flag_manager._feature_flags", EMBEDDED_SUPERSET=True,
)
class GuestTokenRowLevelSecurityTests(SupersetTestCase):
def default_rls_rule(self):
return {
"dataset": self.get_table(name="birth_names").id,
"clause": "name = 'Alice'",
}

def guest_user_with_rls(self, rules: List[Any] = None):
if rules is None:
rules = [self.default_rls_rule()]
return security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [{"type": GuestTokenResourceType.DASHBOARD.value}],
"rls_rules": rules,
}
)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_alters_query(self):
g.user = self.guest_user_with_rls()
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(query_obj)

self.assertRegexpMatches(sql, RLS_ALICE_REGEX)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_rls_filter_does_not_alter_unrelated_query(self):
g.user = self.guest_user_with_rls(
rules=[
{
"dataset": self.get_table(name="birth_names").id + 1,
"clause": "name = 'Alice'",
}
]
)
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(query_obj)

self.assertNotRegexpMatches(sql, RLS_ALICE_REGEX)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_multiple_rls_filters_are_unionized(self):
g.user = self.guest_user_with_rls(
rules=[
self.default_rls_rule(),
{
"dataset": self.get_table(name="birth_names").id,
"clause": "gender = 'girl'",
},
]
)
tbl = self.get_table(name="birth_names")
sql = tbl.get_query_str(query_obj)

self.assertRegexpMatches(sql, RLS_ALICE_REGEX)
self.assertRegexpMatches(sql, RLS_GENDER_REGEX)

0 comments on commit 07debee

Please sign in to comment.