diff --git a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py index 2b07c4de2d3aa..b8ace0e1de8bb 100644 --- a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py +++ b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py @@ -40,14 +40,14 @@ if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse, UpdateMask - from airflow.www.security import AirflowSecurityManager + from airflow.www.security_manager import AirflowSecurityManagerV2 -def _check_action_and_resource(sm: AirflowSecurityManager, perms: list[tuple[str, str]]) -> None: +def _check_action_and_resource(sm: AirflowSecurityManagerV2, perms: list[tuple[str, str]]) -> None: """ Check if the action or resource exists and otherwise raise 400. - This function is intended for use in the REST API because it raise 400 + This function is intended for use in the REST API because it raises an HTTP error 400 """ for action, resource in perms: if not sm.get_action(action): diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index a512804b4ca94..66b79b0a26daf 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -24,9 +24,11 @@ from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: + from flask import Flask + from airflow.auth.managers.models.base_user import BaseUser from airflow.cli.cli_config import CLICommand - from airflow.www.security import AirflowSecurityManager + from airflow.www.security_manager import AirflowSecurityManagerV2 class BaseAuthManager(LoggingMixin): @@ -36,8 +38,9 @@ class BaseAuthManager(LoggingMixin): Auth managers are responsible for any user management related operation such as login, logout, authz, ... """ - def __init__(self): - self._security_manager: AirflowSecurityManager | None = None + def __init__(self, app: Flask) -> None: + self._security_manager: AirflowSecurityManagerV2 | None = None + self.app = app @staticmethod def get_cli_commands() -> list[CLICommand]: @@ -80,22 +83,24 @@ def get_security_manager_override_class(self) -> type: Return the security manager override class. The security manager override class is responsible for overriding the default security manager - class airflow.www.security.AirflowSecurityManager with a custom implementation. This class is - essentially inherited from airflow.www.security.AirflowSecurityManager. + class airflow.www.security_manager.AirflowSecurityManagerV2 with a custom implementation. + This class is essentially inherited from airflow.www.security_manager.AirflowSecurityManagerV2. - By default, return an empty class. + By default, return the generic AirflowSecurityManagerV2. """ - return object + from airflow.www.security_manager import AirflowSecurityManagerV2 + + return AirflowSecurityManagerV2 @property - def security_manager(self) -> AirflowSecurityManager: + def security_manager(self) -> AirflowSecurityManagerV2: """Get the security manager.""" if not self._security_manager: raise AirflowException("Security manager not defined.") return self._security_manager @security_manager.setter - def security_manager(self, security_manager: AirflowSecurityManager): + def security_manager(self, security_manager: AirflowSecurityManagerV2): """ Set the security manager. diff --git a/airflow/auth/managers/fab/cli_commands/role_command.py b/airflow/auth/managers/fab/cli_commands/role_command.py index ce69c7f201dfb..fc828ede771f6 100644 --- a/airflow/auth/managers/fab/cli_commands/role_command.py +++ b/airflow/auth/managers/fab/cli_commands/role_command.py @@ -29,7 +29,7 @@ from airflow.utils import cli as cli_utils from airflow.utils.cli import suppress_logs_and_warning from airflow.utils.providers_configuration_loader import providers_configuration_loaded -from airflow.www.security import EXISTING_ROLES +from airflow.www.security_manager import EXISTING_ROLES if TYPE_CHECKING: from airflow.auth.managers.fab.models import Action, Permission, Resource, Role diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 354a073be2e6d..27c96d6dd596c 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import warnings from typing import TYPE_CHECKING from airflow import AirflowException @@ -90,8 +91,24 @@ def is_logged_in(self) -> bool: def get_security_manager_override_class(self) -> type: """Return the security manager override.""" from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride - - return FabAirflowSecurityManagerOverride + from airflow.www.security import AirflowSecurityManager + + sm_from_config = self.app.config.get("SECURITY_MANAGER_CLASS") + if sm_from_config: + if not issubclass(sm_from_config, AirflowSecurityManager): + raise Exception( + """Your CUSTOM_SECURITY_MANAGER must extend FabAirflowSecurityManagerOverride, + not FAB's own security manager.""" + ) + if not issubclass(sm_from_config, FabAirflowSecurityManagerOverride): + warnings.warn( + "Please make your custom security manager inherit from " + "FabAirflowSecurityManagerOverride instead of AirflowSecurityManager.", + DeprecationWarning, + ) + return sm_from_config + + return FabAirflowSecurityManagerOverride # default choice def url_for(self, *args, **kwargs): """Wrapper to allow mocking without having to import at the top of the file.""" diff --git a/airflow/auth/managers/fab/security_manager/modules/__init__.py b/airflow/auth/managers/fab/security_manager/modules/__init__.py deleted file mode 100644 index 217e5db960782..0000000000000 --- a/airflow/auth/managers/fab/security_manager/modules/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/auth/managers/fab/security_manager/modules/db.py b/airflow/auth/managers/fab/security_manager/modules/db.py deleted file mode 100644 index 77f1a8205f385..0000000000000 --- a/airflow/auth/managers/fab/security_manager/modules/db.py +++ /dev/null @@ -1,557 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import logging -import uuid - -from flask_appbuilder import const -from flask_appbuilder.models.sqla import Base -from sqlalchemy import func, inspect, select -from sqlalchemy.exc import MultipleResultsFound -from werkzeug.security import generate_password_hash - -from airflow import AirflowException -from airflow.auth.managers.fab.models import Action, Permission, Resource, Role - -log = logging.getLogger(__name__) - - -class FabAirflowSecurityManagerOverrideDb: - """ - FabAirflowSecurityManagerOverride is split into multiple classes to avoid having one massive class. - - This class contains all methods in - airflow.auth.managers.fab.security_manager.override.FabAirflowSecurityManagerOverride related to the - database. - - :param appbuilder: The appbuilder. - """ - - # Models - role_model = Role - permission_model = Permission - action_model = Action - resource_model = Resource - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.appbuilder = kwargs["appbuilder"] - - @property - def get_session(self): - return self.appbuilder.get_session - - def create_db(self): - """ - Create the database. - - Creates admin and public roles if they don't exist. - """ - if not self.appbuilder.update_perms: - log.debug("Skipping db since appbuilder disables update_perms") - return - try: - engine = self.get_session.get_bind(mapper=None, clause=None) - inspector = inspect(engine) - if "ab_user" not in inspector.get_table_names(): - log.info(const.LOGMSG_INF_SEC_NO_DB) - Base.metadata.create_all(engine) - log.info(const.LOGMSG_INF_SEC_ADD_DB) - - roles_mapping = self.appbuilder.app.config.get("FAB_ROLES_MAPPING", {}) - for pk, name in roles_mapping.items(): - self.update_role(pk, name) - for role_name in self._builtin_roles: - self.add_role(role_name) - if self.auth_role_admin not in self._builtin_roles: - self.add_role(self.auth_role_admin) - self.add_role(self.auth_role_public) - if self.count_users() == 0 and self.auth_role_public != self.auth_role_admin: - log.warning(const.LOGMSG_WAR_SEC_NO_USER) - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_CREATE_DB, e) - exit(1) - - """ - ----------- - Role entity - ----------- - """ - - def update_role(self, role_id, name: str) -> Role | None: - """Update a role in the database.""" - role = self.get_session.get(self.role_model, role_id) - if not role: - return None - try: - role.name = name - self.get_session.merge(role) - self.get_session.commit() - log.info(const.LOGMSG_INF_SEC_UPD_ROLE, role) - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_UPD_ROLE, e) - self.get_session.rollback() - return None - return role - - def add_role(self, name: str) -> Role: - """Add a role in the database.""" - role = self.find_role(name) - if role is None: - try: - role = self.role_model() - role.name = name - self.get_session.add(role) - self.get_session.commit() - log.info(const.LOGMSG_INF_SEC_ADD_ROLE, name) - return role - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_ADD_ROLE, e) - self.get_session.rollback() - return role - - def find_role(self, name): - """ - Find a role in the database. - - :param name: the role name - """ - return self.get_session.query(self.role_model).filter_by(name=name).one_or_none() - - def get_all_roles(self): - return self.get_session.query(self.role_model).all() - - def get_public_role(self): - return self.get_session.query(self.role_model).filter_by(name=self.auth_role_public).one_or_none() - - def delete_role(self, role_name: str) -> None: - """ - Delete the given Role. - - :param role_name: the name of a role in the ab_role table - """ - session = self.get_session - role = session.query(Role).filter(Role.name == role_name).first() - if role: - log.info("Deleting role '%s'", role_name) - session.delete(role) - session.commit() - else: - raise AirflowException(f"Role named '{role_name}' does not exist") - - """ - ----------- - User entity - ----------- - """ - - def add_user( - self, - username, - first_name, - last_name, - email, - role, - password="", - hashed_password="", - ): - """Generic function to create user.""" - try: - user = self.user_model() - user.first_name = first_name - user.last_name = last_name - user.username = username - user.email = email - user.active = True - user.roles = role if isinstance(role, list) else [role] - if hashed_password: - user.password = hashed_password - else: - user.password = generate_password_hash(password) - self.get_session.add(user) - self.get_session.commit() - log.info(const.LOGMSG_INF_SEC_ADD_USER, username) - return user - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_ADD_USER, e) - self.get_session.rollback() - return False - - def load_user(self, user_id): - """Load user by ID.""" - return self.get_user_by_id(int(user_id)) - - def get_user_by_id(self, pk): - return self.get_session.get(self.user_model, pk) - - def count_users(self): - """Return the number of users in the database.""" - return self.get_session.query(func.count(self.user_model.id)).scalar() - - def add_register_user(self, username, first_name, last_name, email, password="", hashed_password=""): - """ - Add a registration request for the user. - - :rtype : RegisterUser - """ - register_user = self.registeruser_model() - register_user.username = username - register_user.email = email - register_user.first_name = first_name - register_user.last_name = last_name - if hashed_password: - register_user.password = hashed_password - else: - register_user.password = generate_password_hash(password) - register_user.registration_hash = str(uuid.uuid1()) - try: - self.get_session.add(register_user) - self.get_session.commit() - return register_user - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_ADD_REGISTER_USER, e) - self.get_session.rollback() - return None - - def find_user(self, username=None, email=None): - """Finds user by username or email.""" - if username: - try: - if self.auth_username_ci: - return ( - self.get_session.query(self.user_model) - .filter(func.lower(self.user_model.username) == func.lower(username)) - .one_or_none() - ) - else: - return ( - self.get_session.query(self.user_model) - .filter(func.lower(self.user_model.username) == func.lower(username)) - .one_or_none() - ) - except MultipleResultsFound: - log.error("Multiple results found for user %s", username) - return None - elif email: - try: - return self.get_session.query(self.user_model).filter_by(email=email).one_or_none() - except MultipleResultsFound: - log.error("Multiple results found for user with email %s", email) - return None - - def find_register_user(self, registration_hash): - return self.get_session.scalar( - select(self.registeruser_mode) - .where(self.registeruser_model.registration_hash == registration_hash) - .limit(1) - ) - - def update_user(self, user): - try: - self.get_session.merge(user) - self.get_session.commit() - log.info(const.LOGMSG_INF_SEC_UPD_USER, user) - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_UPD_USER, e) - self.get_session.rollback() - return False - - def del_register_user(self, register_user): - """ - Deletes registration object from database. - - :param register_user: RegisterUser object to delete - """ - try: - self.get_session.delete(register_user) - self.get_session.commit() - return True - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_DEL_REGISTER_USER, e) - self.get_session.rollback() - return False - - def get_all_users(self): - return self.get_session.query(self.user_model).all() - - """ - ------------- - Action entity - ------------- - """ - - def get_action(self, name: str) -> Action: - """ - Gets an existing action record. - - :param name: name - :return: Action record, if it exists - """ - return self.get_session.query(self.action_model).filter_by(name=name).one_or_none() - - def create_action(self, name): - """ - Adds an action to the backend, model action. - - :param name: - name of the action: 'can_add','can_edit' etc... - """ - action = self.get_action(name) - if action is None: - try: - action = self.action_model() - action.name = name - self.get_session.add(action) - self.get_session.commit() - return action - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_ADD_PERMISSION, e) - self.get_session.rollback() - return action - - def delete_action(self, name: str) -> bool: - """ - Deletes a permission action. - - :param name: Name of action to delete (e.g. can_read). - """ - action = self.get_action(name) - if not action: - log.warning(const.LOGMSG_WAR_SEC_DEL_PERMISSION, name) - return False - try: - perms = ( - self.get_session.query(self.permission_model) - .filter(self.permission_model.action == action) - .all() - ) - if perms: - log.warning(const.LOGMSG_WAR_SEC_DEL_PERM_PVM, action, perms) - return False - self.get_session.delete(action) - self.get_session.commit() - return True - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_DEL_PERMISSION, e) - self.get_session.rollback() - return False - - """ - --------------- - Resource entity - --------------- - """ - - def get_resource(self, name: str) -> Resource: - """ - Returns a resource record by name, if it exists. - - :param name: Name of resource - """ - return self.get_session.query(self.resource_model).filter_by(name=name).one_or_none() - - def create_resource(self, name) -> Resource: - """ - Create a resource with the given name. - - :param name: The name of the resource to create created. - :return: The FAB resource created. - """ - resource = self.get_resource(name) - if resource is None: - try: - resource = self.resource_model() - resource.name = name - self.get_session.add(resource) - self.get_session.commit() - return resource - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_ADD_VIEWMENU, e) - self.get_session.rollback() - return resource - - def get_all_resources(self) -> list[Resource]: - """ - Gets all existing resource records. - - :return: List of all resources - """ - return self.get_session.query(self.resource_model).all() - - def delete_resource(self, name: str) -> bool: - """ - Deletes a Resource from the backend. - - :param name: - name of the resource - """ - resource = self.get_resource(name) - if not resource: - log.warning(const.LOGMSG_WAR_SEC_DEL_VIEWMENU, name) - return False - try: - perms = ( - self.get_session.query(self.permission_model) - .filter(self.permission_model.resource == resource) - .all() - ) - if perms: - log.warning(const.LOGMSG_WAR_SEC_DEL_VIEWMENU_PVM, resource, perms) - return False - self.get_session.delete(resource) - self.get_session.commit() - return True - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_DEL_PERMISSION, e) - self.get_session.rollback() - return False - - """ - --------------- - Permission entity - --------------- - """ - - def get_permission( - self, - action_name: str, - resource_name: str, - ) -> Permission | None: - """ - Gets a permission made with the given action->resource pair, if the permission already exists. - - :param action_name: Name of action - :param resource_name: Name of resource - :return: The existing permission - """ - action = self.get_action(action_name) - resource = self.get_resource(resource_name) - if action and resource: - return ( - self.get_session.query(self.permission_model) - .filter_by(action=action, resource=resource) - .one_or_none() - ) - return None - - def get_resource_permissions(self, resource: Resource) -> Permission: - """ - Retrieve permission pairs associated with a specific resource object. - - :param resource: Object representing a single resource. - :return: Action objects representing resource->action pair - """ - return self.get_session.query(self.permission_model).filter_by(resource_id=resource.id).all() - - def create_permission(self, action_name, resource_name) -> Permission | None: - """ - Adds a permission on a resource to the backend. - - :param action_name: - name of the action to add: 'can_add','can_edit' etc... - :param resource_name: - name of the resource to add - """ - if not (action_name and resource_name): - return None - perm = self.get_permission(action_name, resource_name) - if perm: - return perm - resource = self.create_resource(resource_name) - action = self.create_action(action_name) - perm = self.permission_model() - perm.resource_id, perm.action_id = resource.id, action.id - try: - self.get_session.add(perm) - self.get_session.commit() - log.info(const.LOGMSG_INF_SEC_ADD_PERMVIEW, perm) - return perm - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_ADD_PERMVIEW, e) - self.get_session.rollback() - return None - - def delete_permission(self, action_name: str, resource_name: str) -> None: - """ - Deletes the permission linking an action->resource pair. - - Doesn't delete the underlying action or resource. - - :param action_name: Name of existing action - :param resource_name: Name of existing resource - :return: None - """ - if not (action_name and resource_name): - return - perm = self.get_permission(action_name, resource_name) - if not perm: - return - roles = ( - self.get_session.query(self.role_model).filter(self.role_model.permissions.contains(perm)).first() - ) - if roles: - log.warning(const.LOGMSG_WAR_SEC_DEL_PERMVIEW, resource_name, action_name, roles) - return - try: - # delete permission on resource - self.get_session.delete(perm) - self.get_session.commit() - # if no more permission on permission view, delete permission - if not self.get_session.query(self.permission_model).filter_by(action=perm.action).all(): - self.delete_action(perm.action.name) - log.info(const.LOGMSG_INF_SEC_DEL_PERMVIEW, action_name, resource_name) - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_DEL_PERMVIEW, e) - self.get_session.rollback() - - def add_permission_to_role(self, role: Role, permission: Permission | None) -> None: - """ - Add an existing permission pair to a role. - - :param role: The role about to get a new permission. - :param permission: The permission pair to add to a role. - :return: None - """ - if permission and permission not in role.permissions: - try: - role.permissions.append(permission) - self.get_session.merge(role) - self.get_session.commit() - log.info(const.LOGMSG_INF_SEC_ADD_PERMROLE, permission, role.name) - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_ADD_PERMROLE, e) - self.get_session.rollback() - - def remove_permission_from_role(self, role: Role, permission: Permission) -> None: - """ - Remove a permission pair from a role. - - :param role: User role containing permissions. - :param permission: Object representing resource-> action pair - """ - if permission in role.permissions: - try: - role.permissions.remove(permission) - self.get_session.merge(role) - self.get_session.commit() - log.info(const.LOGMSG_INF_SEC_DEL_PERMROLE, permission, role.name) - except Exception as e: - log.error(const.LOGMSG_ERR_SEC_DEL_PERMROLE, e) - self.get_session.rollback() diff --git a/airflow/auth/managers/fab/security_manager/modules/oauth.py b/airflow/auth/managers/fab/security_manager/modules/oauth.py deleted file mode 100644 index e1a10de84d08e..0000000000000 --- a/airflow/auth/managers/fab/security_manager/modules/oauth.py +++ /dev/null @@ -1,186 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import base64 -import json -import logging - -import re2 -from flask import session - -log = logging.getLogger(__name__) - - -class FabAirflowSecurityManagerOverrideOauth: - """ - FabAirflowSecurityManagerOverride is split into multiple classes to avoid having one massive class. - - This class contains all methods in - airflow.auth.managers.fab.security_manager.override.FabAirflowSecurityManagerOverride related to the - oauth authentication. - """ - - def get_oauth_user_info(self, provider, resp): - """ - Get the OAuth user information from different OAuth APIs. - - All providers have different ways to retrieve user info. - """ - # for GITHUB - if provider == "github" or provider == "githublocal": - me = self.oauth_remotes[provider].get("user") - data = me.json() - log.debug("User info from GitHub: %s", data) - return {"username": "github_" + data.get("login")} - # for twitter - if provider == "twitter": - me = self.oauth_remotes[provider].get("account/settings.json") - data = me.json() - log.debug("User info from Twitter: %s", data) - return {"username": "twitter_" + data.get("screen_name", "")} - # for linkedin - if provider == "linkedin": - me = self.oauth_remotes[provider].get( - "people/~:(id,email-address,first-name,last-name)?format=json" - ) - data = me.json() - log.debug("User info from LinkedIn: %s", data) - return { - "username": "linkedin_" + data.get("id", ""), - "email": data.get("email-address", ""), - "first_name": data.get("firstName", ""), - "last_name": data.get("lastName", ""), - } - # for Google - if provider == "google": - me = self.oauth_remotes[provider].get("userinfo") - data = me.json() - log.debug("User info from Google: %s", data) - return { - "username": "google_" + data.get("id", ""), - "first_name": data.get("given_name", ""), - "last_name": data.get("family_name", ""), - "email": data.get("email", ""), - } - # for Azure AD Tenant. Azure OAuth response contains - # JWT token which has user info. - # JWT token needs to be base64 decoded. - # https://docs.microsoft.com/en-us/azure/active-directory/develop/ - # active-directory-protocols-oauth-code - if provider == "azure": - log.debug("Azure response received : %s", resp) - id_token = resp["id_token"] - log.debug(str(id_token)) - me = FabAirflowSecurityManagerOverrideOauth._azure_jwt_token_parse(id_token) - log.debug("Parse JWT token : %s", me) - return { - "name": me.get("name", ""), - "email": me["upn"], - "first_name": me.get("given_name", ""), - "last_name": me.get("family_name", ""), - "id": me["oid"], - "username": me["oid"], - "role_keys": me.get("roles", []), - } - # for OpenShift - if provider == "openshift": - me = self.oauth_remotes[provider].get("apis/user.openshift.io/v1/users/~") - data = me.json() - log.debug("User info from OpenShift: %s", data) - return {"username": "openshift_" + data.get("metadata").get("name")} - # for Okta - if provider == "okta": - me = self.oauth_remotes[provider].get("userinfo") - data = me.json() - log.debug("User info from Okta: %s", data) - return { - "username": "okta_" + data.get("sub", ""), - "first_name": data.get("given_name", ""), - "last_name": data.get("family_name", ""), - "email": data.get("email", ""), - "role_keys": data.get("groups", []), - } - # for Keycloak - if provider in ["keycloak", "keycloak_before_17"]: - me = self.oauth_remotes[provider].get("openid-connect/userinfo") - me.raise_for_status() - data = me.json() - log.debug("User info from Keycloak: %s", data) - return { - "username": data.get("preferred_username", ""), - "first_name": data.get("given_name", ""), - "last_name": data.get("family_name", ""), - "email": data.get("email", ""), - } - else: - return {} - - @staticmethod - def oauth_token_getter(): - """Authentication (OAuth) token getter function.""" - token = session.get("oauth") - log.debug("Token Get: %s", token) - return token - - @staticmethod - def _azure_parse_jwt(token): - """ - Parse Azure JWT token content. - - :param token: the JWT token - - :meta private: - """ - jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$" - matches = re2.search(jwt_token_parts, token) - if not matches or len(matches.groups()) < 3: - log.error("Unable to parse token.") - return {} - return { - "header": matches.group(1), - "Payload": matches.group(2), - "Sig": matches.group(3), - } - - @staticmethod - def _azure_jwt_token_parse(self, token): - """ - Parse and decode Azure JWT token. - - :param token: the JWT token - - :meta private: - """ - jwt_split_token = FabAirflowSecurityManagerOverrideOauth._azure_parse_jwt(token) - if not jwt_split_token: - return - - jwt_payload = jwt_split_token["Payload"] - # Prepare for base64 decoding - payload_b64_string = jwt_payload - payload_b64_string += "=" * (4 - (len(jwt_payload) % 4)) - decoded_payload = base64.urlsafe_b64decode(payload_b64_string.encode("ascii")) - - if not decoded_payload: - log.error("Payload of id_token could not be base64 url decoded.") - return - - jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8")) - - return jwt_decoded_payload diff --git a/airflow/auth/managers/fab/security_manager/override.py b/airflow/auth/managers/fab/security_manager/override.py index 011f43e86bb91..66032dc166834 100644 --- a/airflow/auth/managers/fab/security_manager/override.py +++ b/airflow/auth/managers/fab/security_manager/override.py @@ -17,26 +17,33 @@ # under the License. from __future__ import annotations +import base64 +import json import logging +import uuid import warnings from functools import cached_property from typing import TYPE_CHECKING -from flask import flash, g +import re2 +from flask import flash, g, session from flask_appbuilder import const from flask_appbuilder.const import AUTH_DB, AUTH_LDAP, AUTH_OAUTH, AUTH_OID, AUTH_REMOTE_USER +from flask_appbuilder.models.sqla import Base from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import lazy_gettext from flask_jwt_extended import JWTManager from flask_login import LoginManager from itsdangerous import want_bytes from markupsafe import Markup +from sqlalchemy import func, inspect, select +from sqlalchemy.exc import MultipleResultsFound from werkzeug.security import generate_password_hash +from airflow import AirflowException from airflow.auth.managers.fab.models import Action, Permission, RegisterUser, Resource, Role from airflow.auth.managers.fab.models.anonymous_user import AnonymousUser -from airflow.auth.managers.fab.security_manager.modules.db import FabAirflowSecurityManagerOverrideDb -from airflow.auth.managers.fab.security_manager.modules.oauth import FabAirflowSecurityManagerOverrideOauth +from airflow.www.security_manager import AirflowSecurityManagerV2 from airflow.www.session import AirflowDatabaseSessionInterface if TYPE_CHECKING: @@ -53,9 +60,7 @@ MAX_NUM_DATABASE_USER_SESSIONS = 50000 -class FabAirflowSecurityManagerOverride( - FabAirflowSecurityManagerOverrideDb, FabAirflowSecurityManagerOverrideOauth -): +class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): """ This security manager overrides the default AirflowSecurityManager security manager. @@ -64,29 +69,6 @@ class FabAirflowSecurityManagerOverride( the AirflowSecurityManager should be defined here instead of AirflowSecurityManager. :param appbuilder: The appbuilder. - :param actionmodelview: The obj instance for action model view. - :param authdbview: The class for auth db view. - :param authldapview: The class for auth ldap view. - :param authoauthview: The class for auth oauth view. - :param authoidview: The class for auth oid view. - :param authremoteuserview: The class for auth remote user view. - :param permissionmodelview: The class for permission model view. - :param registeruser_view: The class for register user view. - :param registeruserdbview: The class for register user db view. - :param registeruseroauthview: The class for register user oauth view. - :param registerusermodelview: The class for register user model view. - :param registeruseroidview: The class for register user oid view. - :param resetmypasswordview: The class for reset my password view. - :param resetpasswordview: The class for reset password view. - :param rolemodelview: The class for role model view. - :param user_model: The user model. - :param userinfoeditview: The class for user info edit view. - :param userdbmodelview: The class for user db model view. - :param userldapmodelview: The class for user ldap model view. - :param useroauthmodelview: The class for user oauth model view. - :param useroidmodelview: The class for user oid model view. - :param userremoteusermodelview: The class for user remote user model view. - :param userstatschartview: The class for user stats chart view. """ """ The obj instance for authentication view """ @@ -103,37 +85,16 @@ class FabAirflowSecurityManagerOverride( """ Initialized (remote_app) providers dict {'provider_name', OBJ } """ oauth_allow_list: dict[str, list] = {} - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.appbuilder = kwargs["appbuilder"] - self.actionmodelview = kwargs["actionmodelview"] - self.authdbview = kwargs["authdbview"] - self.authldapview = kwargs["authldapview"] - self.authoauthview = kwargs["authoauthview"] - self.authoidview = kwargs["authoidview"] - self.authremoteuserview = kwargs["authremoteuserview"] - self.permissionmodelview = kwargs["permissionmodelview"] - self.registeruser_view = kwargs["registeruser_view"] - self.registeruserdbview = kwargs["registeruserdbview"] - self.registeruseroauthview = kwargs["registeruseroauthview"] - self.registerusermodelview = kwargs["registerusermodelview"] - self.registeruseroidview = kwargs["registeruseroidview"] - self.resetmypasswordview = kwargs["resetmypasswordview"] - self.resetpasswordview = kwargs["resetpasswordview"] - self.rolemodelview = kwargs["rolemodelview"] - self.user_model = kwargs["user_model"] - self.userinfoeditview = kwargs["userinfoeditview"] - self.userdbmodelview = kwargs["userdbmodelview"] - self.userldapmodelview = kwargs["userldapmodelview"] - self.useroauthmodelview = kwargs["useroauthmodelview"] - self.useroidmodelview = kwargs["useroidmodelview"] - self.userremoteusermodelview = kwargs["userremoteusermodelview"] - self.userstatschartview = kwargs["userstatschartview"] + def __init__(self, appbuilder): + # done in super, but we need it before we can call super. + self.appbuilder = appbuilder self._init_config() self._init_auth() self._init_data_model() + # can only call super once data model init has been done + # because of the view.datamodel hack that's done in the init there. + super().__init__(appbuilder=appbuilder) self._builtin_roles: dict = self.create_builtin_roles() @@ -314,10 +275,6 @@ def reset_user_sessions(self, user: User) -> None: "warning", ) - def load_user(self, user_id): - """Load user by ID.""" - return self.get_user_by_id(int(user_id)) - def load_user_jwt(self, _jwt_header, jwt_data): identity = jwt_data["sub"] user = self.load_user(identity) @@ -444,7 +401,7 @@ def _init_auth(self): provider_name = provider["name"] log.debug("OAuth providers init %s", provider_name) obj_provider = self.oauth.register(provider_name, **provider["remote_app"]) - obj_provider._tokengetter = FabAirflowSecurityManagerOverrideOauth.oauth_token_getter + obj_provider._tokengetter = self.oauth_token_getter if not self.oauth_user_info: self.oauth_user_info = self.get_oauth_user_info # Whitelist only users with matching emails @@ -474,3 +431,650 @@ def _init_data_model(self): self.actionmodelview.datamodel = SQLAInterface(self.action_model) self.resourcemodelview.datamodel = SQLAInterface(self.resource_model) self.permissionmodelview.datamodel = SQLAInterface(self.permission_model) + + def create_db(self): + """ + Create the database. + + Creates admin and public roles if they don't exist. + """ + if not self.appbuilder.update_perms: + log.debug("Skipping db since appbuilder disables update_perms") + return + try: + engine = self.get_session.get_bind(mapper=None, clause=None) + inspector = inspect(engine) + if "ab_user" not in inspector.get_table_names(): + log.info(const.LOGMSG_INF_SEC_NO_DB) + Base.metadata.create_all(engine) + log.info(const.LOGMSG_INF_SEC_ADD_DB) + + roles_mapping = self.appbuilder.app.config.get("FAB_ROLES_MAPPING", {}) + for pk, name in roles_mapping.items(): + self.update_role(pk, name) + for role_name in self._builtin_roles: + self.add_role(role_name) + if self.auth_role_admin not in self._builtin_roles: + self.add_role(self.auth_role_admin) + self.add_role(self.auth_role_public) + if self.count_users() == 0 and self.auth_role_public != self.auth_role_admin: + log.warning(const.LOGMSG_WAR_SEC_NO_USER) + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_CREATE_DB, e) + exit(1) + + """ + ----------- + Role entity + ----------- + """ + + def update_role(self, role_id, name: str) -> Role | None: + """Update a role in the database.""" + role = self.get_session.get(self.role_model, role_id) + if not role: + return None + try: + role.name = name + self.get_session.merge(role) + self.get_session.commit() + log.info(const.LOGMSG_INF_SEC_UPD_ROLE, role) + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_UPD_ROLE, e) + self.get_session.rollback() + return None + return role + + def add_role(self, name: str) -> Role: + """Add a role in the database.""" + role = self.find_role(name) + if role is None: + try: + role = self.role_model() + role.name = name + self.get_session.add(role) + self.get_session.commit() + log.info(const.LOGMSG_INF_SEC_ADD_ROLE, name) + return role + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_ADD_ROLE, e) + self.get_session.rollback() + return role + + def find_role(self, name): + """ + Find a role in the database. + + :param name: the role name + """ + return self.get_session.query(self.role_model).filter_by(name=name).one_or_none() + + def get_all_roles(self): + return self.get_session.query(self.role_model).all() + + def get_public_role(self): + return self.get_session.query(self.role_model).filter_by(name=self.auth_role_public).one_or_none() + + def delete_role(self, role_name: str) -> None: + """ + Delete the given Role. + + :param role_name: the name of a role in the ab_role table + """ + session = self.get_session + role = session.query(Role).filter(Role.name == role_name).first() + if role: + log.info("Deleting role '%s'", role_name) + session.delete(role) + session.commit() + else: + raise AirflowException(f"Role named '{role_name}' does not exist") + + """ + ----------- + User entity + ----------- + """ + + def add_user( + self, + username, + first_name, + last_name, + email, + role, + password="", + hashed_password="", + ): + """Generic function to create user.""" + try: + user = self.user_model() + user.first_name = first_name + user.last_name = last_name + user.username = username + user.email = email + user.active = True + user.roles = role if isinstance(role, list) else [role] + if hashed_password: + user.password = hashed_password + else: + user.password = generate_password_hash(password) + self.get_session.add(user) + self.get_session.commit() + log.info(const.LOGMSG_INF_SEC_ADD_USER, username) + return user + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_ADD_USER, e) + self.get_session.rollback() + return False + + def load_user(self, user_id): + """Load user by ID.""" + return self.get_user_by_id(int(user_id)) + + def get_user_by_id(self, pk): + return self.get_session.get(self.user_model, pk) + + def count_users(self): + """Return the number of users in the database.""" + return self.get_session.query(func.count(self.user_model.id)).scalar() + + def add_register_user(self, username, first_name, last_name, email, password="", hashed_password=""): + """ + Add a registration request for the user. + + :rtype : RegisterUser + """ + register_user = self.registeruser_model() + register_user.username = username + register_user.email = email + register_user.first_name = first_name + register_user.last_name = last_name + if hashed_password: + register_user.password = hashed_password + else: + register_user.password = generate_password_hash(password) + register_user.registration_hash = str(uuid.uuid1()) + try: + self.get_session.add(register_user) + self.get_session.commit() + return register_user + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_ADD_REGISTER_USER, e) + self.get_session.rollback() + return None + + def find_user(self, username=None, email=None): + """Finds user by username or email.""" + if username: + try: + if self.auth_username_ci: + return ( + self.get_session.query(self.user_model) + .filter(func.lower(self.user_model.username) == func.lower(username)) + .one_or_none() + ) + else: + return ( + self.get_session.query(self.user_model) + .filter(func.lower(self.user_model.username) == func.lower(username)) + .one_or_none() + ) + except MultipleResultsFound: + log.error("Multiple results found for user %s", username) + return None + elif email: + try: + return self.get_session.query(self.user_model).filter_by(email=email).one_or_none() + except MultipleResultsFound: + log.error("Multiple results found for user with email %s", email) + return None + + def find_register_user(self, registration_hash): + return self.get_session.scalar( + select(self.registeruser_mode) + .where(self.registeruser_model.registration_hash == registration_hash) + .limit(1) + ) + + def update_user(self, user): + try: + self.get_session.merge(user) + self.get_session.commit() + log.info(const.LOGMSG_INF_SEC_UPD_USER, user) + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_UPD_USER, e) + self.get_session.rollback() + return False + + def del_register_user(self, register_user): + """ + Deletes registration object from database. + + :param register_user: RegisterUser object to delete + """ + try: + self.get_session.delete(register_user) + self.get_session.commit() + return True + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_DEL_REGISTER_USER, e) + self.get_session.rollback() + return False + + def get_all_users(self): + return self.get_session.query(self.user_model).all() + + """ + ------------- + Action entity + ------------- + """ + + def get_action(self, name: str) -> Action: + """ + Gets an existing action record. + + :param name: name + :return: Action record, if it exists + """ + return self.get_session.query(self.action_model).filter_by(name=name).one_or_none() + + def create_action(self, name): + """ + Adds an action to the backend, model action. + + :param name: + name of the action: 'can_add','can_edit' etc... + """ + action = self.get_action(name) + if action is None: + try: + action = self.action_model() + action.name = name + self.get_session.add(action) + self.get_session.commit() + return action + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_ADD_PERMISSION, e) + self.get_session.rollback() + return action + + def delete_action(self, name: str) -> bool: + """ + Deletes a permission action. + + :param name: Name of action to delete (e.g. can_read). + """ + action = self.get_action(name) + if not action: + log.warning(const.LOGMSG_WAR_SEC_DEL_PERMISSION, name) + return False + try: + perms = ( + self.get_session.query(self.permission_model) + .filter(self.permission_model.action == action) + .all() + ) + if perms: + log.warning(const.LOGMSG_WAR_SEC_DEL_PERM_PVM, action, perms) + return False + self.get_session.delete(action) + self.get_session.commit() + return True + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_DEL_PERMISSION, e) + self.get_session.rollback() + return False + + """ + --------------- + Resource entity + --------------- + """ + + def get_resource(self, name: str) -> Resource: + """ + Returns a resource record by name, if it exists. + + :param name: Name of resource + """ + return self.get_session.query(self.resource_model).filter_by(name=name).one_or_none() + + def create_resource(self, name) -> Resource: + """ + Create a resource with the given name. + + :param name: The name of the resource to create created. + :return: The FAB resource created. + """ + resource = self.get_resource(name) + if resource is None: + try: + resource = self.resource_model() + resource.name = name + self.get_session.add(resource) + self.get_session.commit() + return resource + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_ADD_VIEWMENU, e) + self.get_session.rollback() + return resource + + def get_all_resources(self) -> list[Resource]: + """ + Gets all existing resource records. + + :return: List of all resources + """ + return self.get_session.query(self.resource_model).all() + + def delete_resource(self, name: str) -> bool: + """ + Deletes a Resource from the backend. + + :param name: + name of the resource + """ + resource = self.get_resource(name) + if not resource: + log.warning(const.LOGMSG_WAR_SEC_DEL_VIEWMENU, name) + return False + try: + perms = ( + self.get_session.query(self.permission_model) + .filter(self.permission_model.resource == resource) + .all() + ) + if perms: + log.warning(const.LOGMSG_WAR_SEC_DEL_VIEWMENU_PVM, resource, perms) + return False + self.get_session.delete(resource) + self.get_session.commit() + return True + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_DEL_PERMISSION, e) + self.get_session.rollback() + return False + + """ + --------------- + Permission entity + --------------- + """ + + def get_permission( + self, + action_name: str, + resource_name: str, + ) -> Permission | None: + """ + Gets a permission made with the given action->resource pair, if the permission already exists. + + :param action_name: Name of action + :param resource_name: Name of resource + :return: The existing permission + """ + action = self.get_action(action_name) + resource = self.get_resource(resource_name) + if action and resource: + return ( + self.get_session.query(self.permission_model) + .filter_by(action=action, resource=resource) + .one_or_none() + ) + return None + + def get_resource_permissions(self, resource: Resource) -> Permission: + """ + Retrieve permission pairs associated with a specific resource object. + + :param resource: Object representing a single resource. + :return: Action objects representing resource->action pair + """ + return self.get_session.query(self.permission_model).filter_by(resource_id=resource.id).all() + + def create_permission(self, action_name, resource_name) -> Permission | None: + """ + Adds a permission on a resource to the backend. + + :param action_name: + name of the action to add: 'can_add','can_edit' etc... + :param resource_name: + name of the resource to add + """ + if not (action_name and resource_name): + return None + perm = self.get_permission(action_name, resource_name) + if perm: + return perm + resource = self.create_resource(resource_name) + action = self.create_action(action_name) + perm = self.permission_model() + perm.resource_id, perm.action_id = resource.id, action.id + try: + self.get_session.add(perm) + self.get_session.commit() + log.info(const.LOGMSG_INF_SEC_ADD_PERMVIEW, perm) + return perm + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_ADD_PERMVIEW, e) + self.get_session.rollback() + return None + + def delete_permission(self, action_name: str, resource_name: str) -> None: + """ + Deletes the permission linking an action->resource pair. + + Doesn't delete the underlying action or resource. + + :param action_name: Name of existing action + :param resource_name: Name of existing resource + :return: None + """ + if not (action_name and resource_name): + return + perm = self.get_permission(action_name, resource_name) + if not perm: + return + roles = ( + self.get_session.query(self.role_model).filter(self.role_model.permissions.contains(perm)).first() + ) + if roles: + log.warning(const.LOGMSG_WAR_SEC_DEL_PERMVIEW, resource_name, action_name, roles) + return + try: + # delete permission on resource + self.get_session.delete(perm) + self.get_session.commit() + # if no more permission on permission view, delete permission + if not self.get_session.query(self.permission_model).filter_by(action=perm.action).all(): + self.delete_action(perm.action.name) + log.info(const.LOGMSG_INF_SEC_DEL_PERMVIEW, action_name, resource_name) + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_DEL_PERMVIEW, e) + self.get_session.rollback() + + def add_permission_to_role(self, role: Role, permission: Permission | None) -> None: + """ + Add an existing permission pair to a role. + + :param role: The role about to get a new permission. + :param permission: The permission pair to add to a role. + :return: None + """ + if permission and permission not in role.permissions: + try: + role.permissions.append(permission) + self.get_session.merge(role) + self.get_session.commit() + log.info(const.LOGMSG_INF_SEC_ADD_PERMROLE, permission, role.name) + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_ADD_PERMROLE, e) + self.get_session.rollback() + + def remove_permission_from_role(self, role: Role, permission: Permission) -> None: + """ + Remove a permission pair from a role. + + :param role: User role containing permissions. + :param permission: Object representing resource-> action pair + """ + if permission in role.permissions: + try: + role.permissions.remove(permission) + self.get_session.merge(role) + self.get_session.commit() + log.info(const.LOGMSG_INF_SEC_DEL_PERMROLE, permission, role.name) + except Exception as e: + log.error(const.LOGMSG_ERR_SEC_DEL_PERMROLE, e) + self.get_session.rollback() + + def get_oauth_user_info(self, provider, resp): + """ + Get the OAuth user information from different OAuth APIs. + + All providers have different ways to retrieve user info. + """ + # for GITHUB + if provider == "github" or provider == "githublocal": + me = self.oauth_remotes[provider].get("user") + data = me.json() + log.debug("User info from GitHub: %s", data) + return {"username": "github_" + data.get("login")} + # for twitter + if provider == "twitter": + me = self.oauth_remotes[provider].get("account/settings.json") + data = me.json() + log.debug("User info from Twitter: %s", data) + return {"username": "twitter_" + data.get("screen_name", "")} + # for linkedin + if provider == "linkedin": + me = self.oauth_remotes[provider].get( + "people/~:(id,email-address,first-name,last-name)?format=json" + ) + data = me.json() + log.debug("User info from LinkedIn: %s", data) + return { + "username": "linkedin_" + data.get("id", ""), + "email": data.get("email-address", ""), + "first_name": data.get("firstName", ""), + "last_name": data.get("lastName", ""), + } + # for Google + if provider == "google": + me = self.oauth_remotes[provider].get("userinfo") + data = me.json() + log.debug("User info from Google: %s", data) + return { + "username": "google_" + data.get("id", ""), + "first_name": data.get("given_name", ""), + "last_name": data.get("family_name", ""), + "email": data.get("email", ""), + } + # for Azure AD Tenant. Azure OAuth response contains + # JWT token which has user info. + # JWT token needs to be base64 decoded. + # https://docs.microsoft.com/en-us/azure/active-directory/develop/ + # active-directory-protocols-oauth-code + if provider == "azure": + log.debug("Azure response received : %s", resp) + id_token = resp["id_token"] + log.debug(str(id_token)) + me = self._azure_jwt_token_parse(id_token) + log.debug("Parse JWT token : %s", me) + return { + "name": me.get("name", ""), + "email": me["upn"], + "first_name": me.get("given_name", ""), + "last_name": me.get("family_name", ""), + "id": me["oid"], + "username": me["oid"], + "role_keys": me.get("roles", []), + } + # for OpenShift + if provider == "openshift": + me = self.oauth_remotes[provider].get("apis/user.openshift.io/v1/users/~") + data = me.json() + log.debug("User info from OpenShift: %s", data) + return {"username": "openshift_" + data.get("metadata").get("name")} + # for Okta + if provider == "okta": + me = self.oauth_remotes[provider].get("userinfo") + data = me.json() + log.debug("User info from Okta: %s", data) + return { + "username": "okta_" + data.get("sub", ""), + "first_name": data.get("given_name", ""), + "last_name": data.get("family_name", ""), + "email": data.get("email", ""), + "role_keys": data.get("groups", []), + } + # for Keycloak + if provider in ["keycloak", "keycloak_before_17"]: + me = self.oauth_remotes[provider].get("openid-connect/userinfo") + me.raise_for_status() + data = me.json() + log.debug("User info from Keycloak: %s", data) + return { + "username": data.get("preferred_username", ""), + "first_name": data.get("given_name", ""), + "last_name": data.get("family_name", ""), + "email": data.get("email", ""), + } + else: + return {} + + @staticmethod + def oauth_token_getter(): + """Authentication (OAuth) token getter function.""" + token = session.get("oauth") + log.debug("Token Get: %s", token) + return token + + @staticmethod + def _azure_parse_jwt(token): + """ + Parse Azure JWT token content. + + :param token: the JWT token + + :meta private: + """ + jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$" + matches = re2.search(jwt_token_parts, token) + if not matches or len(matches.groups()) < 3: + log.error("Unable to parse token.") + return {} + return { + "header": matches.group(1), + "Payload": matches.group(2), + "Sig": matches.group(3), + } + + @staticmethod + def _azure_jwt_token_parse(token): + """ + Parse and decode Azure JWT token. + + :param token: the JWT token + + :meta private: + """ + jwt_split_token = FabAirflowSecurityManagerOverride._azure_parse_jwt(token) + if not jwt_split_token: + return + + jwt_payload = jwt_split_token["Payload"] + # Prepare for base64 decoding + payload_b64_string = jwt_payload + payload_b64_string += "=" * (4 - (len(jwt_payload) % 4)) + decoded_payload = base64.urlsafe_b64decode(payload_b64_string.encode("ascii")) + + if not decoded_payload: + log.error("Payload of id_token could not be base64 url decoded.") + return + + jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8")) + + return jwt_decoded_payload diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 7a926f273693a..b603cf08be7c0 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -692,7 +692,7 @@ def _sync_perm_for_dag(cls, dag: DAG, session: Session = NEW_SESSION): root_dag_id = dag.parent_dag.dag_id if dag.parent_dag else dag.dag_id cls.logger().debug("Syncing DAG permissions: %s to the DB", root_dag_id) - from airflow.www.security import ApplessAirflowSecurityManager + from airflow.www.security_appless import ApplessAirflowSecurityManager security_manager = ApplessAirflowSecurityManager(session=session) security_manager.sync_perm_for_dag(root_dag_id, dag.access_control) diff --git a/airflow/www/extensions/init_appbuilder.py b/airflow/www/extensions/init_appbuilder.py index 2f8fe473aa846..12474142d1893 100644 --- a/airflow/www/extensions/init_appbuilder.py +++ b/airflow/www/extensions/init_appbuilder.py @@ -40,13 +40,15 @@ from airflow import settings from airflow.configuration import conf -from airflow.www.extensions.init_auth_manager import get_auth_manager +from airflow.www.extensions.init_auth_manager import get_auth_manager, init_auth_manager if TYPE_CHECKING: + from flask import Flask from flask_appbuilder import BaseView from flask_appbuilder.security.manager import BaseSecurityManager from sqlalchemy.orm import Session + # This product contains a modified portion of 'Flask App Builder' developed by Daniel Vaz Gaspar. # (https://github.com/dpgaspar/Flask-AppBuilder). # Copyright 2013, Daniel Vaz Gaspar @@ -655,22 +657,13 @@ def _process_inner_views(self): view.get_init_inner_views().append(v) -def init_appbuilder(app) -> AirflowAppBuilder: +def init_appbuilder(app: Flask) -> AirflowAppBuilder: """Init `Flask App Builder `__.""" - from airflow.www.security import AirflowSecurityManager - - security_manager_class = app.config.get("SECURITY_MANAGER_CLASS") or AirflowSecurityManager - - if not issubclass(security_manager_class, AirflowSecurityManager): - raise Exception( - """Your CUSTOM_SECURITY_MANAGER must now extend AirflowSecurityManager, - not FAB's security manager.""" - ) - + auth_manager = init_auth_manager(app) return AirflowAppBuilder( app=app, session=settings.Session, - security_manager_class=security_manager_class, + security_manager_class=auth_manager.get_security_manager_override_class(), base_template="airflow/main.html", update_perms=conf.getboolean("webserver", "UPDATE_FAB_PERMS"), auth_rate_limited=conf.getboolean("webserver", "AUTH_RATE_LIMITED", fallback=True), diff --git a/airflow/www/extensions/init_auth_manager.py b/airflow/www/extensions/init_auth_manager.py index 24ae020862dc9..32db0f2cdc907 100644 --- a/airflow/www/extensions/init_auth_manager.py +++ b/airflow/www/extensions/init_auth_manager.py @@ -18,13 +18,16 @@ from typing import TYPE_CHECKING -from airflow.compat.functools import cache from airflow.configuration import conf from airflow.exceptions import AirflowConfigException if TYPE_CHECKING: + from flask import Flask + from airflow.auth.managers.base_auth_manager import BaseAuthManager +auth_manager: BaseAuthManager | None = None + def get_auth_manager_cls() -> type[BaseAuthManager]: """Returns just the auth manager class without initializing it. @@ -42,13 +45,22 @@ def get_auth_manager_cls() -> type[BaseAuthManager]: return auth_manager_cls -@cache -def get_auth_manager() -> BaseAuthManager: - """ - Initialize auth manager. +def init_auth_manager(app: Flask) -> BaseAuthManager: + """Initialize the auth manager with the given flask app object. - Import the user manager class, instantiate it and return it. + Import the user manager class and instantiate it. """ + global auth_manager auth_manager_cls = get_auth_manager_cls() + auth_manager = auth_manager_cls(app) + return auth_manager - return auth_manager_cls() + +def get_auth_manager() -> BaseAuthManager: + """Returns the auth manager, provided it's been initialized before.""" + if auth_manager is None: + raise Exception( + "Auth Manager has not been initialized yet. " + "The `init_auth_manager` method needs to be called first." + ) + return auth_manager diff --git a/airflow/www/fab_security/manager.py b/airflow/www/fab_security/manager.py index 0b2eb3492a1ca..2015a572ed2bf 100644 --- a/airflow/www/fab_security/manager.py +++ b/airflow/www/fab_security/manager.py @@ -196,6 +196,9 @@ def get_roles_from_keys(self, role_keys: list[str]) -> set[Role]: log.warning("Can't find role specified in AUTH_ROLES_MAPPING: %s", fab_role_name) return _roles + def add_role(self, name: str) -> Role: + raise NotImplementedError + @property def auth_type_provider_name(self): provider_to_auth_type = {AUTH_DB: "db", AUTH_LDAP: "ldap"} @@ -1074,6 +1077,12 @@ def add_permissions_menu(self, resource_name): role_admin = self.find_role(self.auth_role_admin) self.add_permission_to_role(role_admin, perm) + def get_resource(self, name: str) -> Resource: + raise NotImplementedError + + def get_action(self, name: str) -> Action: + raise NotImplementedError + def security_cleanup(self, baseviews, menus): """ Will cleanup all unused permissions from the database. diff --git a/airflow/www/security.py b/airflow/www/security.py index 0a8d6cbf1e8ef..0f0a0f5aed310 100644 --- a/airflow/www/security.py +++ b/airflow/www/security.py @@ -16,785 +16,25 @@ # under the License. from __future__ import annotations -import warnings -from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, Sequence +from typing import TYPE_CHECKING -from flask import g -from sqlalchemy import or_, select -from sqlalchemy.orm import joinedload +from deprecated import deprecated -from airflow.auth.managers.fab.models import Permission, Resource, Role, User -from airflow.auth.managers.fab.views.permissions import ( - ActionModelView, - PermissionPairModelView, - ResourceModelView, -) -from airflow.auth.managers.fab.views.roles_list import CustomRoleModelView -from airflow.auth.managers.fab.views.user import ( - CustomUserDBModelView, - CustomUserLDAPModelView, - CustomUserOAuthModelView, - CustomUserOIDModelView, - CustomUserRemoteUserModelView, -) -from airflow.auth.managers.fab.views.user_edit import ( - CustomResetMyPasswordView, - CustomResetPasswordView, - CustomUserInfoEditView, -) -from airflow.auth.managers.fab.views.user_stats import CustomUserStatsChartView -from airflow.exceptions import AirflowException, RemovedInAirflow3Warning -from airflow.models import DagBag, DagModel -from airflow.security import permissions -from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.www.extensions.init_auth_manager import get_auth_manager -from airflow.www.fab_security.sqla.manager import SecurityManager -from airflow.www.utils import CustomSQLAInterface - -EXISTING_ROLES = { - "Admin", - "Viewer", - "User", - "Op", - "Public", -} +from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride if TYPE_CHECKING: - from sqlalchemy.orm import Session - - SecurityManagerOverride: type = object -else: - # Fetch the security manager override from the auth manager - SecurityManagerOverride = get_auth_manager().get_security_manager_override_class() - - -class AirflowSecurityManager(SecurityManagerOverride, SecurityManager, LoggingMixin): - """Custom security manager, which introduces a permission model adapted to Airflow.""" - - ########################################################################### - # PERMISSIONS - ########################################################################### - - # [START security_viewer_perms] - VIEWER_PERMISSIONS = [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_DEPENDENCIES), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_JOB), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_MY_PASSWORD), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_MY_PASSWORD), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_MY_PROFILE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_MY_PROFILE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_SLA_MISS), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_BROWSE_MENU), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_DEPENDENCIES), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DATASET), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS_MENU), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_JOB), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_AUDIT_LOG), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_PLUGIN), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_SLA_MISS), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_TASK_INSTANCE), - ] - # [END security_viewer_perms] - - # [START security_user_perms] - USER_PERMISSIONS = [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ] - # [END security_user_perms] - - # [START security_op_perms] - OP_PERMISSIONS = [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_ADMIN_MENU), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CONFIG), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_XCOM), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_XCOM), - ] - # [END security_op_perms] - - ADMIN_PERMISSIONS = [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_RESCHEDULE), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_TASK_RESCHEDULE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TRIGGER), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_TRIGGER), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_PASSWORD), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_PASSWORD), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE), - ] - - # global resource for dag-level access - DAG_RESOURCES = {permissions.RESOURCE_DAG} - DAG_ACTIONS = permissions.DAG_ACTIONS - - ########################################################################### - # DEFAULT ROLE CONFIGURATIONS - ########################################################################### - - ROLE_CONFIGS: list[dict[str, Any]] = [ - {"role": "Public", "perms": []}, - {"role": "Viewer", "perms": VIEWER_PERMISSIONS}, - { - "role": "User", - "perms": VIEWER_PERMISSIONS + USER_PERMISSIONS, - }, - { - "role": "Op", - "perms": VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS, - }, - { - "role": "Admin", - "perms": VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS + ADMIN_PERMISSIONS, - }, - ] - - actionmodelview = ActionModelView - permissionmodelview = PermissionPairModelView - rolemodelview = CustomRoleModelView - resourcemodelview = ResourceModelView - userdbmodelview = CustomUserDBModelView - resetmypasswordview = CustomResetMyPasswordView - resetpasswordview = CustomResetPasswordView - userinfoeditview = CustomUserInfoEditView - userldapmodelview = CustomUserLDAPModelView - useroauthmodelview = CustomUserOAuthModelView - userremoteusermodelview = CustomUserRemoteUserModelView - useroidmodelview = CustomUserOIDModelView - userstatschartview = CustomUserStatsChartView - - def __init__(self, appbuilder) -> None: - super().__init__( - appbuilder=appbuilder, - actionmodelview=self.actionmodelview, - authdbview=self.authdbview, - authldapview=self.authldapview, - authoauthview=self.authoauthview, - authoidview=self.authoidview, - authremoteuserview=self.authremoteuserview, - permissionmodelview=self.permissionmodelview, - registeruser_view=self.registeruser_view, - registeruserdbview=self.registeruserdbview, - registeruseroauthview=self.registeruseroauthview, - registerusermodelview=self.registerusermodelview, - registeruseroidview=self.registeruseroidview, - resetmypasswordview=self.resetmypasswordview, - resetpasswordview=self.resetpasswordview, - rolemodelview=self.rolemodelview, - user_model=self.user_model, - userinfoeditview=self.userinfoeditview, - userdbmodelview=self.userdbmodelview, - userldapmodelview=self.userldapmodelview, - useroauthmodelview=self.useroauthmodelview, - useroidmodelview=self.useroidmodelview, - userremoteusermodelview=self.userremoteusermodelview, - userstatschartview=self.userstatschartview, - ) - - # Go and fix up the SQLAInterface used from the stock one to our subclass. - # This is needed to support the "hack" where we had to edit - # FieldConverter.conversion_table in place in airflow.www.utils - for attr in dir(self): - if attr.endswith("view"): - view = getattr(self, attr, None) - if view and getattr(view, "datamodel", None): - view.datamodel = CustomSQLAInterface(view.datamodel.obj) - self.perms = None - - def _get_root_dag_id(self, dag_id: str) -> str: - if "." in dag_id: - dm = self.appbuilder.get_session.execute( - select(DagModel.dag_id, DagModel.root_dag_id).where(DagModel.dag_id == dag_id) - ).one() - return dm.root_dag_id or dm.dag_id - return dag_id - - def init_role(self, role_name, perms) -> None: - """ - Initialize the role with actions and related resources. - - :param role_name: - :param perms: - :return: - """ - warnings.warn( - "`init_role` has been deprecated. Please use `bulk_sync_roles` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - self.bulk_sync_roles([{"role": role_name, "perms": perms}]) - - def bulk_sync_roles(self, roles: Iterable[dict[str, Any]]) -> None: - """Sync the provided roles and permissions.""" - existing_roles = self._get_all_roles_with_permissions() - non_dag_perms = self._get_all_non_dag_permissions() - - for config in roles: - role_name = config["role"] - perms = config["perms"] - role = existing_roles.get(role_name) or self.add_role(role_name) - - for action_name, resource_name in perms: - perm = non_dag_perms.get((action_name, resource_name)) or self.create_permission( - action_name, resource_name - ) - - if perm not in role.permissions: - self.add_permission_to_role(role, perm) - - @staticmethod - def get_user_roles(user=None): - """ - Get all the roles associated with the user. - - :param user: the ab_user in FAB model. - :return: a list of roles associated with the user. - """ - if user is None: - user = g.user - return user.roles - - def get_readable_dags(self, user) -> Iterable[DagModel]: - """Gets the DAGs readable by authenticated user.""" - warnings.warn( - "`get_readable_dags` has been deprecated. Please use `get_readable_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) - - def get_editable_dags(self, user) -> Iterable[DagModel]: - """Gets the DAGs editable by authenticated user.""" - warnings.warn( - "`get_editable_dags` has been deprecated. Please use `get_editable_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) - - @provide_session - def get_accessible_dags( - self, - user_actions: Container[str] | None, - user, - session: Session = NEW_SESSION, - ) -> Iterable[DagModel]: - warnings.warn( - "`get_accessible_dags` has been deprecated. Please use `get_accessible_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=3, - ) - dag_ids = self.get_accessible_dag_ids(user, user_actions, session) - return session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids))) - - def get_readable_dag_ids(self, user) -> set[str]: - """Gets the DAG IDs readable by authenticated user.""" - return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_READ]) - - def get_editable_dag_ids(self, user) -> set[str]: - """Gets the DAG IDs editable by authenticated user.""" - return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_EDIT]) - - @provide_session - def get_accessible_dag_ids( - self, - user, - user_actions: Container[str] | None = None, - session: Session = NEW_SESSION, - ) -> set[str]: - """Generic function to get readable or writable DAGs for user.""" - if not user_actions: - user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] - - if not get_auth_manager().is_logged_in(): - roles = user.roles - else: - if (permissions.ACTION_CAN_EDIT in user_actions and self.can_edit_all_dags(user)) or ( - permissions.ACTION_CAN_READ in user_actions and self.can_read_all_dags(user) - ): - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - user_query = session.scalar( - select(User) - .options( - joinedload(User.roles) - .subqueryload(Role.permissions) - .options(joinedload(Permission.action), joinedload(Permission.resource)) - ) - .where(User.id == user.id) - ) - roles = user_query.roles - - resources = set() - for role in roles: - for permission in role.permissions: - action = permission.action.name - if action in user_actions: - resource = permission.resource.name - if resource == permissions.RESOURCE_DAG: - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - - if resource.startswith(permissions.RESOURCE_DAG_PREFIX): - resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :]) - else: - resources.add(resource) - return { - dag.dag_id - for dag in session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))) - } - - def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: - """Checks if user has read or write access to some dags.""" - if dag_id and dag_id != "~": - root_dag_id = self._get_root_dag_id(dag_id) - return self.has_access(action, permissions.resource_name_for_dag(root_dag_id)) - - user = g.user - if action == permissions.ACTION_CAN_READ: - return any(self.get_readable_dag_ids(user)) - return any(self.get_editable_dag_ids(user)) - - def can_read_dag(self, dag_id: str, user=None) -> bool: - """Determines whether a user has DAG read access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_READ, dag_resource_name, user=user) - - def can_edit_dag(self, dag_id: str, user=None) -> bool: - """Determines whether a user has DAG edit access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_EDIT, dag_resource_name, user=user) - - def can_delete_dag(self, dag_id: str, user=None) -> bool: - """Determines whether a user has DAG delete access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_DELETE, dag_resource_name, user=user) - - def prefixed_dag_id(self, dag_id: str) -> str: - """Returns the permission name for a DAG id.""" - warnings.warn( - "`prefixed_dag_id` has been deprecated. " - "Please use `airflow.security.permissions.resource_name_for_dag` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - root_dag_id = self._get_root_dag_id(dag_id) - return permissions.resource_name_for_dag(root_dag_id) - - def is_dag_resource(self, resource_name: str) -> bool: - """Determines if a resource belongs to a DAG or all DAGs.""" - if resource_name == permissions.RESOURCE_DAG: - return True - return resource_name.startswith(permissions.RESOURCE_DAG_PREFIX) - - def has_access(self, action_name: str, resource_name: str, user=None) -> bool: - """ - Verify whether a given user could perform a certain action on the given resource. - - Example actions might include can_read, can_write, can_delete, etc. - - :param action_name: action_name on resource (e.g can_read, can_edit). - :param resource_name: name of view-menu or resource. - :param user: user name - :return: Whether user could perform certain action on the resource. - :rtype bool - """ - if not user: - user = g.user - if (action_name, resource_name) in user.perms: - return True - - if self.is_dag_resource(resource_name): - if (action_name, permissions.RESOURCE_DAG) in user.perms: - return True - return (action_name, resource_name) in user.perms - - return False - - def _has_role(self, role_name_or_list: Container, user) -> bool: - """Whether the user has this role name.""" - if not isinstance(role_name_or_list, list): - role_name_or_list = [role_name_or_list] - return any(r.name in role_name_or_list for r in user.roles) - - def has_all_dags_access(self, user) -> bool: - """ - Has all the dag access in any of the 3 cases. - - 1. Role needs to be in (Admin, Viewer, User, Op). - 2. Has can_read action on dags resource. - 3. Has can_edit action on dags resource. - """ - if not user: - user = g.user - return ( - self._has_role(["Admin", "Viewer", "Op", "User"], user) - or self.can_read_all_dags(user) - or self.can_edit_all_dags(user) - ) - - def can_edit_all_dags(self, user=None) -> bool: - """Has can_edit action on DAG resource.""" - return self.has_access(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG, user) - - def can_read_all_dags(self, user=None) -> bool: - """Has can_read action on DAG resource.""" - return self.has_access(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG, user) + pass - def clean_perms(self) -> None: - """FAB leaves faulty permissions that need to be cleaned up.""" - self.log.debug("Cleaning faulty perms") - sesh = self.appbuilder.get_session - perms = sesh.query(Permission).filter( - or_( - Permission.action == None, # noqa - Permission.resource == None, # noqa - ) - ) - # Since FAB doesn't define ON DELETE CASCADE on these tables, we need - # to delete the _object_ so that SQLA knows to delete the many-to-many - # relationship object too. :( - deleted_count = 0 - for perm in perms: - sesh.delete(perm) - deleted_count += 1 - sesh.commit() - if deleted_count: - self.log.info("Deleted %s faulty permissions", deleted_count) - - def _merge_perm(self, action_name: str, resource_name: str) -> None: - """ - Add the new (action, resource) to assoc_permission_role if it doesn't exist. - - It will add the related entry to ab_permission and ab_resource two meta tables as well. - - :param action_name: Name of the action - :param resource_name: Name of the resource - :return: - """ - action = self.get_action(action_name) - resource = self.get_resource(resource_name) - perm = None - if action and resource: - perm = self.appbuilder.get_session.scalar( - select(self.permission_model).filter_by(action=action, resource=resource).limit(1) - ) - if not perm and action_name and resource_name: - self.create_permission(action_name, resource_name) - - def add_homepage_access_to_custom_roles(self) -> None: - """ - Add Website.can_read access to all custom roles. - - :return: None. - """ - website_permission = self.create_permission(permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE) - custom_roles = [role for role in self.get_all_roles() if role.name not in EXISTING_ROLES] - for role in custom_roles: - self.add_permission_to_role(role, website_permission) - - self.appbuilder.get_session.commit() - - def get_all_permissions(self) -> set[tuple[str, str]]: - """Returns all permissions as a set of tuples with the action and resource names.""" - return set( - self.appbuilder.get_session.execute( - select(self.action_model.name, self.resource_model.name) - .join(self.permission_model.action) - .join(self.permission_model.resource) - ) - ) - - def _get_all_non_dag_permissions(self) -> dict[tuple[str, str], Permission]: - """ - Get permissions except those that are for specific DAGs. - - Returns a dict with a key of (action_name, resource_name) and value of permission - with all permissions except those that are for specific DAGs. - """ - return { - (action_name, resource_name): viewmodel - for action_name, resource_name, viewmodel in ( - self.appbuilder.get_session.execute( - select(self.action_model.name, self.resource_model.name, self.permission_model) - .join(self.permission_model.action) - .join(self.permission_model.resource) - .where(~self.resource_model.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%")) - ) - ) - } - - def _get_all_roles_with_permissions(self) -> dict[str, Role]: - """Returns a dict with a key of role name and value of role with early loaded permissions.""" - return { - r.name: r - for r in self.appbuilder.get_session.scalars( - select(self.role_model).options(joinedload(self.role_model.permissions)) - ).unique() - } - - def create_dag_specific_permissions(self) -> None: - """ - Add permissions to all DAGs. - - Creates 'can_read', 'can_edit', and 'can_delete' permissions for all - DAGs, along with any `access_control` permissions provided in them. - - This does iterate through ALL the DAGs, which can be slow. See `sync_perm_for_dag` - if you only need to sync a single DAG. - - :return: None. - """ - perms = self.get_all_permissions() - dagbag = DagBag(read_dags_from_db=True) - dagbag.collect_dags_from_db() - dags = dagbag.dags.values() - - for dag in dags: - root_dag_id = dag.parent_dag.dag_id if dag.parent_dag else dag.dag_id - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - for action_name in self.DAG_ACTIONS: - if (action_name, dag_resource_name) not in perms: - self._merge_perm(action_name, dag_resource_name) - - if dag.access_control is not None: - self.sync_perm_for_dag(dag_resource_name, dag.access_control) - - def update_admin_permission(self) -> None: - """ - Add missing permissions to the table for admin. - - Admin should get all the permissions, except the dag permissions - because Admin already has Dags permission. - Add the missing ones to the table for admin. - - :return: None. - """ - session = self.appbuilder.get_session - dag_resources = session.scalars( - select(Resource).where(Resource.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%")) - ) - resource_ids = [resource.id for resource in dag_resources] - - perms = session.scalars(select(Permission).where(~Permission.resource_id.in_(resource_ids))) - perms = [p for p in perms if p.action and p.resource] - - admin = self.find_role("Admin") - admin.permissions = list(set(admin.permissions) | set(perms)) - - session.commit() - - def sync_roles(self) -> None: - """ - Initialize default and custom roles with related permissions. - - 1. Init the default role(Admin, Viewer, User, Op, public) - with related permissions. - 2. Init the custom role(dag-user) with related permissions. - - :return: None. - """ - # Create global all-dag permissions - self.create_perm_vm_for_all_dag() - - # Sync the default roles (Admin, Viewer, User, Op, public) with related permissions - self.bulk_sync_roles(self.ROLE_CONFIGS) - - self.add_homepage_access_to_custom_roles() - # init existing roles, the rest role could be created through UI. - self.update_admin_permission() - self.clean_perms() - - def sync_resource_permissions(self, perms: Iterable[tuple[str, str]] | None = None) -> None: - """Populates resource-based permissions.""" - if not perms: - return - - for action_name, resource_name in perms: - self.create_resource(resource_name) - self.create_permission(action_name, resource_name) - - def sync_perm_for_dag( - self, - dag_id: str, - access_control: dict[str, Collection[str]] | None = None, - ) -> None: - """ - Sync permissions for given dag id. - - The dag id surely exists in our dag bag as only / refresh button or DagBag will call this function. - - :param dag_id: the ID of the DAG whose permissions should be updated - :param access_control: a dict where each key is a rolename and - each value is a set() of action names (e.g., - {'can_read'} - :return: - """ - dag_resource_name = permissions.resource_name_for_dag(dag_id) - for dag_action_name in self.DAG_ACTIONS: - self.create_permission(dag_action_name, dag_resource_name) - - if access_control is not None: - self.log.debug("Syncing DAG-level permissions for DAG '%s'", dag_resource_name) - self._sync_dag_view_permissions(dag_resource_name, access_control) - else: - self.log.debug( - "Not syncing DAG-level permissions for DAG '%s' as access control is unset.", - dag_resource_name, - ) - - def _sync_dag_view_permissions(self, dag_id: str, access_control: dict[str, Collection[str]]) -> None: - """ - Set the access policy on the given DAG's ViewModel. - - :param dag_id: the ID of the DAG whose permissions should be updated - :param access_control: a dict where each key is a rolename and - each value is a set() of action names (e.g. {'can_read'}) - """ - dag_resource_name = permissions.resource_name_for_dag(dag_id) - - def _get_or_create_dag_permission(action_name: str) -> Permission | None: - perm = self.get_permission(action_name, dag_resource_name) - if not perm: - self.log.info("Creating new action '%s' on resource '%s'", action_name, dag_resource_name) - perm = self.create_permission(action_name, dag_resource_name) - - return perm - - def _revoke_stale_permissions(resource: Resource): - existing_dag_perms = self.get_resource_permissions(resource) - for perm in existing_dag_perms: - non_admin_roles = [role for role in perm.role if role.name != "Admin"] - for role in non_admin_roles: - target_perms_for_role = access_control.get(role.name, ()) - if perm.action.name not in target_perms_for_role: - self.log.info( - "Revoking '%s' on DAG '%s' for role '%s'", - perm.action, - dag_resource_name, - role.name, - ) - self.remove_permission_from_role(role, perm) - - resource = self.get_resource(dag_resource_name) - if resource: - _revoke_stale_permissions(resource) - - for rolename, action_names in access_control.items(): - role = self.find_role(rolename) - if not role: - raise AirflowException( - f"The access_control mapping for DAG '{dag_id}' includes a role named " - f"'{rolename}', but that role does not exist" - ) - - action_names = set(action_names) - invalid_action_names = action_names - self.DAG_ACTIONS - if invalid_action_names: - raise AirflowException( - f"The access_control map for DAG '{dag_resource_name}' includes " - f"the following invalid permissions: {invalid_action_names}; " - f"The set of valid permissions is: {self.DAG_ACTIONS}" - ) - - for action_name in action_names: - dag_perm = _get_or_create_dag_permission(action_name) - if dag_perm: - self.add_permission_to_role(role, dag_perm) - - def create_perm_vm_for_all_dag(self) -> None: - """Create perm-vm if not exist and insert into FAB security model for all-dags.""" - # create perm for global logical dag - for resource_name in self.DAG_RESOURCES: - for action_name in self.DAG_ACTIONS: - self._merge_perm(action_name, resource_name) - - def check_authorization( - self, - perms: Sequence[tuple[str, str]] | None = None, - dag_id: str | None = None, - ) -> bool: - """Checks that the logged in user has the specified permissions.""" - if not perms: - return True - - for perm in perms: - if perm in ( - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ): - can_access_all_dags = self.has_access(*perm) - if not can_access_all_dags: - action = perm[0] - if not self.can_access_some_dags(action, dag_id): - return False - elif not self.has_access(*perm): - return False - - return True - - -class FakeAppBuilder: - """Stand-in class to replace a Flask App Builder. +@deprecated( + reason="If you want to override the security manager, you should inherit from " + "`airflow.auth.managers.fab.security_manager.override.FabAirflowSecurityManagerOverride` " + "instead" +) +class AirflowSecurityManager(FabAirflowSecurityManagerOverride): + """Placeholder, just here to avoid breaking the code of users who inherit from this. - The only purpose is to provide the ``self.appbuilder.get_session`` interface - for ``ApplessAirflowSecurityManager`` so it can be used without a real Flask - app, which is slow to create. + Do not use if writing new code. """ - def __init__(self, session: Session | None = None) -> None: - self.get_session = session - - -class ApplessAirflowSecurityManager(AirflowSecurityManager): - """Security Manager that doesn't need the whole flask app.""" - - def __init__(self, session: Session | None = None): - self.appbuilder = FakeAppBuilder(session) + ... diff --git a/airflow/www/security_appless.py b/airflow/www/security_appless.py new file mode 100644 index 0000000000000..32233c1fcd58e --- /dev/null +++ b/airflow/www/security_appless.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride + +if TYPE_CHECKING: + from flask_session import Session + + +class FakeAppBuilder: + """Stand-in class to replace a Flask App Builder. + + The only purpose is to provide the ``self.appbuilder.get_session`` interface + for ``ApplessAirflowSecurityManager`` so it can be used without a real Flask + app, which is slow to create. + """ + + def __init__(self, session: Session | None = None) -> None: + self.get_session = session + + +class ApplessAirflowSecurityManager(FabAirflowSecurityManagerOverride): + """Security Manager that doesn't need the whole flask app.""" + + def __init__(self, session: Session | None = None): + self.appbuilder = FakeAppBuilder(session) diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py new file mode 100644 index 0000000000000..172857098ac7d --- /dev/null +++ b/airflow/www/security_manager.py @@ -0,0 +1,753 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, Sequence + +from flask import g +from sqlalchemy import or_, select +from sqlalchemy.orm import joinedload + +from airflow.auth.managers.fab.models import Permission, Resource, Role, User +from airflow.auth.managers.fab.views.permissions import ( + ActionModelView, + PermissionPairModelView, + ResourceModelView, +) +from airflow.auth.managers.fab.views.roles_list import CustomRoleModelView +from airflow.auth.managers.fab.views.user import ( + CustomUserDBModelView, + CustomUserLDAPModelView, + CustomUserOAuthModelView, + CustomUserOIDModelView, + CustomUserRemoteUserModelView, +) +from airflow.auth.managers.fab.views.user_edit import ( + CustomResetMyPasswordView, + CustomResetPasswordView, + CustomUserInfoEditView, +) +from airflow.auth.managers.fab.views.user_stats import CustomUserStatsChartView +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning +from airflow.models import DagBag, DagModel +from airflow.security import permissions +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager +from airflow.www.fab_security.sqla.manager import SecurityManager +from airflow.www.utils import CustomSQLAInterface + +EXISTING_ROLES = { + "Admin", + "Viewer", + "User", + "Op", + "Public", +} + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + +class AirflowSecurityManagerV2(SecurityManager, LoggingMixin): + """Custom security manager, which introduces a permission model adapted to Airflow. + + It's named V2 to differentiate it from the obsolete airflow.www.security.AirflowSecurityManager. + """ + + ########################################################################### + # PERMISSIONS + ########################################################################### + + # [START security_viewer_perms] + VIEWER_PERMISSIONS = [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_DEPENDENCIES), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_JOB), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_MY_PASSWORD), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_MY_PASSWORD), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_MY_PROFILE), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_MY_PROFILE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_SLA_MISS), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_BROWSE_MENU), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_DEPENDENCIES), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CLUSTER_ACTIVITY), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS_MENU), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_JOB), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_AUDIT_LOG), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_PLUGIN), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_SLA_MISS), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_TASK_INSTANCE), + ] + # [END security_viewer_perms] + + # [START security_user_perms] + USER_PERMISSIONS = [ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), + ] + # [END security_user_perms] + + # [START security_op_perms] + OP_PERMISSIONS = [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_ADMIN_MENU), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CONFIG), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CONNECTION), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_POOL), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_VARIABLE), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_XCOM), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_XCOM), + ] + # [END security_op_perms] + + ADMIN_PERMISSIONS = [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_RESCHEDULE), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_TASK_RESCHEDULE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TRIGGER), + (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_TRIGGER), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_PASSWORD), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_PASSWORD), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE), + ] + + # global resource for dag-level access + DAG_RESOURCES = {permissions.RESOURCE_DAG} + DAG_ACTIONS = permissions.DAG_ACTIONS + + ########################################################################### + # DEFAULT ROLE CONFIGURATIONS + ########################################################################### + + ROLE_CONFIGS: list[dict[str, Any]] = [ + {"role": "Public", "perms": []}, + {"role": "Viewer", "perms": VIEWER_PERMISSIONS}, + { + "role": "User", + "perms": VIEWER_PERMISSIONS + USER_PERMISSIONS, + }, + { + "role": "Op", + "perms": VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS, + }, + { + "role": "Admin", + "perms": VIEWER_PERMISSIONS + USER_PERMISSIONS + OP_PERMISSIONS + ADMIN_PERMISSIONS, + }, + ] + + actionmodelview = ActionModelView + permissionmodelview = PermissionPairModelView + rolemodelview = CustomRoleModelView + resourcemodelview = ResourceModelView + userdbmodelview = CustomUserDBModelView + resetmypasswordview = CustomResetMyPasswordView + resetpasswordview = CustomResetPasswordView + userinfoeditview = CustomUserInfoEditView + userldapmodelview = CustomUserLDAPModelView + useroauthmodelview = CustomUserOAuthModelView + userremoteusermodelview = CustomUserRemoteUserModelView + useroidmodelview = CustomUserOIDModelView + userstatschartview = CustomUserStatsChartView + + def __init__(self, appbuilder) -> None: + super().__init__(appbuilder=appbuilder) + + # Go and fix up the SQLAInterface used from the stock one to our subclass. + # This is needed to support the "hack" where we had to edit + # FieldConverter.conversion_table in place in airflow.www.utils + for attr in dir(self): + if attr.endswith("view"): + view = getattr(self, attr, None) + if view and getattr(view, "datamodel", None): + view.datamodel = CustomSQLAInterface(view.datamodel.obj) + self.perms = None + + def _get_root_dag_id(self, dag_id: str) -> str: + if "." in dag_id: + dm = self.appbuilder.get_session.execute( + select(DagModel.dag_id, DagModel.root_dag_id).where(DagModel.dag_id == dag_id) + ).one() + return dm.root_dag_id or dm.dag_id + return dag_id + + def init_role(self, role_name, perms) -> None: + """ + Initialize the role with actions and related resources. + + :param role_name: + :param perms: + :return: + """ + warnings.warn( + "`init_role` has been deprecated. Please use `bulk_sync_roles` instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + self.bulk_sync_roles([{"role": role_name, "perms": perms}]) + + def bulk_sync_roles(self, roles: Iterable[dict[str, Any]]) -> None: + """Sync the provided roles and permissions.""" + existing_roles = self._get_all_roles_with_permissions() + non_dag_perms = self._get_all_non_dag_permissions() + + for config in roles: + role_name = config["role"] + perms = config["perms"] + role = existing_roles.get(role_name) or self.add_role(role_name) + + for action_name, resource_name in perms: + perm = non_dag_perms.get((action_name, resource_name)) or self.create_permission( + action_name, resource_name + ) + + if perm not in role.permissions: + self.add_permission_to_role(role, perm) + + @staticmethod + def get_user_roles(user=None): + """ + Get all the roles associated with the user. + + :param user: the ab_user in FAB model. + :return: a list of roles associated with the user. + """ + if user is None: + user = g.user + return user.roles + + def get_readable_dags(self, user) -> Iterable[DagModel]: + """Gets the DAGs readable by authenticated user.""" + warnings.warn( + "`get_readable_dags` has been deprecated. Please use `get_readable_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) + + def get_editable_dags(self, user) -> Iterable[DagModel]: + """Gets the DAGs editable by authenticated user.""" + warnings.warn( + "`get_editable_dags` has been deprecated. Please use `get_editable_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) + + @provide_session + def get_accessible_dags( + self, + user_actions: Container[str] | None, + user, + session: Session = NEW_SESSION, + ) -> Iterable[DagModel]: + warnings.warn( + "`get_accessible_dags` has been deprecated. Please use `get_accessible_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=3, + ) + dag_ids = self.get_accessible_dag_ids(user, user_actions, session) + return session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids))) + + def get_readable_dag_ids(self, user) -> set[str]: + """Gets the DAG IDs readable by authenticated user.""" + return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_READ]) + + def get_editable_dag_ids(self, user) -> set[str]: + """Gets the DAG IDs editable by authenticated user.""" + return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_EDIT]) + + @provide_session + def get_accessible_dag_ids( + self, + user, + user_actions: Container[str] | None = None, + session: Session = NEW_SESSION, + ) -> set[str]: + """Generic function to get readable or writable DAGs for user.""" + if not user_actions: + user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] + + if not get_auth_manager().is_logged_in(): + roles = user.roles + else: + if (permissions.ACTION_CAN_EDIT in user_actions and self.can_edit_all_dags(user)) or ( + permissions.ACTION_CAN_READ in user_actions and self.can_read_all_dags(user) + ): + return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + user_query = session.scalar( + select(User) + .options( + joinedload(User.roles) + .subqueryload(Role.permissions) + .options(joinedload(Permission.action), joinedload(Permission.resource)) + ) + .where(User.id == user.id) + ) + roles = user_query.roles + + resources = set() + for role in roles: + for permission in role.permissions: + action = permission.action.name + if action in user_actions: + resource = permission.resource.name + if resource == permissions.RESOURCE_DAG: + return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + if resource.startswith(permissions.RESOURCE_DAG_PREFIX): + resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :]) + else: + resources.add(resource) + return { + dag.dag_id + for dag in session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))) + } + + def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: + """Checks if user has read or write access to some dags.""" + if dag_id and dag_id != "~": + root_dag_id = self._get_root_dag_id(dag_id) + return self.has_access(action, permissions.resource_name_for_dag(root_dag_id)) + + user = g.user + if action == permissions.ACTION_CAN_READ: + return any(self.get_readable_dag_ids(user)) + return any(self.get_editable_dag_ids(user)) + + def can_read_dag(self, dag_id: str, user=None) -> bool: + """Determines whether a user has DAG read access.""" + root_dag_id = self._get_root_dag_id(dag_id) + dag_resource_name = permissions.resource_name_for_dag(root_dag_id) + return self.has_access(permissions.ACTION_CAN_READ, dag_resource_name, user=user) + + def can_edit_dag(self, dag_id: str, user=None) -> bool: + """Determines whether a user has DAG edit access.""" + root_dag_id = self._get_root_dag_id(dag_id) + dag_resource_name = permissions.resource_name_for_dag(root_dag_id) + return self.has_access(permissions.ACTION_CAN_EDIT, dag_resource_name, user=user) + + def can_delete_dag(self, dag_id: str, user=None) -> bool: + """Determines whether a user has DAG delete access.""" + root_dag_id = self._get_root_dag_id(dag_id) + dag_resource_name = permissions.resource_name_for_dag(root_dag_id) + return self.has_access(permissions.ACTION_CAN_DELETE, dag_resource_name, user=user) + + def prefixed_dag_id(self, dag_id: str) -> str: + """Returns the permission name for a DAG id.""" + warnings.warn( + "`prefixed_dag_id` has been deprecated. " + "Please use `airflow.security.permissions.resource_name_for_dag` instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + root_dag_id = self._get_root_dag_id(dag_id) + return permissions.resource_name_for_dag(root_dag_id) + + def is_dag_resource(self, resource_name: str) -> bool: + """Determines if a resource belongs to a DAG or all DAGs.""" + if resource_name == permissions.RESOURCE_DAG: + return True + return resource_name.startswith(permissions.RESOURCE_DAG_PREFIX) + + def has_access(self, action_name: str, resource_name: str, user=None) -> bool: + """ + Verify whether a given user could perform a certain action on the given resource. + + Example actions might include can_read, can_write, can_delete, etc. + + :param action_name: action_name on resource (e.g can_read, can_edit). + :param resource_name: name of view-menu or resource. + :param user: user name + :return: Whether user could perform certain action on the resource. + :rtype bool + """ + if not user: + user = g.user + if (action_name, resource_name) in user.perms: + return True + + if self.is_dag_resource(resource_name): + if (action_name, permissions.RESOURCE_DAG) in user.perms: + return True + return (action_name, resource_name) in user.perms + + return False + + def _has_role(self, role_name_or_list: Container, user) -> bool: + """Whether the user has this role name.""" + if not isinstance(role_name_or_list, list): + role_name_or_list = [role_name_or_list] + return any(r.name in role_name_or_list for r in user.roles) + + def has_all_dags_access(self, user) -> bool: + """ + Has all the dag access in any of the 3 cases. + + 1. Role needs to be in (Admin, Viewer, User, Op). + 2. Has can_read action on dags resource. + 3. Has can_edit action on dags resource. + """ + if not user: + user = g.user + return ( + self._has_role(["Admin", "Viewer", "Op", "User"], user) + or self.can_read_all_dags(user) + or self.can_edit_all_dags(user) + ) + + def can_edit_all_dags(self, user=None) -> bool: + """Has can_edit action on DAG resource.""" + return self.has_access(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG, user) + + def can_read_all_dags(self, user=None) -> bool: + """Has can_read action on DAG resource.""" + return self.has_access(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG, user) + + def clean_perms(self) -> None: + """FAB leaves faulty permissions that need to be cleaned up.""" + self.log.debug("Cleaning faulty perms") + sesh = self.appbuilder.get_session + perms = sesh.query(Permission).filter( + or_( + Permission.action == None, # noqa + Permission.resource == None, # noqa + ) + ) + # Since FAB doesn't define ON DELETE CASCADE on these tables, we need + # to delete the _object_ so that SQLA knows to delete the many-to-many + # relationship object too. :( + + deleted_count = 0 + for perm in perms: + sesh.delete(perm) + deleted_count += 1 + sesh.commit() + if deleted_count: + self.log.info("Deleted %s faulty permissions", deleted_count) + + def _merge_perm(self, action_name: str, resource_name: str) -> None: + """ + Add the new (action, resource) to assoc_permission_role if it doesn't exist. + + It will add the related entry to ab_permission and ab_resource two meta tables as well. + + :param action_name: Name of the action + :param resource_name: Name of the resource + :return: + """ + action = self.get_action(action_name) + resource = self.get_resource(resource_name) + perm = None + if action and resource: + perm = self.appbuilder.get_session.scalar( + select(self.permission_model).filter_by(action=action, resource=resource).limit(1) + ) + if not perm and action_name and resource_name: + self.create_permission(action_name, resource_name) + + def add_homepage_access_to_custom_roles(self) -> None: + """ + Add Website.can_read access to all custom roles. + + :return: None. + """ + website_permission = self.create_permission(permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE) + custom_roles = [role for role in self.get_all_roles() if role.name not in EXISTING_ROLES] + for role in custom_roles: + self.add_permission_to_role(role, website_permission) + + self.appbuilder.get_session.commit() + + def get_all_permissions(self) -> set[tuple[str, str]]: + """Returns all permissions as a set of tuples with the action and resource names.""" + return set( + self.appbuilder.get_session.execute( + select(self.action_model.name, self.resource_model.name) + .join(self.permission_model.action) + .join(self.permission_model.resource) + ) + ) + + def _get_all_non_dag_permissions(self) -> dict[tuple[str, str], Permission]: + """ + Get permissions except those that are for specific DAGs. + + Returns a dict with a key of (action_name, resource_name) and value of permission + with all permissions except those that are for specific DAGs. + """ + return { + (action_name, resource_name): viewmodel + for action_name, resource_name, viewmodel in ( + self.appbuilder.get_session.execute( + select(self.action_model.name, self.resource_model.name, self.permission_model) + .join(self.permission_model.action) + .join(self.permission_model.resource) + .where(~self.resource_model.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%")) + ) + ) + } + + def _get_all_roles_with_permissions(self) -> dict[str, Role]: + """Returns a dict with a key of role name and value of role with early loaded permissions.""" + return { + r.name: r + for r in self.appbuilder.get_session.scalars( + select(self.role_model).options(joinedload(self.role_model.permissions)) + ).unique() + } + + def create_dag_specific_permissions(self) -> None: + """ + Add permissions to all DAGs. + + Creates 'can_read', 'can_edit', and 'can_delete' permissions for all + DAGs, along with any `access_control` permissions provided in them. + + This does iterate through ALL the DAGs, which can be slow. See `sync_perm_for_dag` + if you only need to sync a single DAG. + + :return: None. + """ + perms = self.get_all_permissions() + dagbag = DagBag(read_dags_from_db=True) + dagbag.collect_dags_from_db() + dags = dagbag.dags.values() + + for dag in dags: + root_dag_id = dag.parent_dag.dag_id if dag.parent_dag else dag.dag_id + dag_resource_name = permissions.resource_name_for_dag(root_dag_id) + for action_name in self.DAG_ACTIONS: + if (action_name, dag_resource_name) not in perms: + self._merge_perm(action_name, dag_resource_name) + + if dag.access_control is not None: + self.sync_perm_for_dag(dag_resource_name, dag.access_control) + + def update_admin_permission(self) -> None: + """ + Add missing permissions to the table for admin. + + Admin should get all the permissions, except the dag permissions + because Admin already has Dags permission. + Add the missing ones to the table for admin. + + :return: None. + """ + session = self.appbuilder.get_session + dag_resources = session.scalars( + select(Resource).where(Resource.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%")) + ) + resource_ids = [resource.id for resource in dag_resources] + + perms = session.scalars(select(Permission).where(~Permission.resource_id.in_(resource_ids))) + perms = [p for p in perms if p.action and p.resource] + + admin = self.find_role("Admin") + admin.permissions = list(set(admin.permissions) | set(perms)) + + session.commit() + + def sync_roles(self) -> None: + """ + Initialize default and custom roles with related permissions. + + 1. Init the default role(Admin, Viewer, User, Op, public) + with related permissions. + 2. Init the custom role(dag-user) with related permissions. + + :return: None. + """ + # Create global all-dag permissions + self.create_perm_vm_for_all_dag() + + # Sync the default roles (Admin, Viewer, User, Op, public) with related permissions + self.bulk_sync_roles(self.ROLE_CONFIGS) + + self.add_homepage_access_to_custom_roles() + # init existing roles, the rest role could be created through UI. + self.update_admin_permission() + self.clean_perms() + + def sync_resource_permissions(self, perms: Iterable[tuple[str, str]] | None = None) -> None: + """Populates resource-based permissions.""" + if not perms: + return + + for action_name, resource_name in perms: + self.create_resource(resource_name) + self.create_permission(action_name, resource_name) + + def sync_perm_for_dag( + self, + dag_id: str, + access_control: dict[str, Collection[str]] | None = None, + ) -> None: + """ + Sync permissions for given dag id. + + The dag id surely exists in our dag bag as only / refresh button or DagBag will call this function. + + :param dag_id: the ID of the DAG whose permissions should be updated + :param access_control: a dict where each key is a rolename and + each value is a set() of action names (e.g., + {'can_read'} + :return: + """ + dag_resource_name = permissions.resource_name_for_dag(dag_id) + for dag_action_name in self.DAG_ACTIONS: + self.create_permission(dag_action_name, dag_resource_name) + + if access_control is not None: + self.log.debug("Syncing DAG-level permissions for DAG '%s'", dag_resource_name) + self._sync_dag_view_permissions(dag_resource_name, access_control) + else: + self.log.debug( + "Not syncing DAG-level permissions for DAG '%s' as access control is unset.", + dag_resource_name, + ) + + def _sync_dag_view_permissions(self, dag_id: str, access_control: dict[str, Collection[str]]) -> None: + """ + Set the access policy on the given DAG's ViewModel. + + :param dag_id: the ID of the DAG whose permissions should be updated + :param access_control: a dict where each key is a rolename and + each value is a set() of action names (e.g. {'can_read'}) + """ + dag_resource_name = permissions.resource_name_for_dag(dag_id) + + def _get_or_create_dag_permission(action_name: str) -> Permission | None: + perm = self.get_permission(action_name, dag_resource_name) + if not perm: + self.log.info("Creating new action '%s' on resource '%s'", action_name, dag_resource_name) + perm = self.create_permission(action_name, dag_resource_name) + + return perm + + def _revoke_stale_permissions(resource: Resource): + existing_dag_perms = self.get_resource_permissions(resource) + for perm in existing_dag_perms: + non_admin_roles = [role for role in perm.role if role.name != "Admin"] + for role in non_admin_roles: + target_perms_for_role = access_control.get(role.name, ()) + if perm.action.name not in target_perms_for_role: + self.log.info( + "Revoking '%s' on DAG '%s' for role '%s'", + perm.action, + dag_resource_name, + role.name, + ) + self.remove_permission_from_role(role, perm) + + resource = self.get_resource(dag_resource_name) + if resource: + _revoke_stale_permissions(resource) + + for rolename, action_names in access_control.items(): + role = self.find_role(rolename) + if not role: + raise AirflowException( + f"The access_control mapping for DAG '{dag_id}' includes a role named " + f"'{rolename}', but that role does not exist" + ) + + action_names = set(action_names) + invalid_action_names = action_names - self.DAG_ACTIONS + if invalid_action_names: + raise AirflowException( + f"The access_control map for DAG '{dag_resource_name}' includes " + f"the following invalid permissions: {invalid_action_names}; " + f"The set of valid permissions is: {self.DAG_ACTIONS}" + ) + + for action_name in action_names: + dag_perm = _get_or_create_dag_permission(action_name) + if dag_perm: + self.add_permission_to_role(role, dag_perm) + + def create_perm_vm_for_all_dag(self) -> None: + """Create perm-vm if not exist and insert into FAB security model for all-dags.""" + # create perm for global logical dag + for resource_name in self.DAG_RESOURCES: + for action_name in self.DAG_ACTIONS: + self._merge_perm(action_name, resource_name) + + def check_authorization( + self, + perms: Sequence[tuple[str, str]] | None = None, + dag_id: str | None = None, + ) -> bool: + """Checks that the logged in user has the specified permissions.""" + if not perms: + return True + + for perm in perms: + if perm in ( + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), + ): + can_access_all_dags = self.has_access(*perm) + if not can_access_all_dags: + action = perm[0] + if not self.can_access_some_dags(action, dag_id): + return False + elif not self.has_access(*perm): + return False + + return True diff --git a/docs/apache-airflow/security/access-control.rst b/docs/apache-airflow/security/access-control.rst index 7692a54cb5bf9..bef8b34f428f4 100644 --- a/docs/apache-airflow/security/access-control.rst +++ b/docs/apache-airflow/security/access-control.rst @@ -51,7 +51,7 @@ Viewer ^^^^^^ ``Viewer`` users have limited read permissions: -.. exampleinclude:: /../../airflow/www/security.py +.. exampleinclude:: /../../airflow/www/security_manager.py :language: python :start-after: [START security_viewer_perms] :end-before: [END security_viewer_perms] @@ -60,7 +60,7 @@ User ^^^^ ``User`` users have ``Viewer`` permissions plus additional permissions: -.. exampleinclude:: /../../airflow/www/security.py +.. exampleinclude:: /../../airflow/www/security_manager.py :language: python :start-after: [START security_user_perms] :end-before: [END security_user_perms] @@ -69,7 +69,7 @@ Op ^^ ``Op`` users have ``User`` permissions plus additional permissions: -.. exampleinclude:: /../../airflow/www/security.py +.. exampleinclude:: /../../airflow/www/security_manager.py :language: python :start-after: [START security_op_perms] :end-before: [END security_op_perms] diff --git a/docs/apache-airflow/security/webserver.rst b/docs/apache-airflow/security/webserver.rst index d990afa4d4536..282eabaa5b4d6 100644 --- a/docs/apache-airflow/security/webserver.rst +++ b/docs/apache-airflow/security/webserver.rst @@ -169,7 +169,7 @@ Here is an example of what you might have in your webserver_config.py: .. code-block:: python - from airflow.www.security import AirflowSecurityManager + from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride from flask_appbuilder.security.manager import AUTH_OAUTH import os @@ -200,7 +200,7 @@ Here is an example of what you might have in your webserver_config.py: ] - class CustomSecurityManager(AirflowSecurityManager): + class CustomSecurityManager(FabAirflowSecurityManagerOverride): pass @@ -213,7 +213,7 @@ webserver_config.py itself if you wish. .. code-block:: python - from airflow.www.security import AirflowSecurityManager + from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride import logging from typing import Any, List, Union import os @@ -244,7 +244,7 @@ webserver_config.py itself if you wish. return list(set(team_role_map.get(team, FAB_PUBLIC_ROLE) for team in team_list)) - class GithubTeamAuthorizer(AirflowSecurityManager): + class GithubTeamAuthorizer(FabAirflowSecurityManagerOverride): # In this example, the oauth provider == 'github'. # If you ever want to support other providers, see how it is done here: diff --git a/tests/api_connexion/endpoints/test_role_and_permission_endpoint.py b/tests/api_connexion/endpoints/test_role_and_permission_endpoint.py index 0a62de1043fb9..bdede1f16ffc0 100644 --- a/tests/api_connexion/endpoints/test_role_and_permission_endpoint.py +++ b/tests/api_connexion/endpoints/test_role_and_permission_endpoint.py @@ -21,7 +21,7 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.auth.managers.fab.models import Role from airflow.security import permissions -from airflow.www.security import EXISTING_ROLES +from airflow.www.security_manager import EXISTING_ROLES from tests.test_utils.api_connexion_utils import ( assert_401, create_role, diff --git a/tests/auth/managers/fab/security_manager/test_override.py b/tests/auth/managers/fab/security_manager/test_override.py index 9f931e6a2de42..9d63954767ea0 100644 --- a/tests/auth/managers/fab/security_manager/test_override.py +++ b/tests/auth/managers/fab/security_manager/test_override.py @@ -17,113 +17,35 @@ from __future__ import annotations from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock -import pytest - -from airflow.auth.managers.fab.models import User from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride -appbuilder = Mock() -actionmodelview = Mock() -authdbview = Mock() -authldapview = Mock() -authoauthview = Mock() -authoidview = Mock() -authremoteuserview = Mock() -permissionmodelview = Mock() -registeruser_view = Mock() -registeruserdbview = Mock() -registeruseroauthview = Mock() -registerusermodelview = Mock() -registeruseroidview = Mock() -resetmypasswordview = Mock() -resetpasswordview = Mock() -rolemodelview = Mock() -user_model = User -userinfoeditview = Mock() -userdbmodelview = Mock() -userldapmodelview = Mock() -useroauthmodelview = Mock() -useroidmodelview = Mock() -userremoteusermodelview = Mock() -userstatschartview = Mock() - - -@pytest.fixture -def security_manager_override(): - class SubSecurityManager: - def __init__(self, **kwargs): - pass - - class EmptySecurityManager(FabAirflowSecurityManagerOverride, SubSecurityManager): - def __init__(self, appbuilder): - super().__init__( - appbuilder=appbuilder, - actionmodelview=actionmodelview, - authdbview=authdbview, - authldapview=authldapview, - authoauthview=authoauthview, - authoidview=authoidview, - authremoteuserview=authremoteuserview, - permissionmodelview=permissionmodelview, - registeruser_view=registeruser_view, - registeruserdbview=registeruserdbview, - registeruseroauthview=registeruseroauthview, - registerusermodelview=registerusermodelview, - registeruseroidview=registeruseroidview, - resetmypasswordview=resetmypasswordview, - resetpasswordview=resetpasswordview, - rolemodelview=rolemodelview, - user_model=user_model, - userinfoeditview=userinfoeditview, - userdbmodelview=userdbmodelview, - userldapmodelview=userldapmodelview, - useroauthmodelview=useroauthmodelview, - useroidmodelview=useroidmodelview, - userremoteusermodelview=userremoteusermodelview, - userstatschartview=userstatschartview, - ) - - with mock.patch( - "airflow.auth.managers.fab.security_manager.override.LoginManager" - ) as mock_login_manager, mock.patch( - "airflow.auth.managers.fab.security_manager.override.JWTManager" - ) as mock_jwt_manager, mock.patch.object( - FabAirflowSecurityManagerOverride, "create_db" - ): - mock_login_manager_instance = Mock() - mock_login_manager.return_value = mock_login_manager_instance - mock_jwt_manager_instance = Mock() - mock_jwt_manager.return_value = mock_jwt_manager_instance - appbuilder.app.config = MagicMock() +class EmptySecurityManager(FabAirflowSecurityManagerOverride): + # noinspection PyMissingConstructor + # super() not called on purpose to avoid the whole chain of init calls + def __init__(self): + pass - security_manager_override = EmptySecurityManager(appbuilder) - mock_login_manager.assert_called_once_with(appbuilder.app) - mock_login_manager_instance.user_loader.assert_called_once_with(security_manager_override.load_user) - mock_jwt_manager_instance.init_app.assert_called_once_with(appbuilder.app) - mock_jwt_manager_instance.user_lookup_loader.assert_called_once_with( - security_manager_override.load_user_jwt - ) - - return security_manager_override +class TestFabAirflowSecurityManagerOverride: + def test_load_user(self): + sm = EmptySecurityManager() + sm.get_user_by_id = Mock() + sm.load_user("123") -class TestFabAirflowSecurityManagerOverride: - def test_load_user(self, security_manager_override): - mock_get_user_by_id = Mock() - security_manager_override.get_user_by_id = mock_get_user_by_id - security_manager_override.load_user("123") - mock_get_user_by_id.assert_called_once_with(123) + sm.get_user_by_id.assert_called_once_with(123) @mock.patch("airflow.auth.managers.fab.security_manager.override.g", spec={}) - def test_load_user_jwt(self, mock_g, security_manager_override): + def test_load_user_jwt(self, mock_g): + sm = EmptySecurityManager() mock_user = Mock() - mock_load_user = Mock(return_value=mock_user) - security_manager_override.load_user = mock_load_user - actual_user = security_manager_override.load_user_jwt(None, {"sub": "test_identity"}) - mock_load_user.assert_called_once_with("test_identity") + sm.load_user = Mock(return_value=mock_user) + + actual_user = sm.load_user_jwt(None, {"sub": "test_identity"}) + + sm.load_user.assert_called_once_with("test_identity") assert actual_user is mock_user assert mock_g.user is mock_user diff --git a/tests/auth/managers/fab/test_fab_auth_manager.py b/tests/auth/managers/fab/test_fab_auth_manager.py index 66c279510e55e..4802b23cf253a 100644 --- a/tests/auth/managers/fab/test_fab_auth_manager.py +++ b/tests/auth/managers/fab/test_fab_auth_manager.py @@ -25,12 +25,14 @@ from airflow.auth.managers.fab.fab_auth_manager import FabAuthManager from airflow.auth.managers.fab.models import User from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride -from airflow.www.security import ApplessAirflowSecurityManager +from airflow.www.security_appless import ApplessAirflowSecurityManager @pytest.fixture def auth_manager(): - auth_manager = FabAuthManager() + app_mock = Mock(name="flask_app") + app_mock.config.get.return_value = None # this is called to get the security manager override (if any) + auth_manager = FabAuthManager(app_mock) auth_manager.security_manager = ApplessAirflowSecurityManager() return auth_manager diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index f339e7d33fa90..c506133b43e23 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -20,7 +20,8 @@ from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.exceptions import AirflowException -from airflow.www.security import ApplessAirflowSecurityManager +from airflow.www.security_appless import ApplessAirflowSecurityManager +from airflow.www.security_manager import AirflowSecurityManagerV2 @pytest.fixture @@ -35,12 +36,13 @@ def is_logged_in(self) -> bool: def get_url_login(self, **kwargs) -> str: raise NotImplementedError() - return EmptyAuthManager() + # noinspection PyTypeChecker + return EmptyAuthManager(None) class TestBaseAuthManager: def test_get_security_manager_override_class_return_empty_class(self, auth_manager): - assert auth_manager.get_security_manager_override_class() is object + assert auth_manager.get_security_manager_override_class() is AirflowSecurityManagerV2 def test_get_security_manager_not_defined(self, auth_manager): with pytest.raises(AirflowException, match="Security manager not defined."): diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index e0a999b9d8a95..f7b36dbb4794d 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -42,7 +42,7 @@ from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.dates import timezone as tz from airflow.utils.session import create_session -from airflow.www.security import ApplessAirflowSecurityManager +from airflow.www.security_appless import ApplessAirflowSecurityManager from tests import cluster_policies from tests.models import TEST_DAGS_FOLDER from tests.test_utils import db @@ -901,7 +901,7 @@ def _sync_to_db(): _sync_to_db() mock_sync_perm_for_dag.assert_called_once_with(dag, session=session) - @patch("airflow.www.security.ApplessAirflowSecurityManager") + @patch("airflow.www.security_appless.ApplessAirflowSecurityManager") def test_sync_perm_for_dag(self, mock_security_manager): """ Test that dagbag._sync_perm_for_dag will call ApplessAirflowSecurityManager.sync_perm_for_dag diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py index 7360ed18f6e74..836ad95d7baf0 100644 --- a/tests/test_utils/api_connexion_utils.py +++ b/tests/test_utils/api_connexion_utils.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.www.security import EXISTING_ROLES +from airflow.www.security_manager import EXISTING_ROLES @contextmanager diff --git a/tests/test_utils/mock_security_manager.py b/tests/test_utils/mock_security_manager.py index f55f5e31bd94b..e8050f83ef7e9 100644 --- a/tests/test_utils/mock_security_manager.py +++ b/tests/test_utils/mock_security_manager.py @@ -16,10 +16,10 @@ # under the License. from __future__ import annotations -from airflow.www.security import AirflowSecurityManager +from airflow.www.security_manager import AirflowSecurityManagerV2 -class MockSecurityManager(AirflowSecurityManager): +class MockSecurityManager(AirflowSecurityManagerV2): VIEWER_VMS = { "Airflow", } diff --git a/tests/www/test_security.py b/tests/www/test_security.py index 8dc0ebec26542..b70aad536bce4 100644 --- a/tests/www/test_security.py +++ b/tests/www/test_security.py @@ -873,7 +873,7 @@ def test_create_dag_specific_permissions(session, security_manager, monkeypatch, dagbag_class_mock.return_value = dagbag_mock import airflow.www.security - monkeypatch.setitem(airflow.www.security.__dict__, "DagBag", dagbag_class_mock) + monkeypatch.setitem(airflow.www.security_manager.__dict__, "DagBag", dagbag_class_mock) security_manager._sync_dag_view_permissions = mock.Mock() for dag in sample_dags: