Skip to content

Commit

Permalink
AutoSchema: Inherit DRF AutoSchema
Browse files Browse the repository at this point in the history
Replace custom _map_basic_serializer with _map_field,
and invoke super _map_serializer as a fallback.

DRF AutoSchema PR 7257 makes all of these fields public,
so Spectacular AutoSchema should implement them and use
them when appropriate for interoperability.

Related to #31
Related to #45
  • Loading branch information
jayvdb committed May 1, 2020
1 parent c21a177 commit cd32c99
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 43 deletions.
86 changes: 43 additions & 43 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from rest_framework import permissions, renderers, serializers
from rest_framework.fields import _UnvalidatedField, empty
from rest_framework.generics import GenericAPIView
from rest_framework.schemas.inspectors import ViewInspector
from rest_framework.schemas.openapi import AutoSchema as DRFAutoSchema
from rest_framework.schemas.utils import get_pk_description
from rest_framework.settings import api_settings
from rest_framework.views import APIView
Expand All @@ -34,7 +34,7 @@
from drf_spectacular.authentication import OpenApiAuthenticationExtension


class AutoSchema(ViewInspector):
class AutoSchema(DRFAutoSchema):
method_mapping = {
'get': 'retrieve',
'post': 'create',
Expand Down Expand Up @@ -541,55 +541,28 @@ def _map_min_max(self, field, content):
if field.min_value:
content['minimum'] = field.min_value

def _map_serializer(self, serializer, direction):
def _map_serializer(self, serializer, direction=None):
serializer = force_instance(serializer)
serializer_extension = OpenApiSerializerExtension.get_match(serializer)

if serializer_extension:
if serializer_extension and direction:
return serializer_extension.map_serializer(self, direction)
else:
return self._map_basic_serializer(serializer, direction)

def _map_basic_serializer(self, serializer, direction):
required = []
properties = {}

for field in serializer.fields.values():
if isinstance(field, serializers.HiddenField):
continue

if field.required:
required.append(field.field_name)
if hasattr(DRFAutoSchema, 'map_serializer'):
result = super().map_serializer(serializer)
else:
result = super()._map_serializer(serializer)

schema = self._map_serializer_field(field)
if result.get('properties'):
# Move 'type' to top
new = {'type': 'object'}
new.update(result)
result = new

if field.read_only:
schema['readOnly'] = True
if field.write_only:
schema['writeOnly'] = True
if field.allow_null:
schema['nullable'] = True
if field.default is not None and field.default != empty and not callable(field.default):
schema['default'] = field.to_representation(field.default)
if field.help_text:
schema['description'] = str(field.help_text)
self._map_field_validators(field, schema)

# sibling entries to $ref will be ignored as it replaces itself and its context with
# the referenced object. Wrap it in a separate context.
if '$ref' in schema and len(schema) > 1:
schema = {'allOf': [{'$ref': schema.pop('$ref')}], **schema}

properties[field.field_name] = schema

result = {
'type': 'object',
'properties': properties
}
if required and (self.method != 'PATCH' or direction == 'response'):
result['required'] = required
if result.get('required') and self.method == 'PATCH' and direction == 'request':
del result['required']

return result
return result

def _map_field_validators(self, field, schema):
for v in field.validators:
Expand Down Expand Up @@ -624,6 +597,33 @@ def _map_field_validators(self, field, schema):
schema['maximum'] = int(digits * '9') + 1
schema['minimum'] = -schema['maximum']

def _map_field(self, field):
result = super()._map_field(field)
schema = self._map_serializer_field(field)

result.update(schema)
if result.get('properties'):
result['type'] = 'object'

# sibling entries to $ref will be ignored as it replaces itself and its context with
# the referenced object. Wrap it in a separate context.
if '$ref' in result and len(result) > 1:
return {'allOf': [{'$ref': schema.pop('$ref')}], **schema}

new = {}
if result.get('enum'):
new['enum'] = result['enum']
if result.get('type'):
new['type'] = result['type']
if result.get('format'):
new['format'] = result['format']

new.update(result)

return new

map_field = _map_field

def _map_type_hint(self, method):
hint = getattr(method, '_spectacular_annotation', None) or typing.get_type_hints(method).get('return')

Expand Down
2 changes: 2 additions & 0 deletions tests/test_fields.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,10 @@ components:
minimum: -1000
field_file:
type: string
format: binary
field_img:
type: string
format: binary
field_date:
type: string
format: date
Expand Down

0 comments on commit cd32c99

Please sign in to comment.