Skip to content

Commit

Permalink
fix: permission sqlalchemy events (apache#21454)
Browse files Browse the repository at this point in the history
(cherry picked from commit 64d216a)
  • Loading branch information
dpgaspar authored and zef committed Sep 15, 2022
1 parent 309d506 commit 4ca96dc
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 19 deletions.
56 changes: 50 additions & 6 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,34 @@ def _delete_pvm_on_sqla_event( # pylint: disable=too-many-arguments
view_menu_table.delete().where(view_menu_table.c.id == pvm.view_menu_id)
)

def _find_permission_on_sqla_event(
self, connection: Connection, name: str
) -> Permission:
permission_table = self.permission_model.__table__ # pylint: disable=no-member

permission_ = connection.execute(
permission_table.select().where(permission_table.c.name == name)
).fetchone()
permission = Permission()
permission.metadata = None
permission.id = permission_.id
permission.name = permission_.name
return permission

def _find_view_menu_on_sqla_event(
self, connection: Connection, name: str
) -> ViewMenu:
view_menu_table = self.viewmenu_model.__table__ # pylint: disable=no-member

view_menu_ = connection.execute(
view_menu_table.select().where(view_menu_table.c.name == name)
).fetchone()
view_menu = ViewMenu()
view_menu.metadata = None
view_menu.id = view_menu_.id
view_menu.name = view_menu_.name
return view_menu

def _insert_pvm_on_sqla_event(
self,
mapper: Mapper,
Expand Down Expand Up @@ -1511,20 +1539,36 @@ def _insert_pvm_on_sqla_event(
permission = self.find_permission(permission_name)
view_menu = self.find_view_menu(view_menu_name)
if not permission:
connection.execute(permission_table.insert().values(name=permission_name))
permission = self.find_permission(permission_name)
_ = connection.execute(
permission_table.insert().values(name=permission_name)
)
permission = self._find_permission_on_sqla_event(
connection, permission_name
)
self.on_permission_after_insert(mapper, connection, permission)
if not view_menu:
connection.execute(view_menu_table.insert().values(name=view_menu_name))
view_menu = self.find_view_menu(view_menu_name)
_ = connection.execute(view_menu_table.insert().values(name=view_menu_name))
view_menu = self._find_view_menu_on_sqla_event(connection, view_menu_name)
self.on_view_menu_after_insert(mapper, connection, view_menu)
connection.execute(
permission_view_table.insert().values(
permission_id=permission.id, view_menu_id=view_menu.id
)
)
permission = self.find_permission_view_menu(permission_name, view_menu_name)
self.on_permission_view_after_insert(mapper, connection, permission)
permission_view = connection.execute(
permission_view_table.select().where(
permission_view_table.c.permission_id == permission.id,
permission_view_table.c.view_menu_id == view_menu.id,
)
).fetchone()
permission_view_model = PermissionView()
permission_view_model.metadata = None
permission_view_model.id = permission_view.id
permission_view_model.permission_id = permission.id
permission_view_model.view_menu_id = view_menu.id
permission_view_model.permission = permission
permission_view_model.view_menu = view_menu
self.on_permission_view_after_insert(mapper, connection, permission_view_model)

def on_role_after_update(
self, mapper: Mapper, connection: Connection, target: Role
Expand Down
21 changes: 8 additions & 13 deletions tests/integration_tests/security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,13 @@ def test_after_insert_dataset(self):
self.assertIsNotNone(pvm_schema)

# assert on permission hooks
view_menu_dataset = security_manager.find_view_menu(
f"[tmp_db1].[tmp_perm_table](id:{table.id})"
)
view_menu_schema = security_manager.find_view_menu(f"[tmp_db1].[tmp_schema]")
security_manager.on_view_menu_after_insert.assert_has_calls(
[
call(ANY, ANY, view_menu_dataset),
call(ANY, ANY, view_menu_schema),
]
)
call_args = security_manager.on_permission_view_after_insert.call_args
assert call_args.args[2].id == pvm_schema.id

security_manager.on_permission_view_after_insert.assert_has_calls(
[
call(ANY, ANY, pvm_dataset),
call(ANY, ANY, pvm_schema),
call(ANY, ANY, ANY),
call(ANY, ANY, ANY),
]
)

Expand Down Expand Up @@ -289,9 +282,11 @@ def test_after_insert_database(self):
# Assert the hook is called
security_manager.on_permission_view_after_insert.assert_has_calls(
[
call(ANY, ANY, tmp_db1_pvm),
call(ANY, ANY, ANY),
]
)
call_args = security_manager.on_permission_view_after_insert.call_args
assert call_args.args[2].id == tmp_db1_pvm.id
session.delete(tmp_db1)
session.commit()

Expand Down

0 comments on commit 4ca96dc

Please sign in to comment.