diff --git a/openapi_to_fastapi/routes.py b/openapi_to_fastapi/routes.py index d88cbb3..66a0ffa 100644 --- a/openapi_to_fastapi/routes.py +++ b/openapi_to_fastapi/routes.py @@ -1,10 +1,10 @@ import json from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import pydantic -from fastapi import APIRouter +from fastapi import APIRouter, params from fastapi.openapi import models as oas from .model_generator import load_models @@ -35,6 +35,7 @@ class RouteInfo: tags: Optional[List[str]] = None summary: Optional[str] = None deprecated: Optional[bool] = None + dependencies: Optional[Sequence[params.Depends]] = None request_model: Optional[Type[pydantic.BaseModel]] = None response_model: Optional[Type[pydantic.BaseModel]] = None @@ -188,6 +189,7 @@ def post( response_description: Optional[str] = None, name_factory: Optional[Callable] = None, responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None, + dependencies: Optional[Sequence[params.Depends]] = None, ): """ Define implementation for a specific POST route @@ -201,6 +203,7 @@ def post( :param description: Route description. Got from OpenAPI spec by default :param response_description: Description of the response :param responses: Possible responses the route may return. Used in documentation + :param dependencies: Possible dependencies to add to the route. """ def _wrapper(fn): @@ -215,6 +218,7 @@ def _wrapper(fn): route_info.tags = tags route_info.name_factory = name_factory route_info.responses = responses + route_info.dependencies = dependencies if response_description: route_info.response_description = response_description @@ -258,5 +262,6 @@ def to_fastapi_router(self): responses=route_info.responses, tags=route_info.tags, deprecated=route_info.deprecated, + dependencies=route_info.dependencies, )(handler) return router diff --git a/openapi_to_fastapi/tests/test_router.py b/openapi_to_fastapi/tests/test_router.py index 9c6f2e4..e9dfce7 100644 --- a/openapi_to_fastapi/tests/test_router.py +++ b/openapi_to_fastapi/tests/test_router.py @@ -1,7 +1,8 @@ import pydantic import pytest -from fastapi import Header +from fastapi import Depends, Header, HTTPException, Request from pydantic import BaseModel +from starlette.status import HTTP_418_IM_A_TEAPOT from openapi_to_fastapi.model_generator import load_models from openapi_to_fastapi.routes import SpecRouter @@ -251,3 +252,36 @@ def test_custom_responses(app, specs_root): assert issubclass(model, BaseModel) assert "ok" in model.model_fields assert "errorMessage" in model.model_fields + + +def test_dependencies(app, client, specs_root): + async def teapot_dependency(request: Request): + """ + Dependency used just for testing. + """ + if request.headers.get("X-Brew") != "tea": + raise HTTPException( + HTTP_418_IM_A_TEAPOT, + "I'm a teapot", + ) + + spec_router = SpecRouter(specs_root / "definitions") + + @spec_router.post("/Company/BasicInfo", dependencies=[Depends(teapot_dependency)]) + def weather_metric(request): + return company_basic_info_resp + + app.include_router(spec_router.to_fastapi_router()) + + # Normal request, not affected by the dependency + resp = client.post( + "/Company/BasicInfo", json={"companyId": "test"}, headers={"X-Brew": "tea"} + ) + assert resp.status_code == 200, resp.json() + assert resp.json() == company_basic_info_resp + + # Custom request, affected by the dependency + resp = client.post( + "/Company/BasicInfo", json={"companyId": "test"}, headers={"X-Brew": "coffee"} + ) + assert resp.status_code == 418, resp.json()