Skip to content

Commit

Permalink
Merge branch 'pr5' with improvements. closes #5
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Mar 11, 2020
2 parents e02bc62 + 3a6dea7 commit db9b81a
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 13 deletions.
12 changes: 6 additions & 6 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from rest_framework.schemas.utils import get_pk_description, is_list_view

from drf_spectacular.app_settings import spectacular_settings
from drf_spectacular.plumbing import resolve_basic_type, warn, anyisinstance, force_serializer_instance, is_serializer, follow_field_source
from drf_spectacular.plumbing import resolve_basic_type, warn, anyisinstance, force_instance, is_serializer, follow_field_source, is_field
from drf_spectacular.types import OpenApiTypes, PYTHON_TYPE_MAPPING, OPENAPI_TYPE_MAPPING
from drf_spectacular.utils import PolymorphicProxySerializer

Expand Down Expand Up @@ -545,7 +545,7 @@ def _map_polymorphic_proxy_serializer(self, method, serializer, nested):

for sub_serializer in serializer.serializers:
assert is_serializer(sub_serializer), 'sub-serializer must be either a Serializer or a PolymorphicProxySerializer.'
sub_serializer = force_serializer_instance(sub_serializer)
sub_serializer = force_instance(sub_serializer)

if serializer.resource_type_field_name not in sub_serializer.fields:
warn(
Expand Down Expand Up @@ -638,8 +638,8 @@ def _map_field_validators(self, field, schema):
def _map_type_hint(self, method):
hint = getattr(method, '_spectacular_annotation', None) or typing.get_type_hints(method).get('return')

if is_serializer(hint):
return self._map_serializer_field(method, force_serializer_instance(hint))
if is_serializer(hint) or is_field(hint):
return self._map_serializer_field(method, force_instance(hint))
elif hint in PYTHON_TYPE_MAPPING or hint in OPENAPI_TYPE_MAPPING:
return resolve_basic_type(hint)
else:
Expand Down Expand Up @@ -685,7 +685,7 @@ def _get_request_body(self, path, method):
if method not in ('PUT', 'PATCH', 'POST'):
return {}

serializer = force_serializer_instance(self.get_request_serializer(path, method))
serializer = force_instance(self.get_request_serializer(path, method))

if is_serializer(serializer):
schema = self.resolve_serializer(method, serializer)
Expand Down Expand Up @@ -733,7 +733,7 @@ def _get_response_bodies(self, path, method):

def _get_response_for_code(self, path, method, serializer):
# convenience feature: auto instantiate serializer classes
serializer = force_serializer_instance(serializer)
serializer = force_instance(serializer)

if not serializer:
return {'description': 'No response body'}
Expand Down
20 changes: 14 additions & 6 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys

from django import __version__ as DJANGO_VERSION
from rest_framework import serializers
from rest_framework import fields, serializers

from drf_spectacular.types import OPENAPI_TYPE_MAPPING, PYTHON_TYPE_MAPPING, OpenApiTypes
from drf_spectacular.utils import PolymorphicProxySerializer
Expand All @@ -16,20 +16,28 @@ def anyisinstance(obj, type_list):
return any([isinstance(obj, t) for t in type_list])


def force_serializer_instance(serializer):
if inspect.isclass(serializer) and issubclass(serializer, serializers.BaseSerializer):
return serializer()
def force_instance(serializer_or_field):
if not inspect.isclass(serializer_or_field):
return serializer_or_field
elif issubclass(serializer_or_field, (serializers.BaseSerializer, fields.Field)):
return serializer_or_field()
else:
return serializer
return serializer_or_field


def is_serializer(obj):
return anyisinstance(
force_serializer_instance(obj),
force_instance(obj),
[serializers.BaseSerializer, PolymorphicProxySerializer]
)


def is_field(obj):
# make sure obj is a serializer field and nothing else.
# guard against serializers because BaseSerializer(Field)
return isinstance(force_instance(obj), fields.Field) and not is_serializer(obj)


def resolve_basic_type(type_):
"""
resolve either enum or actual type and yield schema template for modification
Expand Down
7 changes: 6 additions & 1 deletion tests/test_extend_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ErrorSerializer(serializers.Serializer):
field_i = serializers.SerializerMethodField()
field_j = serializers.SerializerMethodField()
field_k = serializers.SerializerMethodField()
field_l = serializers.SerializerMethodField()

@extend_schema_field(OpenApiTypes.DATETIME)
def get_field_i(self, object):
Expand All @@ -52,6 +53,10 @@ def get_field_j(self, object):
def get_field_k(self, object):
return InlineSerializer([], many=True).data

@extend_schema_field(serializers.ChoiceField(choices=['a', 'b']))
def get_field_l(self, object):
return object.some_choice


with mock.patch('rest_framework.settings.api_settings.DEFAULT_SCHEMA_CLASS', AutoSchema):
class DoesItAllViewset(viewsets.GenericViewSet):
Expand All @@ -71,7 +76,7 @@ class DoesItAllViewset(viewsets.GenericViewSet):
],
description='this weird endpoint needs some explaining',
deprecated=True,
tags=['custom_tag']
tags=['custom_tag'],
)
def create(self, request, *args, **kwargs):
return Response({})
Expand Down
6 changes: 6 additions & 0 deletions tests/test_extend_schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ components:
items:
$ref: '#/components/schemas/Inline'
readOnly: true
field_l:
enum:
- a
- b
type: string
readOnly: true
Inline:
type: object
properties:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_plumbing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from django.db import models
from rest_framework import serializers

from drf_spectacular.plumbing import is_serializer, force_instance, is_field


def test_is_serializer():
assert not is_serializer(serializers.SlugField)
assert not is_serializer(serializers.SlugField())

assert not is_serializer(models.CharField)
assert not is_serializer(models.CharField())

assert is_serializer(serializers.Serializer)
assert is_serializer(serializers.Serializer())


def test_is_field():
assert is_field(serializers.SlugField)
assert is_field(serializers.SlugField())

assert not is_field(models.CharField)
assert not is_field(models.CharField())

assert not is_field(serializers.Serializer)
assert not is_field(serializers.Serializer())


def test_force_instance():
assert isinstance(force_instance(serializers.CharField), serializers.CharField)
assert force_instance(5) == 5
assert force_instance(dict) == dict

0 comments on commit db9b81a

Please sign in to comment.