Skip to content

Commit

Permalink
Fix type hint support for functools cached_property wrapped funcs
Browse files Browse the repository at this point in the history
functools added a `cached_property` wrapper in Python 3.8. Support type
hinting for properties wrapped in both Django's `cached_property` and
functool's `cached_property`.
  • Loading branch information
jalaziz committed Sep 6, 2021
1 parent 1407059 commit 84da467
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
15 changes: 12 additions & 3 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from django.utils.functional import Promise, cached_property
from django.utils.module_loading import import_string
from django.utils.version import PY38
from rest_framework import exceptions, fields, mixins, serializers, versioning
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
Expand All @@ -46,6 +47,11 @@
class Choices: # type: ignore
pass

if PY38:
CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) # type: ignore
else:
CACHED_PROPERTY_FUNCS = (cached_property,) # type: ignore

T = TypeVar('T')


Expand Down Expand Up @@ -373,7 +379,7 @@ def _follow_field_source(model, path: List[str]):
# end of traversal
if isinstance(field_or_property, property):
return field_or_property.fget
elif isinstance(field_or_property, cached_property):
elif isinstance(field_or_property, CACHED_PROPERTY_FUNCS):
return field_or_property.func
elif callable(field_or_property):
return field_or_property
Expand All @@ -397,10 +403,13 @@ def _follow_field_source(model, path: List[str]):
else:
return field
else:
if isinstance(field_or_property, (property, cached_property)) or callable(field_or_property):
if (
isinstance(field_or_property, (property,) + CACHED_PROPERTY_FUNCS)
or callable(field_or_property)
):
if isinstance(field_or_property, property):
target_model = _follow_return_type(field_or_property.fget)
elif isinstance(field_or_property, cached_property):
elif isinstance(field_or_property, CACHED_PROPERTY_FUNCS):
target_model = _follow_return_type(field_or_property.func)
else:
target_model = _follow_return_type(field_or_property)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import tempfile
import uuid
from datetime import date, datetime, timedelta
Expand Down Expand Up @@ -112,6 +113,10 @@ def field_model_property_float(self) -> float:
def field_model_cached_property_float(self) -> float:
return 1.337

@functools.cached_property # type: ignore
def field_model_py_cached_property_float(self) -> float:
return 1.337

@property
def field_list(self):
return [1.1, 2.2, 3.3]
Expand All @@ -134,6 +139,10 @@ def sub_object(self) -> SubObject:
def sub_object_cached(self) -> SubObject:
return SubObject(self)

@functools.cached_property # type: ignore
def sub_object_py_cached(self) -> SubObject:
return SubObject(self)

@property
def optional_sub_object(self) -> Optional[SubObject]:
return SubObject(self)
Expand Down Expand Up @@ -199,6 +208,8 @@ def get_field_method_object(self, obj) -> dict:

field_model_cached_property_float = serializers.ReadOnlyField()

field_model_py_cached_property_float = serializers.ReadOnlyField()

field_dict_int = serializers.DictField(
child=serializers.IntegerField(),
source='field_json',
Expand All @@ -217,6 +228,14 @@ def get_field_method_object(self, obj) -> dict:
field_sub_object_cached_nested_calculated = serializers.ReadOnlyField(source='sub_object_cached.nested.calculated')
field_sub_object_cached_model_int = serializers.ReadOnlyField(source='sub_object_cached.model_instance.field_int')

field_sub_object_py_cached_calculated = serializers.ReadOnlyField(source='sub_object_py_cached.calculated')
field_sub_object_py_cached_nested_calculated = serializers.ReadOnlyField(
source='sub_object_py_cached.nested.calculated',
)
field_sub_object_py_cached_model_int = serializers.ReadOnlyField(
source='sub_object_py_cached.model_instance.field_int',
)

# typing.Optional
field_optional_sub_object_calculated = serializers.ReadOnlyField(
source='optional_sub_object.calculated',
Expand Down
17 changes: 17 additions & 0 deletions tests/test_fields.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ components:
type: number
format: float
readOnly: true
field_model_py_cached_property_float:
type: number
format: float
readOnly: true
field_dict_int:
type: object
additionalProperties:
Expand All @@ -192,6 +196,15 @@ components:
field_sub_object_cached_model_int:
type: integer
readOnly: true
field_sub_object_py_cached_calculated:
type: integer
readOnly: true
field_sub_object_py_cached_nested_calculated:
type: integer
readOnly: true
field_sub_object_py_cached_model_int:
type: integer
readOnly: true
field_optional_sub_object_calculated:
type: integer
readOnly: true
Expand Down Expand Up @@ -303,6 +316,7 @@ components:
- field_method_object
- field_model_cached_property_float
- field_model_property_float
- field_model_py_cached_property_float
- field_o2o
- field_optional_sub_object_calculated
- field_posint
Expand All @@ -324,6 +338,9 @@ components:
- field_sub_object_model_int
- field_sub_object_nested_calculated
- field_sub_object_optional_int
- field_sub_object_py_cached_calculated
- field_sub_object_py_cached_model_int
- field_sub_object_py_cached_nested_calculated
- field_text
- field_time
- field_url
Expand Down

0 comments on commit 84da467

Please sign in to comment.