Skip to content

Commit

Permalink
Set up JWT token authentication in Fast APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck committed Nov 6, 2024
1 parent 7f4b7fd commit b9a9e8a
Show file tree
Hide file tree
Showing 16 changed files with 402 additions and 55 deletions.
45 changes: 45 additions & 0 deletions airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@

from airflow.api_fastapi.core_api.app import init_config, init_dag_bag, init_plugins, init_views
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException

log = logging.getLogger(__name__)

app: FastAPI | None = None
auth_manager: BaseAuthManager | None = None


def create_app(apps: str = "all") -> FastAPI:
Expand All @@ -42,6 +46,7 @@ def create_app(apps: str = "all") -> FastAPI:
init_dag_bag(app)
init_views(app)
init_plugins(app)
init_auth_manager()

if "execution" in apps_list or "all" in apps_list:
task_exec_api_app = create_task_execution_api_app(app)
Expand All @@ -64,3 +69,43 @@ def purge_cached_app() -> None:
"""Remove the cached version of the app in global state."""
global app
app = None


def get_auth_manager_cls() -> type[BaseAuthManager]:
"""
Return just the auth manager class without initializing it.
Useful to save execution time if only static methods need to be called.
"""
auth_manager_cls = conf.getimport(section="core", key="auth_manager")

if not auth_manager_cls:
raise AirflowConfigException(
"No auth manager defined in the config. "
"Please specify one using section/key [core/auth_manager]."
)

return auth_manager_cls


def init_auth_manager() -> BaseAuthManager:
"""
Initialize the auth manager.
Import the user manager class and instantiate it.
"""
global auth_manager
auth_manager_cls = get_auth_manager_cls()
auth_manager = auth_manager_cls()
auth_manager.init()
return auth_manager


def get_auth_manager() -> BaseAuthManager:
"""Return the auth manager, provided it's been initialized before."""
if auth_manager is None:
raise RuntimeError(
"Auth Manager has not been initialized yet. "
"The `init_auth_manager` method needs to be called first."
)
return auth_manager
77 changes: 77 additions & 0 deletions airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from functools import cache
from typing import Any, Callable

from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jwt import InvalidTokenError
from typing_extensions import Annotated

from airflow.api_fastapi.app import get_auth_manager
from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.configuration import conf
from airflow.utils.jwt_signer import JWTSigner

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


@cache
def get_signer() -> JWTSigner:
return JWTSigner(
secret_key=conf.get("api", "auth_jwt_secret"),
expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"),
audience="front-apis",
)


def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser:
try:
signer = get_signer()
payload: dict[str, Any] = signer.verify_token(token_str)
return get_auth_manager().deserialize_user(payload)
except InvalidTokenError:
raise HTTPException(403, "Forbidden")


def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None = None) -> Callable:
def inner(
dag_id: str | None = None,
user: Annotated[BaseUser | None, Depends(get_user)] = None,
) -> None:
def callback():
return get_auth_manager().is_authorized_dag(
method=method, access_entity=access_entity, details=DagDetails(id=dag_id), user=user
)

_requires_access(
is_authorized_callback=callback,
)

return inner


def _requires_access(
*,
is_authorized_callback: Callable[[], bool],
) -> None:
if not is_authorized_callback():
raise HTTPException(403, "Forbidden")
36 changes: 23 additions & 13 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

from abc import abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Container, Literal, Sequence
from typing import TYPE_CHECKING, Any, Container, Generic, Literal, Sequence, TypeVar

from flask_appbuilder.menu import MenuItem
from sqlalchemy import select

from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
DagDetails,
)
Expand All @@ -37,7 +38,6 @@
from flask import Blueprint
from sqlalchemy.orm import Session

from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
Expand All @@ -59,8 +59,10 @@

ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]

T = TypeVar("T", bound=BaseUser)

class BaseAuthManager(LoggingMixin):

class BaseAuthManager(Generic[T], LoggingMixin):
"""
Class to derive in order to implement concrete auth managers.
Expand All @@ -69,7 +71,7 @@ class BaseAuthManager(LoggingMixin):
:param appbuilder: the flask app builder
"""

def __init__(self, appbuilder: AirflowAppBuilder) -> None:
def __init__(self, appbuilder: AirflowAppBuilder | None = None) -> None:
super().__init__()
self.appbuilder = appbuilder

Expand All @@ -93,9 +95,17 @@ def get_user_display_name(self) -> str:
return self.get_user_name()

@abstractmethod
def get_user(self) -> BaseUser | None:
def get_user(self) -> T | None:
"""Return the user associated to the user in session."""

@abstractmethod
def deserialize_user(self, token: dict[str, Any]) -> T:
"""Create a user object from dict."""

@abstractmethod
def serialize_user(self, user: T) -> dict[str, Any]:
"""Create a dict from a user object."""

def get_user_id(self) -> str | None:
"""Return the user ID associated to the user in session."""
user = self.get_user()
Expand Down Expand Up @@ -132,7 +142,7 @@ def is_authorized_configuration(
*,
method: ResourceMethod,
details: ConfigurationDetails | None = None,
user: BaseUser | None = None,
user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on configuration.
Expand All @@ -148,7 +158,7 @@ def is_authorized_connection(
*,
method: ResourceMethod,
details: ConnectionDetails | None = None,
user: BaseUser | None = None,
user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on a connection.
Expand All @@ -165,7 +175,7 @@ def is_authorized_dag(
method: ResourceMethod,
access_entity: DagAccessEntity | None = None,
details: DagDetails | None = None,
user: BaseUser | None = None,
user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on a DAG.
Expand All @@ -183,7 +193,7 @@ def is_authorized_asset(
*,
method: ResourceMethod,
details: AssetDetails | None = None,
user: BaseUser | None = None,
user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on an asset.
Expand All @@ -199,7 +209,7 @@ def is_authorized_pool(
*,
method: ResourceMethod,
details: PoolDetails | None = None,
user: BaseUser | None = None,
user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on a pool.
Expand All @@ -215,7 +225,7 @@ def is_authorized_variable(
*,
method: ResourceMethod,
details: VariableDetails | None = None,
user: BaseUser | None = None,
user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on a variable.
Expand All @@ -230,7 +240,7 @@ def is_authorized_view(
self,
*,
access_view: AccessView,
user: BaseUser | None = None,
user: T | None = None,
) -> bool:
"""
Return whether the user is authorized to access a read-only state of the installation.
Expand All @@ -241,7 +251,7 @@ def is_authorized_view(

@abstractmethod
def is_authorized_custom_view(
self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None
self, *, method: ResourceMethod | str, resource_name: str, user: T | None = None
):
"""
Return whether the user is authorized to perform a given action on a custom view.
Expand Down
Loading

0 comments on commit b9a9e8a

Please sign in to comment.