From cd32c99b1d4cc0dd55769808f55f9a74c6ba29ea Mon Sep 17 00:00:00 2001 From: John Vandenberg Date: Fri, 1 May 2020 17:20:05 +0700 Subject: [PATCH] AutoSchema: Inherit DRF AutoSchema Replace custom _map_basic_serializer with _map_field, and invoke super _map_serializer as a fallback. DRF AutoSchema PR 7257 makes all of these fields public, so Spectacular AutoSchema should implement them and use them when appropriate for interoperability. Related to https://github.com/tfranzel/drf-spectacular/issues/31 Related to https://github.com/tfranzel/drf-spectacular/issues/45 --- drf_spectacular/openapi.py | 86 +++++++++++++++++++------------------- tests/test_fields.yml | 2 + 2 files changed, 45 insertions(+), 43 deletions(-) diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 4f17af52..f115d5e3 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -12,7 +12,7 @@ from rest_framework import permissions, renderers, serializers from rest_framework.fields import _UnvalidatedField, empty from rest_framework.generics import GenericAPIView -from rest_framework.schemas.inspectors import ViewInspector +from rest_framework.schemas.openapi import AutoSchema as DRFAutoSchema from rest_framework.schemas.utils import get_pk_description from rest_framework.settings import api_settings from rest_framework.views import APIView @@ -34,7 +34,7 @@ from drf_spectacular.authentication import OpenApiAuthenticationExtension -class AutoSchema(ViewInspector): +class AutoSchema(DRFAutoSchema): method_mapping = { 'get': 'retrieve', 'post': 'create', @@ -541,55 +541,28 @@ def _map_min_max(self, field, content): if field.min_value: content['minimum'] = field.min_value - def _map_serializer(self, serializer, direction): + def _map_serializer(self, serializer, direction=None): serializer = force_instance(serializer) serializer_extension = OpenApiSerializerExtension.get_match(serializer) - if serializer_extension: + if serializer_extension and direction: return serializer_extension.map_serializer(self, direction) else: - return self._map_basic_serializer(serializer, direction) - - def _map_basic_serializer(self, serializer, direction): - required = [] - properties = {} - - for field in serializer.fields.values(): - if isinstance(field, serializers.HiddenField): - continue - - if field.required: - required.append(field.field_name) + if hasattr(DRFAutoSchema, 'map_serializer'): + result = super().map_serializer(serializer) + else: + result = super()._map_serializer(serializer) - schema = self._map_serializer_field(field) + if result.get('properties'): + # Move 'type' to top + new = {'type': 'object'} + new.update(result) + result = new - if field.read_only: - schema['readOnly'] = True - if field.write_only: - schema['writeOnly'] = True - if field.allow_null: - schema['nullable'] = True - if field.default is not None and field.default != empty and not callable(field.default): - schema['default'] = field.to_representation(field.default) - if field.help_text: - schema['description'] = str(field.help_text) - self._map_field_validators(field, schema) - - # sibling entries to $ref will be ignored as it replaces itself and its context with - # the referenced object. Wrap it in a separate context. - if '$ref' in schema and len(schema) > 1: - schema = {'allOf': [{'$ref': schema.pop('$ref')}], **schema} - - properties[field.field_name] = schema - - result = { - 'type': 'object', - 'properties': properties - } - if required and (self.method != 'PATCH' or direction == 'response'): - result['required'] = required + if result.get('required') and self.method == 'PATCH' and direction == 'request': + del result['required'] - return result + return result def _map_field_validators(self, field, schema): for v in field.validators: @@ -624,6 +597,33 @@ def _map_field_validators(self, field, schema): schema['maximum'] = int(digits * '9') + 1 schema['minimum'] = -schema['maximum'] + def _map_field(self, field): + result = super()._map_field(field) + schema = self._map_serializer_field(field) + + result.update(schema) + if result.get('properties'): + result['type'] = 'object' + + # sibling entries to $ref will be ignored as it replaces itself and its context with + # the referenced object. Wrap it in a separate context. + if '$ref' in result and len(result) > 1: + return {'allOf': [{'$ref': schema.pop('$ref')}], **schema} + + new = {} + if result.get('enum'): + new['enum'] = result['enum'] + if result.get('type'): + new['type'] = result['type'] + if result.get('format'): + new['format'] = result['format'] + + new.update(result) + + return new + + map_field = _map_field + def _map_type_hint(self, method): hint = getattr(method, '_spectacular_annotation', None) or typing.get_type_hints(method).get('return') diff --git a/tests/test_fields.yml b/tests/test_fields.yml index 6f5a3d76..9855192a 100644 --- a/tests/test_fields.yml +++ b/tests/test_fields.yml @@ -161,8 +161,10 @@ components: minimum: -1000 field_file: type: string + format: binary field_img: type: string + format: binary field_date: type: string format: date