diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 8c61023e..d6a25f39 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -17,7 +17,7 @@ from rest_framework.schemas.utils import get_pk_description, is_list_view from drf_spectacular.app_settings import spectacular_settings -from drf_spectacular.plumbing import resolve_basic_type, warn, anyisinstance, force_serializer_instance, is_serializer, follow_field_source +from drf_spectacular.plumbing import resolve_basic_type, warn, anyisinstance, force_instance, is_serializer, follow_field_source, is_field from drf_spectacular.types import OpenApiTypes, PYTHON_TYPE_MAPPING, OPENAPI_TYPE_MAPPING from drf_spectacular.utils import PolymorphicProxySerializer @@ -545,7 +545,7 @@ def _map_polymorphic_proxy_serializer(self, method, serializer, nested): for sub_serializer in serializer.serializers: assert is_serializer(sub_serializer), 'sub-serializer must be either a Serializer or a PolymorphicProxySerializer.' - sub_serializer = force_serializer_instance(sub_serializer) + sub_serializer = force_instance(sub_serializer) if serializer.resource_type_field_name not in sub_serializer.fields: warn( @@ -638,8 +638,8 @@ def _map_field_validators(self, field, schema): def _map_type_hint(self, method): hint = getattr(method, '_spectacular_annotation', None) or typing.get_type_hints(method).get('return') - if is_serializer(hint): - return self._map_serializer_field(method, force_serializer_instance(hint)) + if is_serializer(hint) or is_field(hint): + return self._map_serializer_field(method, force_instance(hint)) elif hint in PYTHON_TYPE_MAPPING or hint in OPENAPI_TYPE_MAPPING: return resolve_basic_type(hint) else: @@ -685,7 +685,7 @@ def _get_request_body(self, path, method): if method not in ('PUT', 'PATCH', 'POST'): return {} - serializer = force_serializer_instance(self.get_request_serializer(path, method)) + serializer = force_instance(self.get_request_serializer(path, method)) if is_serializer(serializer): schema = self.resolve_serializer(method, serializer) @@ -733,7 +733,7 @@ def _get_response_bodies(self, path, method): def _get_response_for_code(self, path, method, serializer): # convenience feature: auto instantiate serializer classes - serializer = force_serializer_instance(serializer) + serializer = force_instance(serializer) if not serializer: return {'description': 'No response body'} diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 60475127..250ab878 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -2,7 +2,7 @@ import sys from django import __version__ as DJANGO_VERSION -from rest_framework import serializers +from rest_framework import fields, serializers from drf_spectacular.types import OPENAPI_TYPE_MAPPING, PYTHON_TYPE_MAPPING, OpenApiTypes from drf_spectacular.utils import PolymorphicProxySerializer @@ -16,20 +16,28 @@ def anyisinstance(obj, type_list): return any([isinstance(obj, t) for t in type_list]) -def force_serializer_instance(serializer): - if inspect.isclass(serializer) and issubclass(serializer, serializers.BaseSerializer): - return serializer() +def force_instance(serializer_or_field): + if not inspect.isclass(serializer_or_field): + return serializer_or_field + elif issubclass(serializer_or_field, (serializers.BaseSerializer, fields.Field)): + return serializer_or_field() else: - return serializer + return serializer_or_field def is_serializer(obj): return anyisinstance( - force_serializer_instance(obj), + force_instance(obj), [serializers.BaseSerializer, PolymorphicProxySerializer] ) +def is_field(obj): + # make sure obj is a serializer field and nothing else. + # guard against serializers because BaseSerializer(Field) + return isinstance(force_instance(obj), fields.Field) and not is_serializer(obj) + + def resolve_basic_type(type_): """ resolve either enum or actual type and yield schema template for modification diff --git a/tests/test_extend_schema.py b/tests/test_extend_schema.py index ba67bcf2..cc583500 100644 --- a/tests/test_extend_schema.py +++ b/tests/test_extend_schema.py @@ -39,6 +39,7 @@ class ErrorSerializer(serializers.Serializer): field_i = serializers.SerializerMethodField() field_j = serializers.SerializerMethodField() field_k = serializers.SerializerMethodField() + field_l = serializers.SerializerMethodField() @extend_schema_field(OpenApiTypes.DATETIME) def get_field_i(self, object): @@ -52,6 +53,10 @@ def get_field_j(self, object): def get_field_k(self, object): return InlineSerializer([], many=True).data + @extend_schema_field(serializers.ChoiceField(choices=['a', 'b'])) + def get_field_l(self, object): + return object.some_choice + with mock.patch('rest_framework.settings.api_settings.DEFAULT_SCHEMA_CLASS', AutoSchema): class DoesItAllViewset(viewsets.GenericViewSet): @@ -71,7 +76,7 @@ class DoesItAllViewset(viewsets.GenericViewSet): ], description='this weird endpoint needs some explaining', deprecated=True, - tags=['custom_tag'] + tags=['custom_tag'], ) def create(self, request, *args, **kwargs): return Response({}) diff --git a/tests/test_extend_schema.yml b/tests/test_extend_schema.yml index 967e73b1..94811645 100644 --- a/tests/test_extend_schema.yml +++ b/tests/test_extend_schema.yml @@ -123,6 +123,12 @@ components: items: $ref: '#/components/schemas/Inline' readOnly: true + field_l: + enum: + - a + - b + type: string + readOnly: true Inline: type: object properties: diff --git a/tests/test_plumbing.py b/tests/test_plumbing.py new file mode 100644 index 00000000..ba150161 --- /dev/null +++ b/tests/test_plumbing.py @@ -0,0 +1,32 @@ +from django.db import models +from rest_framework import serializers + +from drf_spectacular.plumbing import is_serializer, force_instance, is_field + + +def test_is_serializer(): + assert not is_serializer(serializers.SlugField) + assert not is_serializer(serializers.SlugField()) + + assert not is_serializer(models.CharField) + assert not is_serializer(models.CharField()) + + assert is_serializer(serializers.Serializer) + assert is_serializer(serializers.Serializer()) + + +def test_is_field(): + assert is_field(serializers.SlugField) + assert is_field(serializers.SlugField()) + + assert not is_field(models.CharField) + assert not is_field(models.CharField()) + + assert not is_field(serializers.Serializer) + assert not is_field(serializers.Serializer()) + + +def test_force_instance(): + assert isinstance(force_instance(serializers.CharField), serializers.CharField) + assert force_instance(5) == 5 + assert force_instance(dict) == dict