Skip to content

Commit

Permalink
Move the session auth backend to FAB auth manager (apache#42878)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored and pavansharma36 committed Oct 14, 2024
1 parent 4825f40 commit 4c9b824
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Session authentication backend."""

from __future__ import annotations

from functools import wraps
from typing import Any, Callable, TypeVar, cast

from flask import Response

from airflow.www.extensions.init_auth_manager import get_auth_manager

CLIENT_AUTH: tuple[str, str] | Any | None = None


def init_app(_):
"""Initialize authentication backend."""


T = TypeVar("T", bound=Callable)


def requires_authentication(function: T):
"""Decorate functions that require authentication."""

@wraps(function)
def decorated(*args, **kwargs):
if not get_auth_manager().is_logged_in():
return Response("Unauthorized", 401, {})
return function(*args, **kwargs)

return cast(T, decorated)
73 changes: 73 additions & 0 deletions providers/tests/fab/auth_manager/api/auth/backend/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest.mock import Mock, patch

import pytest
from flask import Response

from airflow.providers.fab.auth_manager.api.auth.backend.session import requires_authentication
from airflow.www import app as application

from dev.tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS

pytestmark = [
pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"),
]


@pytest.fixture
def app():
return application.create_app(testing=True)


mock_call = Mock()


@requires_authentication
def function_decorated():
mock_call()


@pytest.mark.db_test
class TestSessionAuth:
def setup_method(self) -> None:
mock_call.reset_mock()

@patch("airflow.providers.fab.auth_manager.api.auth.backend.session.get_auth_manager")
def test_requires_authentication_when_not_authenticated(self, mock_get_auth_manager, app):
auth_manager = Mock()
auth_manager.is_logged_in.return_value = False
mock_get_auth_manager.return_value = auth_manager
with app.test_request_context() as mock_context:
mock_context.request.authorization = None
result = function_decorated()

assert type(result) is Response
assert result.status_code == 401

@patch("airflow.providers.fab.auth_manager.api.auth.backend.session.get_auth_manager")
def test_requires_authentication_when_authenticated(self, mock_get_auth_manager, app):
auth_manager = Mock()
auth_manager.is_logged_in.return_value = True
mock_get_auth_manager.return_value = auth_manager
with app.test_request_context() as mock_context:
mock_context.request.authorization = None
function_decorated()

mock_call.assert_called_once()

0 comments on commit 4c9b824

Please sign in to comment.