Skip to content

Commit

Permalink
Add @endpoint.response_model
Browse files Browse the repository at this point in the history
  • Loading branch information
sayanarijit committed Dec 27, 2023
1 parent fed3ef1 commit c7f378d
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
8 changes: 8 additions & 0 deletions apphelpers/rest/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,11 @@ def not_found_on_none(func):
def ignore_site_ctx(func):
func.ignore_site_ctx = True
return func


def response_model(response_model):
def decorator(func):
func.response_model = response_model
return func

return decorator
16 changes: 14 additions & 2 deletions apphelpers/rest/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
from functools import wraps
from types import NoneType, UnionType

from converge import settings
from fastapi import APIRouter, Depends
Expand All @@ -22,6 +23,12 @@
from honeybadger.utils import filter_dict


def _is_optional(t) -> bool:
return (
isinstance(t, UnionType) and len(t.__args__) == 2 and t.__args__[1] == NoneType
)


def raise_not_found_on_none(f):
if getattr(f, "not_found_on_none", None) is True:

Expand Down Expand Up @@ -470,17 +477,22 @@ def build(self, method, method_args, method_kw, f):
module = f.__module__.split(".")[-1].strip("_")
name = f.__name__.strip("_")
return_type = inspect.signature(f).return_annotation
response_model = getattr(f, "response_model", return_type)

if "operation_id" not in method_kw:
method_kw["operation_id"] = f"{name}_{module}"
if "tags" not in method_kw:
method_kw["tags"] = [module]

if (
"response_model" not in method_kw
and "response_class" not in method_kw
and return_type is not inspect.Signature.empty
and response_model is not inspect.Signature.empty
):
method_kw["response_model"] = return_type
method_kw["response_model"] = response_model

if _is_optional(return_type) and not _is_optional(response_model):
f.not_found_on_none = True

print(
f"{method_args[0]}",
Expand Down
10 changes: 9 additions & 1 deletion fastapi_tests/app/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from apphelpers.rest import endpoint as ep
from apphelpers.rest.fastapi import json_body, user, user_id, user_agent
from apphelpers.rest.fastapi import json_body, user, user_agent, user_id


def echo(word, user=user):
Expand Down Expand Up @@ -46,6 +46,12 @@ def get_snake(name=None):
return name


@ep.login_required
@ep.response_model(str)
def get_snake_fancy(name=None) -> str | None:
return name


@ep.login_required
@ep.not_found_on_none
async def get_snake_async(name=None):
Expand Down Expand Up @@ -91,6 +97,8 @@ def setup_routes(factory):
factory.get("/snakes/{name}")(get_snake)
factory.get("/snakes-async/{name}")(get_snake_async)

factory.get("/snakes-fancy/{name}")(get_snake_fancy)

factory.get("/sites/{site_id}/echo-groups")(echo_site_groups)
factory.get("/sites/{site_id}/echo-groups-async")(echo_site_groups_async)

Expand Down
6 changes: 6 additions & 0 deletions fastapi_tests/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,15 @@ def test_not_found_on_none():
url = base_url + "snakes/viper"
assert requests.get(url).status_code != 404

url = base_url + "snakes-fancy/viper"
assert requests.get(url).status_code != 404

url = base_url + "snakes"
assert requests.get(url).status_code == 404

url = base_url + "snakes-fancy"
assert requests.get(url).status_code == 404


def test_site_group_access():
url = echo_site_groups_url
Expand Down

0 comments on commit c7f378d

Please sign in to comment.