Skip to content

Commit

Permalink
fab_auth_manager: allow get_user method to return the user authentica…
Browse files Browse the repository at this point in the history
…ted via Kerberos (apache#43662)
  • Loading branch information
brouberol authored and ellisms committed Nov 13, 2024
1 parent b86a747 commit 577a9d7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import packaging.version
from connexion import FlaskApi
from flask import Blueprint, url_for
from flask import Blueprint, g, url_for
from packaging.version import Version
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload
Expand Down Expand Up @@ -183,9 +183,20 @@ def get_user_display_name(self) -> str:
return f"{first_name} {last_name}".strip()

def get_user(self) -> User:
"""Return the user associated to the user in session."""
"""
Return the user associated to the user in session.
Attempt to find the current user in g.user, as defined by the kerberos authentication backend.
If no such user is found, return the `current_user` local proxy object, linked to the user session.
"""
from flask_login import current_user

# If a user has gone through the Kerberos dance, the kerberos authentication manager
# has linked it with a User model, stored in g.user, and not the session.
if current_user.is_anonymous and getattr(g, "user", None) is not None and not g.user.is_anonymous:
return g.user

return current_user

def init(self) -> None:
Expand Down
26 changes: 23 additions & 3 deletions providers/tests/fab/auth_manager/test_fab_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.
from __future__ import annotations

from contextlib import contextmanager
from itertools import chain
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import Mock

import pytest
from flask import Flask
from flask import Flask, g

from airflow.exceptions import AirflowConfigException, AirflowException

Expand Down Expand Up @@ -72,6 +73,13 @@
}


@contextmanager
def user_set(app, user):
g.user = user
yield
g.user = None


@pytest.fixture
def auth_manager():
return FabAuthManager(None)
Expand Down Expand Up @@ -114,12 +122,24 @@ def test_get_user_display_name(
assert auth_manager.get_user_display_name() == expected

@mock.patch("flask_login.utils._get_user")
def test_get_user(self, mock_current_user, auth_manager):
def test_get_user(self, mock_current_user, minimal_app_for_auth_api, auth_manager):
user = Mock()
user.is_anonymous.return_value = True
mock_current_user.return_value = user
with minimal_app_for_auth_api.app_context():
assert auth_manager.get_user() == user

assert auth_manager.get_user() == user
@mock.patch("flask_login.utils._get_user")
def test_get_user_from_flask_g(self, mock_current_user, minimal_app_for_auth_api, auth_manager):
session_user = Mock()
session_user.is_anonymous = True
mock_current_user.return_value = session_user

flask_g_user = Mock()
flask_g_user.is_anonymous = False
with minimal_app_for_auth_api.app_context():
with user_set(minimal_app_for_auth_api, flask_g_user):
assert auth_manager.get_user() == flask_g_user

@pytest.mark.db_test
@mock.patch.object(FabAuthManager, "get_user")
Expand Down

0 comments on commit 577a9d7

Please sign in to comment.