From 577a9d77dd3ba2f18b17dec4d5310615ac989f23 Mon Sep 17 00:00:00 2001 From: Balthazar Rouberol Date: Tue, 5 Nov 2024 17:07:02 +0100 Subject: [PATCH] fab_auth_manager: allow get_user method to return the user authenticated via Kerberos (#43662) --- .../fab/auth_manager/fab_auth_manager.py | 15 +++++++++-- .../fab/auth_manager/test_fab_auth_manager.py | 26 ++++++++++++++++--- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 8a8fad6788693..e93e440f5ddfe 100644 --- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -25,7 +25,7 @@ import packaging.version from connexion import FlaskApi -from flask import Blueprint, url_for +from flask import Blueprint, g, url_for from packaging.version import Version from sqlalchemy import select from sqlalchemy.orm import Session, joinedload @@ -183,9 +183,20 @@ def get_user_display_name(self) -> str: return f"{first_name} {last_name}".strip() def get_user(self) -> User: - """Return the user associated to the user in session.""" + """ + Return the user associated to the user in session. + + Attempt to find the current user in g.user, as defined by the kerberos authentication backend. + If no such user is found, return the `current_user` local proxy object, linked to the user session. + + """ from flask_login import current_user + # If a user has gone through the Kerberos dance, the kerberos authentication manager + # has linked it with a User model, stored in g.user, and not the session. + if current_user.is_anonymous and getattr(g, "user", None) is not None and not g.user.is_anonymous: + return g.user + return current_user def init(self) -> None: diff --git a/providers/tests/fab/auth_manager/test_fab_auth_manager.py b/providers/tests/fab/auth_manager/test_fab_auth_manager.py index 91efb8428c654..d298f7667eaaf 100644 --- a/providers/tests/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/tests/fab/auth_manager/test_fab_auth_manager.py @@ -16,13 +16,14 @@ # under the License. from __future__ import annotations +from contextlib import contextmanager from itertools import chain from typing import TYPE_CHECKING from unittest import mock from unittest.mock import Mock import pytest -from flask import Flask +from flask import Flask, g from airflow.exceptions import AirflowConfigException, AirflowException @@ -72,6 +73,13 @@ } +@contextmanager +def user_set(app, user): + g.user = user + yield + g.user = None + + @pytest.fixture def auth_manager(): return FabAuthManager(None) @@ -114,12 +122,24 @@ def test_get_user_display_name( assert auth_manager.get_user_display_name() == expected @mock.patch("flask_login.utils._get_user") - def test_get_user(self, mock_current_user, auth_manager): + def test_get_user(self, mock_current_user, minimal_app_for_auth_api, auth_manager): user = Mock() user.is_anonymous.return_value = True mock_current_user.return_value = user + with minimal_app_for_auth_api.app_context(): + assert auth_manager.get_user() == user - assert auth_manager.get_user() == user + @mock.patch("flask_login.utils._get_user") + def test_get_user_from_flask_g(self, mock_current_user, minimal_app_for_auth_api, auth_manager): + session_user = Mock() + session_user.is_anonymous = True + mock_current_user.return_value = session_user + + flask_g_user = Mock() + flask_g_user.is_anonymous = False + with minimal_app_for_auth_api.app_context(): + with user_set(minimal_app_for_auth_api, flask_g_user): + assert auth_manager.get_user() == flask_g_user @pytest.mark.db_test @mock.patch.object(FabAuthManager, "get_user")