Skip to content

Commit

Permalink
Implement response_model for hug
Browse files Browse the repository at this point in the history
  • Loading branch information
sayanarijit committed Jan 5, 2024
1 parent 067cf3d commit 00b9c4d
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 33 deletions.
5 changes: 0 additions & 5 deletions apphelpers/rest/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ def decorator(func):
return decorator


def not_found_on_none(func):
func.not_found_on_none = True
return func


def ignore_site_ctx(func):
func.ignore_site_ctx = True
return func
Expand Down
16 changes: 2 additions & 14 deletions apphelpers/rest/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect
from functools import wraps
from typing import Union

from converge import settings
from fastapi import APIRouter, Depends
Expand All @@ -17,23 +16,12 @@
from apphelpers.rest import endpoint as ep
from apphelpers.rest.common import User, phony
from apphelpers.sessions import SessionDBHandler
from apphelpers.utilities.types import is_optional

if settings.get("HONEYBADGER_API_KEY"):
from honeybadger import Honeybadger
from honeybadger.utils import filter_dict

try:
from types import NoneType, UnionType
except ImportError:
NoneType = type(None)
UnionType = type(Union[str, int])


def _is_optional(t) -> bool:
return (
isinstance(t, UnionType) and len(t.__args__) == 2 and t.__args__[1] == NoneType
)


def raise_not_found_on_none(f):
if getattr(f, "not_found_on_none", None) is True:
Expand Down Expand Up @@ -497,7 +485,7 @@ def build(self, method, method_args, method_kw, f):
):
method_kw["response_model"] = response_model

if _is_optional(return_type) and not _is_optional(response_model):
if is_optional(return_type) and not is_optional(response_model):
f.not_found_on_none = True

print(
Expand Down
24 changes: 19 additions & 5 deletions apphelpers/rest/hug.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
# type: ignore

import inspect
from dataclasses import asdict, dataclass

import hug
from converge import settings
from falcon import (
HTTPForbidden,
HTTPNotFound,
HTTPUnauthorized,
)
from falcon import HTTPForbidden, HTTPNotFound, HTTPUnauthorized
from hug.decorators import wraps
from pydantic import TypeAdapter

from apphelpers.db.peewee import dbtransaction
from apphelpers.errors.hug import BaseError, InvalidSessionError
from apphelpers.loggers import api_logger
from apphelpers.rest import endpoint as ep
from apphelpers.sessions import SessionDBHandler
from apphelpers.utilities.types import is_optional

if settings.get("HONEYBADGER_API_KEY"):
from honeybadger import Honeybadger
Expand Down Expand Up @@ -423,6 +422,21 @@ def choose_router(self, f):
return self.secure_router if login_required else self.router

def build(self, method, method_args, method_kw, f):
return_type = inspect.signature(f).return_annotation
response_model = getattr(f, "response_model", return_type)

if is_optional(return_type) and not is_optional(response_model):
f.not_found_on_none = True

# Let hug validate the response using return annotation
if response_model and response_model != inspect.Signature.empty:
validator = TypeAdapter(response_model)

def callable_return_type(ret):
return validator.dump_python(validator.validate_python(ret))

f.__annotations__["return"] = callable_return_type

print(
f"{method_args[0]}",
f"[{method.__name__.upper()}] => {f.__module__}:{f.__name__}",
Expand Down
8 changes: 8 additions & 0 deletions apphelpers/utilities/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
try:
from types import NoneType
except ImportError:
NoneType = type(None)


def is_optional(t) -> bool:
return len(getattr(t, "__args__", [])) == 2 and t.__args__[1] == NoneType
7 changes: 3 additions & 4 deletions fastapi_tests/app/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def get_my_uid(body=json_body):


@ep.login_required
@ep.not_found_on_none
def get_snake(name=None):
@ep.response_model(str)
def get_snake(name=None) -> Optional[str]:
return name


Expand All @@ -55,8 +55,7 @@ def get_snake_fancy(name=None) -> Optional[str]:


@ep.login_required
@ep.not_found_on_none
async def get_snake_async(name=None):
async def get_snake_async(name=None) -> Optional[str]:
return name


Expand Down
26 changes: 21 additions & 5 deletions tests/app/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Dict, Optional

import hug
import hug.directives
from pydantic import BaseModel

from apphelpers.rest import endpoint as ep
from apphelpers.rest.hug import user_id
Expand Down Expand Up @@ -30,14 +33,21 @@ def get_my_uid(uid: user_id):
return uid


@ep.not_found_on_none
def get_snake(name):
@ep.response_model(str)
def get_snake(name) -> Optional[str]:
return None


def get_snake_legacy(name) -> Optional[str]:
return None


get_snake_legacy.not_found_on_none = True


@ep.login_required
@ep.not_found_on_none
def get_secure_snake(site_id, name):
@ep.response_model(str)
def get_secure_snake(site_id, name) -> Optional[str]:
return None


Expand All @@ -64,7 +74,12 @@ def process_request(request, body):
return {"body": body, "headers": request.headers}


@ep.not_found_on_none
class RawRequest(BaseModel):
raw_body: str
headers: Dict[str, str]


@ep.response_model(RawRequest)
def process_raw_request(request):
return {"raw_body": request.stream.read().decode(), "headers": request.headers}

Expand All @@ -91,6 +106,7 @@ def setup_routes(factory):
factory.post("/me/uid")(get_my_uid)

factory.get("/snakes/{name}")(get_snake)
factory.get("/snakes-legacy/{name}")(get_snake_legacy)

factory.get("/sites/{site_id}/secure-echo/{word}")(secure_multisite_echo)
factory.get("/sites/{site_id}/echo-groups")(echo_multisite_groups)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ def test_not_found():
url = urls.base + "snakes/viper"
assert requests.get(url).status_code == 404

url = urls.base + "snakes-legacy/viper"
assert requests.get(url).status_code == 404

url = urls.base + "sites/1/snakes/viper"
assert requests.get(url, headers=headers).status_code == 404

Expand Down

0 comments on commit 00b9c4d

Please sign in to comment.