diff --git a/apphelpers/rest/endpoint.py b/apphelpers/rest/endpoint.py index 03bf63d..43477da 100644 --- a/apphelpers/rest/endpoint.py +++ b/apphelpers/rest/endpoint.py @@ -38,11 +38,6 @@ def decorator(func): return decorator -def not_found_on_none(func): - func.not_found_on_none = True - return func - - def ignore_site_ctx(func): func.ignore_site_ctx = True return func diff --git a/apphelpers/rest/fastapi.py b/apphelpers/rest/fastapi.py index 6074c3d..f6ea869 100644 --- a/apphelpers/rest/fastapi.py +++ b/apphelpers/rest/fastapi.py @@ -1,6 +1,5 @@ import inspect from functools import wraps -from typing import Union from converge import settings from fastapi import APIRouter, Depends @@ -17,23 +16,12 @@ from apphelpers.rest import endpoint as ep from apphelpers.rest.common import User, phony from apphelpers.sessions import SessionDBHandler +from apphelpers.utilities.types import is_optional if settings.get("HONEYBADGER_API_KEY"): from honeybadger import Honeybadger from honeybadger.utils import filter_dict -try: - from types import NoneType, UnionType -except ImportError: - NoneType = type(None) - UnionType = type(Union[str, int]) - - -def _is_optional(t) -> bool: - return ( - isinstance(t, UnionType) and len(t.__args__) == 2 and t.__args__[1] == NoneType - ) - def raise_not_found_on_none(f): if getattr(f, "not_found_on_none", None) is True: @@ -497,7 +485,7 @@ def build(self, method, method_args, method_kw, f): ): method_kw["response_model"] = response_model - if _is_optional(return_type) and not _is_optional(response_model): + if is_optional(return_type) and not is_optional(response_model): f.not_found_on_none = True print( diff --git a/apphelpers/rest/hug.py b/apphelpers/rest/hug.py index c2371b0..d092740 100644 --- a/apphelpers/rest/hug.py +++ b/apphelpers/rest/hug.py @@ -1,21 +1,20 @@ # type: ignore +import inspect from dataclasses import asdict, dataclass import hug from converge import settings -from falcon import ( - HTTPForbidden, - HTTPNotFound, - HTTPUnauthorized, -) +from falcon import HTTPForbidden, HTTPNotFound, HTTPUnauthorized from hug.decorators import wraps +from pydantic import TypeAdapter from apphelpers.db.peewee import dbtransaction from apphelpers.errors.hug import BaseError, InvalidSessionError from apphelpers.loggers import api_logger from apphelpers.rest import endpoint as ep from apphelpers.sessions import SessionDBHandler +from apphelpers.utilities.types import is_optional if settings.get("HONEYBADGER_API_KEY"): from honeybadger import Honeybadger @@ -423,6 +422,21 @@ def choose_router(self, f): return self.secure_router if login_required else self.router def build(self, method, method_args, method_kw, f): + return_type = inspect.signature(f).return_annotation + response_model = getattr(f, "response_model", return_type) + + if is_optional(return_type) and not is_optional(response_model): + f.not_found_on_none = True + + # Let hug validate the response using return annotation + if response_model and response_model != inspect.Signature.empty: + validator = TypeAdapter(response_model) + + def callable_return_type(ret): + return validator.dump_python(validator.validate_python(ret)) + + f.__annotations__["return"] = callable_return_type + print( f"{method_args[0]}", f"[{method.__name__.upper()}] => {f.__module__}:{f.__name__}", diff --git a/apphelpers/utilities/types.py b/apphelpers/utilities/types.py new file mode 100644 index 0000000..8b74ff5 --- /dev/null +++ b/apphelpers/utilities/types.py @@ -0,0 +1,8 @@ +try: + from types import NoneType +except ImportError: + NoneType = type(None) + + +def is_optional(t) -> bool: + return len(getattr(t, "__args__", [])) == 2 and t.__args__[1] == NoneType diff --git a/fastapi_tests/app/endpoints.py b/fastapi_tests/app/endpoints.py index 69b590b..ab035ab 100644 --- a/fastapi_tests/app/endpoints.py +++ b/fastapi_tests/app/endpoints.py @@ -43,8 +43,8 @@ def get_my_uid(body=json_body): @ep.login_required -@ep.not_found_on_none -def get_snake(name=None): +@ep.response_model(str) +def get_snake(name=None) -> Optional[str]: return name @@ -55,8 +55,7 @@ def get_snake_fancy(name=None) -> Optional[str]: @ep.login_required -@ep.not_found_on_none -async def get_snake_async(name=None): +async def get_snake_async(name=None) -> Optional[str]: return name diff --git a/tests/app/endpoints.py b/tests/app/endpoints.py index a0e4b99..d0280b8 100644 --- a/tests/app/endpoints.py +++ b/tests/app/endpoints.py @@ -1,5 +1,8 @@ +from typing import Dict, Optional + import hug import hug.directives +from pydantic import BaseModel from apphelpers.rest import endpoint as ep from apphelpers.rest.hug import user_id @@ -30,14 +33,21 @@ def get_my_uid(uid: user_id): return uid -@ep.not_found_on_none -def get_snake(name): +@ep.response_model(str) +def get_snake(name) -> Optional[str]: + return None + + +def get_snake_legacy(name) -> Optional[str]: return None +get_snake_legacy.not_found_on_none = True + + @ep.login_required -@ep.not_found_on_none -def get_secure_snake(site_id, name): +@ep.response_model(str) +def get_secure_snake(site_id, name) -> Optional[str]: return None @@ -64,7 +74,12 @@ def process_request(request, body): return {"body": body, "headers": request.headers} -@ep.not_found_on_none +class RawRequest(BaseModel): + raw_body: str + headers: Dict[str, str] + + +@ep.response_model(RawRequest) def process_raw_request(request): return {"raw_body": request.stream.read().decode(), "headers": request.headers} @@ -91,6 +106,7 @@ def setup_routes(factory): factory.post("/me/uid")(get_my_uid) factory.get("/snakes/{name}")(get_snake) + factory.get("/snakes-legacy/{name}")(get_snake_legacy) factory.get("/sites/{site_id}/secure-echo/{word}")(secure_multisite_echo) factory.get("/sites/{site_id}/echo-groups")(echo_multisite_groups) diff --git a/tests/test_rest.py b/tests/test_rest.py index 51ebcc4..cc89d8b 100644 --- a/tests/test_rest.py +++ b/tests/test_rest.py @@ -159,6 +159,9 @@ def test_not_found(): url = urls.base + "snakes/viper" assert requests.get(url).status_code == 404 + url = urls.base + "snakes-legacy/viper" + assert requests.get(url).status_code == 404 + url = urls.base + "sites/1/snakes/viper" assert requests.get(url, headers=headers).status_code == 404