-
Notifications
You must be signed in to change notification settings - Fork 14.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace get accessible dag ids (#11027)
- Loading branch information
Showing
4 changed files
with
93 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,15 +20,17 @@ | |
import unittest | ||
from unittest import mock | ||
|
||
from flask import Flask | ||
from flask_appbuilder import SQLA, AppBuilder, Model, expose, has_access | ||
from flask_appbuilder import SQLA, Model, expose, has_access | ||
from flask_appbuilder.security.sqla import models as sqla_models | ||
from flask_appbuilder.views import BaseView, ModelView | ||
from sqlalchemy import Column, Date, Float, Integer, String | ||
|
||
from airflow import settings | ||
from airflow.exceptions import AirflowException | ||
from airflow.www.security import AirflowSecurityManager | ||
from airflow.models import DagModel | ||
from airflow.www import app as application | ||
from airflow.www.utils import CustomSQLAInterface | ||
from tests.test_utils.db import clear_db_runs | ||
from tests.test_utils.mock_security_manager import MockSecurityManager | ||
|
||
READ_WRITE = {'can_dag_read', 'can_dag_edit'} | ||
|
@@ -66,22 +68,24 @@ def some_action(self): | |
|
||
|
||
class TestSecurity(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
settings.configure_orm() | ||
cls.session = settings.Session | ||
cls.app = application.create_app(testing=True) | ||
cls.appbuilder = cls.app.appbuilder # pylint: disable=no-member | ||
cls.app.config['WTF_CSRF_ENABLED'] = False | ||
cls.security_manager = cls.appbuilder.sm | ||
cls.role_admin = cls.security_manager.find_role('Admin') | ||
cls.user = cls.appbuilder.sm.add_user( | ||
'admin', 'admin', 'user', '[email protected]', cls.role_admin, 'general' | ||
) | ||
|
||
def setUp(self): | ||
self.app = Flask(__name__) | ||
self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' | ||
self.app.config['SECRET_KEY'] = 'secret_key' | ||
self.app.config['CSRF_ENABLED'] = False | ||
self.app.config['WTF_CSRF_ENABLED'] = False | ||
self.db = SQLA(self.app) | ||
self.appbuilder = AppBuilder(self.app, | ||
self.db.session, | ||
security_manager_class=AirflowSecurityManager) | ||
self.security_manager = self.appbuilder.sm | ||
self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") | ||
self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews") | ||
role_admin = self.security_manager.find_role('Admin') | ||
self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '[email protected]', | ||
role_admin, 'general') | ||
|
||
log.debug("Complete setup!") | ||
|
||
def expect_user_is_in_role(self, user, rolename): | ||
|
@@ -112,13 +116,14 @@ def _has_dag_perm(self, perm, dag_id): | |
self.user) | ||
|
||
def tearDown(self): | ||
clear_db_runs() | ||
self.appbuilder = None | ||
self.app = None | ||
self.db = None | ||
log.debug("Complete teardown!") | ||
|
||
def test_init_role_baseview(self): | ||
role_name = 'MyRole1' | ||
role_name = 'MyRole3' | ||
role_perms = ['can_some_action'] | ||
role_vms = ['SomeBaseView'] | ||
self.security_manager.init_role(role_name, role_vms, role_perms) | ||
|
@@ -159,7 +164,7 @@ def test_get_user_roles(self): | |
|
||
@mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles') | ||
def test_get_all_permissions_views(self, mock_get_user_roles): | ||
role_name = 'MyRole1' | ||
role_name = 'MyRole5' | ||
role_perms = ['can_some_action'] | ||
role_vms = ['SomeBaseView'] | ||
self.security_manager.init_role(role_name, role_vms, role_perms) | ||
|
@@ -174,23 +179,27 @@ def test_get_all_permissions_views(self, mock_get_user_roles): | |
self.assertEqual(len(self.security_manager | ||
.get_all_permissions_views()), 0) | ||
|
||
@mock.patch('airflow.www.security.AirflowSecurityManager' | ||
'.get_all_permissions_views') | ||
@mock.patch('airflow.www.security.AirflowSecurityManager' | ||
'.get_user_roles') | ||
def test_get_accessible_dag_ids(self, mock_get_user_roles, | ||
mock_get_all_permissions_views): | ||
user = mock.MagicMock() | ||
def test_get_accessible_dag_ids(self): | ||
role_name = 'MyRole1' | ||
role_perms = ['can_dag_read'] | ||
role_vms = ['dag_id'] | ||
self.security_manager.init_role(role_name, role_vms, role_perms) | ||
permission_action = ['can_dag_read'] | ||
dag_id = 'dag_id' | ||
username = "Mr. User" | ||
self.security_manager.init_role(role_name, [], []) | ||
self.security_manager.sync_perm_for_dag( # type: ignore # pylint: disable=no-member | ||
dag_id, access_control={role_name: permission_action} | ||
) | ||
role = self.security_manager.find_role(role_name) | ||
user.roles = [role] | ||
user.is_anonymous = False | ||
mock_get_all_permissions_views.return_value = {('can_dag_read', 'dag_id')} | ||
|
||
mock_get_user_roles.return_value = [role] | ||
user = self.security_manager.add_user( | ||
username=username, | ||
first_name=username, | ||
last_name=username, | ||
email=f"{username}@fab.org", | ||
role=role, | ||
password=username, | ||
) | ||
dag_model = DagModel(dag_id="dag_id", fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *") | ||
self.session.add(dag_model) | ||
self.session.commit() | ||
self.assertEqual(self.security_manager | ||
.get_accessible_dag_ids(user), {'dag_id'}) | ||
|
||
|
@@ -235,8 +244,17 @@ def test_access_control_with_invalid_permission(self): | |
'can_varimport', # a real permission, but not a member of DAG_PERMS | ||
'can_eat_pudding', # clearly not a real permission | ||
] | ||
username = "Mrs. User" | ||
user = self.security_manager.add_user( | ||
username=username, | ||
first_name=username, | ||
last_name=username, | ||
email=f"{username}@fab.org", | ||
role=self.role_admin, | ||
password=username, | ||
) | ||
for permission in invalid_permissions: | ||
self.expect_user_is_in_role(self.user, rolename='team-a') | ||
self.expect_user_is_in_role(user, rolename='team-a') | ||
with self.assertRaises(AirflowException) as context: | ||
self.security_manager.sync_perm_for_dag( | ||
'access_control_test', | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters