Skip to content

Commit

Permalink
Replace get accessible dag ids (#11027)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhtimmins authored Oct 1, 2020
1 parent b6d5d1e commit 427a4a8
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 78 deletions.
64 changes: 31 additions & 33 deletions airflow/www/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from flask import current_app, g
from flask_appbuilder.security.sqla import models as sqla_models
from flask_appbuilder.security.sqla.manager import SecurityManager
from flask_appbuilder.security.sqla.models import PermissionView, Role, User
from sqlalchemy import and_, or_
from sqlalchemy.orm import joinedload

from airflow import models
from airflow.exceptions import AirflowException
Expand All @@ -41,7 +43,9 @@

CAN_CREATE = 'can_create'
CAN_READ = 'can_read'
CAN_DAG_READ = 'can_dag_read'
CAN_EDIT = 'can_edit'
CAN_DAG_EDIT = 'can_dag_edit'
CAN_DELETE = 'can_delete'


Expand Down Expand Up @@ -276,60 +280,54 @@ def get_all_permissions_views(self):

def get_readable_dags(self, user):
"""Gets the DAGs readable by authenticated user."""
return self.get_accessible_dags(CAN_READ, user)
return self.get_accessible_dags([CAN_READ, CAN_DAG_READ], user)

def get_editable_dags(self, user):
"""Gets the DAGs editable by authenticated user."""
return self.get_accessible_dags(CAN_EDIT, user)
return self.get_accessible_dags([CAN_EDIT, CAN_DAG_EDIT], user)

def get_readable_dag_ids(self, user):
def get_readable_dag_ids(self, user) -> Set[str]:
"""Gets the DAG IDs readable by authenticated user."""
return [dag.dag_id for dag in self.get_readable_dags(user)]
return set(dag.dag_id for dag in self.get_readable_dags(user))

def get_editable_dag_ids(self, user):
def get_editable_dag_ids(self, user) -> Set[str]:
"""Gets the DAG IDs editable by authenticated user."""
return [dag.dag_id for dag in self.get_editable_dags(user)]
return set(dag.dag_id for dag in self.get_editable_dags(user))

def get_accessible_dag_ids(self, user) -> Set[str]:
"""Gets the DAG IDs editable or readable by authenticated user."""
accessible_dags = self.get_accessible_dags([CAN_EDIT, CAN_DAG_EDIT, CAN_READ, CAN_DAG_READ], user)
return set(dag.dag_id for dag in accessible_dags)

@provide_session
def get_accessible_dags(self, user_action, user, session=None):
def get_accessible_dags(self, user_actions, user, session=None):
"""Generic function to get readable or writable DAGs for authenticated user."""
if user.is_anonymous:
return set()

user_query = (
session.query(User)
.options(
joinedload(User.roles)
.subqueryload(Role.permissions)
.options(joinedload(PermissionView.permission), joinedload(PermissionView.view_menu))
)
.filter(User.id == user.id)
.first()
)
resources = set()
for role in user.roles:
for role in user_query.roles:
for permission in role.permissions:
resource = permission.view_menu.name
action = permission.permission.name
if action == user_action:
if action in user_actions:
resources.add(resource)
if 'Dag' in resources:

if bool({'Dag', 'all_dags'}.intersection(resources)):
return session.query(DagModel)

return session.query(DagModel).filter(DagModel.dag_id.in_(resources))

def get_accessible_dag_ids(self, username=None) -> Set[str]:
"""
Return a set of dags that user has access to(either read or write).
:param username: Name of the user.
:return: A set of dag ids that the user could access.
"""
if not username:
username = g.user

if username.is_anonymous or 'Public' in username.roles:
# return an empty set if the role is public
return set()

roles = {role.name for role in username.roles}
if {'Admin', 'Viewer', 'User', 'Op'} & roles:
return self.DAG_VMS

user_perms_views = self.get_all_permissions_views()
# return a set of all dags that the user could access
return {view for perm, view in user_perms_views if perm in self.DAG_PERMS}

def has_access(self, permission, view_name, user=None) -> bool:
"""
Verify whether a given user could perform certain permission
Expand Down Expand Up @@ -414,7 +412,7 @@ def clean_perms(self):

def _merge_perm(self, permission_name, view_menu_name):
"""
Add the new permission , view_menu to ab_permission_view_role if not exists.
Add the new (permission, view_menu) to assoc_permissionview_role if it doesn't exist.
It will add the related entry to ab_permission
and ab_view_menu two meta tables as well.
Expand Down
21 changes: 10 additions & 11 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import nvd3
import sqlalchemy as sqla
from flask import (
Markup, Response, current_app, escape, flash, jsonify, make_response, redirect, render_template, request,
session as flask_session, url_for,
Markup, Response, current_app, escape, flash, g, jsonify, make_response, redirect, render_template,
request, session as flask_session, url_for,
)
from flask_appbuilder import BaseView, ModelView, expose, has_access, permission_name
from flask_appbuilder.actions import action
Expand Down Expand Up @@ -442,7 +442,7 @@ def get_int_arg(value, default=0):
end = start + dags_per_page

# Get all the dag id the user could access
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)

with create_session() as session:
# read orm_dags from the db
Expand Down Expand Up @@ -543,7 +543,7 @@ def dag_stats(self, session=None):
"""Dag statistics."""
dr = models.DagRun

allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]

Expand Down Expand Up @@ -588,7 +588,7 @@ def dag_stats(self, session=None):
@provide_session
def task_stats(self, session=None):
"""Task Statistics"""
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)

if not allowed_dag_ids:
return wwwutils.json_response({})
Expand Down Expand Up @@ -702,7 +702,7 @@ def task_stats(self, session=None):
@provide_session
def last_dagruns(self, session=None):
"""Last DAG runs"""
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)

if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
Expand Down Expand Up @@ -1385,7 +1385,7 @@ def dagrun_clear(self):
@provide_session
def blocked(self, session=None):
"""Mark Dag Blocked."""
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)

if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
Expand Down Expand Up @@ -2287,7 +2287,6 @@ def extra_links(self):
return response

task = dag.get_task(task_id)

try:
url = task.get_extra_links(dttm, link_name)
except ValueError as err:
Expand Down Expand Up @@ -2416,7 +2415,7 @@ class DagFilter(BaseFilter):
def apply(self, query, func): # noqa pylint: disable=redefined-outer-name,unused-argument
if current_app.appbuilder.sm.has_all_dags_access():
return query
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
return query.filter(self.model.dag_id.in_(filter_dag_ids))


Expand Down Expand Up @@ -3136,9 +3135,9 @@ def autocomplete(self, session=None):
dag_ids_query = dag_ids_query.filter(DagModel.is_paused)
owners_query = owners_query.filter(DagModel.is_paused)

filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
# pylint: disable=no-member
if 'all_dags' not in filter_dag_ids:
if not bool({'all_dags', 'Dag'}.intersection(filter_dag_ids)):
dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids))
owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids))
# pylint: enable=no-member
Expand Down
84 changes: 51 additions & 33 deletions tests/www/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@
import unittest
from unittest import mock

from flask import Flask
from flask_appbuilder import SQLA, AppBuilder, Model, expose, has_access
from flask_appbuilder import SQLA, Model, expose, has_access
from flask_appbuilder.security.sqla import models as sqla_models
from flask_appbuilder.views import BaseView, ModelView
from sqlalchemy import Column, Date, Float, Integer, String

from airflow import settings
from airflow.exceptions import AirflowException
from airflow.www.security import AirflowSecurityManager
from airflow.models import DagModel
from airflow.www import app as application
from airflow.www.utils import CustomSQLAInterface
from tests.test_utils.db import clear_db_runs
from tests.test_utils.mock_security_manager import MockSecurityManager

READ_WRITE = {'can_dag_read', 'can_dag_edit'}
Expand Down Expand Up @@ -66,22 +68,24 @@ def some_action(self):


class TestSecurity(unittest.TestCase):
@classmethod
def setUpClass(cls):
settings.configure_orm()
cls.session = settings.Session
cls.app = application.create_app(testing=True)
cls.appbuilder = cls.app.appbuilder # pylint: disable=no-member
cls.app.config['WTF_CSRF_ENABLED'] = False
cls.security_manager = cls.appbuilder.sm
cls.role_admin = cls.security_manager.find_role('Admin')
cls.user = cls.appbuilder.sm.add_user(
'admin', 'admin', 'user', '[email protected]', cls.role_admin, 'general'
)

def setUp(self):
self.app = Flask(__name__)
self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
self.app.config['SECRET_KEY'] = 'secret_key'
self.app.config['CSRF_ENABLED'] = False
self.app.config['WTF_CSRF_ENABLED'] = False
self.db = SQLA(self.app)
self.appbuilder = AppBuilder(self.app,
self.db.session,
security_manager_class=AirflowSecurityManager)
self.security_manager = self.appbuilder.sm
self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
role_admin = self.security_manager.find_role('Admin')
self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '[email protected]',
role_admin, 'general')

log.debug("Complete setup!")

def expect_user_is_in_role(self, user, rolename):
Expand Down Expand Up @@ -112,13 +116,14 @@ def _has_dag_perm(self, perm, dag_id):
self.user)

def tearDown(self):
clear_db_runs()
self.appbuilder = None
self.app = None
self.db = None
log.debug("Complete teardown!")

def test_init_role_baseview(self):
role_name = 'MyRole1'
role_name = 'MyRole3'
role_perms = ['can_some_action']
role_vms = ['SomeBaseView']
self.security_manager.init_role(role_name, role_vms, role_perms)
Expand Down Expand Up @@ -159,7 +164,7 @@ def test_get_user_roles(self):

@mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles')
def test_get_all_permissions_views(self, mock_get_user_roles):
role_name = 'MyRole1'
role_name = 'MyRole5'
role_perms = ['can_some_action']
role_vms = ['SomeBaseView']
self.security_manager.init_role(role_name, role_vms, role_perms)
Expand All @@ -174,23 +179,27 @@ def test_get_all_permissions_views(self, mock_get_user_roles):
self.assertEqual(len(self.security_manager
.get_all_permissions_views()), 0)

@mock.patch('airflow.www.security.AirflowSecurityManager'
'.get_all_permissions_views')
@mock.patch('airflow.www.security.AirflowSecurityManager'
'.get_user_roles')
def test_get_accessible_dag_ids(self, mock_get_user_roles,
mock_get_all_permissions_views):
user = mock.MagicMock()
def test_get_accessible_dag_ids(self):
role_name = 'MyRole1'
role_perms = ['can_dag_read']
role_vms = ['dag_id']
self.security_manager.init_role(role_name, role_vms, role_perms)
permission_action = ['can_dag_read']
dag_id = 'dag_id'
username = "Mr. User"
self.security_manager.init_role(role_name, [], [])
self.security_manager.sync_perm_for_dag( # type: ignore # pylint: disable=no-member
dag_id, access_control={role_name: permission_action}
)
role = self.security_manager.find_role(role_name)
user.roles = [role]
user.is_anonymous = False
mock_get_all_permissions_views.return_value = {('can_dag_read', 'dag_id')}

mock_get_user_roles.return_value = [role]
user = self.security_manager.add_user(
username=username,
first_name=username,
last_name=username,
email=f"{username}@fab.org",
role=role,
password=username,
)
dag_model = DagModel(dag_id="dag_id", fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *")
self.session.add(dag_model)
self.session.commit()
self.assertEqual(self.security_manager
.get_accessible_dag_ids(user), {'dag_id'})

Expand Down Expand Up @@ -235,8 +244,17 @@ def test_access_control_with_invalid_permission(self):
'can_varimport', # a real permission, but not a member of DAG_PERMS
'can_eat_pudding', # clearly not a real permission
]
username = "Mrs. User"
user = self.security_manager.add_user(
username=username,
first_name=username,
last_name=username,
email=f"{username}@fab.org",
role=self.role_admin,
password=username,
)
for permission in invalid_permissions:
self.expect_user_is_in_role(self.user, rolename='team-a')
self.expect_user_is_in_role(user, rolename='team-a')
with self.assertRaises(AirflowException) as context:
self.security_manager.sync_perm_for_dag(
'access_control_test',
Expand Down
2 changes: 1 addition & 1 deletion tests/www/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def prepare_dagruns(self):
state=State.RUNNING)

def test_index(self):
with assert_queries_count(40):
with assert_queries_count(43):
resp = self.client.get('/', follow_redirects=True)
self.check_content_in_response('DAGs', resp)

Expand Down

0 comments on commit 427a4a8

Please sign in to comment.