Skip to content

Commit

Permalink
[copy] fix: Row Level Security get_rls_filters func SELECT statement (#…
Browse files Browse the repository at this point in the history
…9541)

* fix: Row Level Security get_rls_filters func SELECT statement

* More general RowLevelSecurityTests case to avoid improper ids matching
  • Loading branch information
axelet authored Apr 15, 2020
1 parent d81f720 commit ef5e11f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def get_rls_filters(self, table: "BaseDatasource"):
.subquery()
)
filter_roles = (
db.session.query(RLSFilterRoles.c.id)
db.session.query(RLSFilterRoles.c.rls_filter_id)
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
)
Expand Down
9 changes: 5 additions & 4 deletions tests/security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,10 +833,11 @@ def setUp(self):
self.rls_entry.table = (
session.query(SqlaTable).filter_by(table_name="birth_names").first()
)
self.rls_entry.clause = "gender = 'male'"
self.rls_entry.clause = "gender = 'boy'"
self.rls_entry.roles.append(
security_manager.find_role("Gamma")
) # db.session.query(Role).filter_by(name="Gamma").first())
self.rls_entry.roles.append(security_manager.find_role("Alpha"))
db.session.add(self.rls_entry)

db.session.commit()
Expand All @@ -849,7 +850,7 @@ def tearDown(self):
# Do another test to make sure it doesn't alter another query
def test_rls_filter_alters_query(self):
g.user = self.get_user(
username="gamma"
username="alpha"
) # self.login() doesn't actually set the user
tbl = self.get_table_by_name("birth_names")
query_obj = dict(
Expand All @@ -864,7 +865,7 @@ def test_rls_filter_alters_query(self):
extras={},
)
sql = tbl.get_query_str(query_obj)
self.assertIn("gender = 'male'", sql)
self.assertIn("gender = 'boy'", sql)

def test_rls_filter_doesnt_alter_query(self):
g.user = self.get_user(
Expand All @@ -883,4 +884,4 @@ def test_rls_filter_doesnt_alter_query(self):
extras={},
)
sql = tbl.get_query_str(query_obj)
self.assertNotIn("gender = 'male'", sql)
self.assertNotIn("gender = 'boy'", sql)

0 comments on commit ef5e11f

Please sign in to comment.