Skip to content

Commit

Permalink
Closes #9608: Move from drf-yasg to spectacular
Browse files Browse the repository at this point in the history
Co-authored-by: arthanson <[email protected]>
Co-authored-by: jeremystretch <[email protected]>
  • Loading branch information
arthanson and jeremystretch authored Mar 30, 2023
1 parent 1be626e commit ecd0c56
Show file tree
Hide file tree
Showing 35 changed files with 513 additions and 339 deletions.
6 changes: 3 additions & 3 deletions base_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ django-timezone-field
# https://github.com/encode/django-rest-framework
djangorestframework

# Swagger/OpenAPI schema generation for REST APIs
# https://github.com/axnsan12/drf-yasg
drf-yasg[validation]
# Sane and flexible OpenAPI 3 schema generation for Django REST framework.
# https://github.com/tfranzel/drf-spectacular
drf-spectacular

# RSS feed parser
# https://github.com/kurtmckee/feedparser
Expand Down
8 changes: 8 additions & 0 deletions netbox/circuits/api/nested_serializers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from drf_spectacular.utils import extend_schema_field, extend_schema_serializer
from drf_spectacular.types import OpenApiTypes
from rest_framework import serializers

from circuits.models import *
Expand Down Expand Up @@ -29,6 +31,9 @@ class Meta:
# Providers
#

@extend_schema_serializer(
exclude_fields=('circuit_count',),
)
class NestedProviderSerializer(WritableNestedSerializer):
url = serializers.HyperlinkedIdentityField(view_name='circuits-api:provider-detail')
circuit_count = serializers.IntegerField(read_only=True)
Expand All @@ -54,6 +59,9 @@ class Meta:
# Circuits
#

@extend_schema_serializer(
exclude_fields=('circuit_count',),
)
class NestedCircuitTypeSerializer(WritableNestedSerializer):
url = serializers.HyperlinkedIdentityField(view_name='circuits-api:circuittype-detail')
circuit_count = serializers.IntegerField(read_only=True)
Expand Down
8 changes: 4 additions & 4 deletions netbox/circuits/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class Meta:

class CircuitCircuitTerminationSerializer(WritableNestedSerializer):
url = serializers.HyperlinkedIdentityField(view_name='circuits-api:circuittermination-detail')
site = NestedSiteSerializer()
provider_network = NestedProviderNetworkSerializer()
site = NestedSiteSerializer(allow_null=True)
provider_network = NestedProviderNetworkSerializer(allow_null=True)

class Meta:
model = CircuitTermination
Expand All @@ -110,8 +110,8 @@ class CircuitSerializer(NetBoxModelSerializer):
status = ChoiceField(choices=CircuitStatusChoices, required=False)
type = NestedCircuitTypeSerializer()
tenant = NestedTenantSerializer(required=False, allow_null=True)
termination_a = CircuitCircuitTerminationSerializer(read_only=True)
termination_z = CircuitCircuitTerminationSerializer(read_only=True)
termination_a = CircuitCircuitTerminationSerializer(read_only=True, allow_null=True)
termination_z = CircuitCircuitTerminationSerializer(read_only=True, allow_null=True)

class Meta:
model = Circuit
Expand Down
224 changes: 224 additions & 0 deletions netbox/core/api/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import re
import typing

from drf_spectacular.extensions import (
OpenApiSerializerFieldExtension,
OpenApiViewExtension,
)
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import (
ComponentRegistry,
ResolvedComponent,
build_basic_type,
build_media_type_object,
build_object_type,
is_serializer,
)
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema
from rest_framework.relations import ManyRelatedField

from netbox.api.fields import ChoiceField, SerializedPKRelatedField
from netbox.api.serializers import WritableNestedSerializer

# see netbox.api.routers.NetBoxRouter
BULK_ACTIONS = ("bulk_destroy", "bulk_partial_update", "bulk_update")
WRITABLE_ACTIONS = ("PATCH", "POST", "PUT")


class FixTimeZoneSerializerField(OpenApiSerializerFieldExtension):
target_class = 'timezone_field.rest_framework.TimeZoneSerializerField'

def map_serializer_field(self, auto_schema, direction):
return build_basic_type(OpenApiTypes.STR)


class ChoiceFieldFix(OpenApiSerializerFieldExtension):
target_class = 'netbox.api.fields.ChoiceField'

def map_serializer_field(self, auto_schema, direction):
if direction == 'request':
return build_basic_type(OpenApiTypes.STR)

elif direction == "response":
return build_object_type(
properties={
"value": build_basic_type(OpenApiTypes.STR),
"label": build_basic_type(OpenApiTypes.STR),
}
)


class NetBoxAutoSchema(AutoSchema):
"""
Overrides to drf_spectacular.openapi.AutoSchema to fix following issues:
1. bulk serializers cause operation_id conflicts with non-bulk ones
2. bulk operations should specify a list
3. bulk operations don't have filter params
4. bulk operations don't have pagination
5. bulk delete should specify input
"""

writable_serializers = {}

@property
def is_bulk_action(self):
if hasattr(self.view, "action") and self.view.action in BULK_ACTIONS:
return True
else:
return False

def get_operation_id(self):
"""
bulk serializers cause operation_id conflicts with non-bulk ones
bulk operations cause id conflicts in spectacular resulting in numerous:
Warning: operationId "xxx" has collisions [xxx]. "resolving with numeral suffixes"
code is modified from drf_spectacular.openapi.AutoSchema.get_operation_id
"""
if self.is_bulk_action:
tokenized_path = self._tokenize_path()
# replace dashes as they can be problematic later in code generation
tokenized_path = [t.replace('-', '_') for t in tokenized_path]

if self.method == 'GET' and self._is_list_view():
# this shouldn't happen, but keeping it here to follow base code
action = 'list'
else:
# action = self.method_mapping[self.method.lower()]
# use bulk name so partial_update -> bulk_partial_update
action = self.view.action.lower()

if not tokenized_path:
tokenized_path.append('root')

if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
tokenized_path.append('formatted')

return '_'.join(tokenized_path + [action])

# if not bulk - just return normal id
return super().get_operation_id()

def get_request_serializer(self) -> typing.Any:
# bulk operations should specify a list
serializer = super().get_request_serializer()

if self.is_bulk_action:
return type(serializer)(many=True)

# handle mapping for Writable serializers - adapted from dansheps original code
# for drf-yasg
if serializer is not None and self.method in WRITABLE_ACTIONS:
writable_class = self.get_writable_class(serializer)
if writable_class is not None:
if hasattr(serializer, "child"):
child_serializer = self.get_writable_class(serializer.child)
serializer = writable_class(context=serializer.context, child=child_serializer)
else:
serializer = writable_class(context=serializer.context)

return serializer

def get_response_serializers(self) -> typing.Any:
# bulk operations should specify a list
response_serializers = super().get_response_serializers()

if self.is_bulk_action:
return type(response_serializers)(many=True)

return response_serializers

def get_serializer_ref_name(self, serializer):
# from drf-yasg.utils
"""Get serializer's ref_name (or None for ModelSerializer if it is named 'NestedSerializer')
:param serializer: Serializer instance
:return: Serializer's ``ref_name`` or ``None`` for inline serializer
:rtype: str or None
"""
serializer_meta = getattr(serializer, 'Meta', None)
serializer_name = type(serializer).__name__
if hasattr(serializer_meta, 'ref_name'):
ref_name = serializer_meta.ref_name
elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer):
ref_name = None
else:
ref_name = serializer_name
if ref_name.endswith('Serializer'):
ref_name = ref_name[: -len('Serializer')]
return ref_name

def get_writable_class(self, serializer):
properties = {}
fields = {} if hasattr(serializer, 'child') else serializer.fields

for child_name, child in fields.items():
if isinstance(child, (ChoiceField, WritableNestedSerializer)):
properties[child_name] = None
elif isinstance(child, ManyRelatedField) and isinstance(child.child_relation, SerializedPKRelatedField):
properties[child_name] = None

if not properties:
return None

if type(serializer) not in self.writable_serializers:
writable_name = 'Writable' + type(serializer).__name__
meta_class = getattr(type(serializer), 'Meta', None)
if meta_class:
ref_name = 'Writable' + self.get_serializer_ref_name(serializer)
writable_meta = type('Meta', (meta_class,), {'ref_name': ref_name})
properties['Meta'] = writable_meta

self.writable_serializers[type(serializer)] = type(writable_name, (type(serializer),), properties)

writable_class = self.writable_serializers[type(serializer)]
return writable_class

def get_filter_backends(self):
# bulk operations don't have filter params
if self.is_bulk_action:
return []
return super().get_filter_backends()

def _get_paginator(self):
# bulk operations don't have pagination
if self.is_bulk_action:
return None
return super()._get_paginator()

def _get_request_body(self, direction='request'):
# bulk delete should specify input
if (not self.is_bulk_action) or (self.method != 'DELETE'):
return super()._get_request_body(direction)

# rest from drf_spectacular.openapi.AutoSchema._get_request_body
# but remove the unsafe method check

request_serializer = self.get_request_serializer()

if isinstance(request_serializer, dict):
content = []
request_body_required = True
for media_type, serializer in request_serializer.items():
schema, partial_request_body_required = self._get_request_for_media_type(serializer, direction)
examples = self._get_examples(serializer, direction, media_type)
if schema is None:
continue
content.append((media_type, schema, examples))
request_body_required &= partial_request_body_required
else:
schema, request_body_required = self._get_request_for_media_type(request_serializer, direction)
if schema is None:
return None
content = [
(media_type, schema, self._get_examples(request_serializer, direction, media_type))
for media_type in self.map_parsers()
]

request_body = {
'content': {
media_type: build_media_type_object(schema, examples) for media_type, schema, examples in content
}
}
if request_body_required:
request_body['required'] = request_body_required
return request_body
1 change: 1 addition & 0 deletions netbox/core/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ class CoreConfig(AppConfig):

def ready(self):
from . import data_backends, search
from core.api import schema # noqa: E402
Loading

0 comments on commit ecd0c56

Please sign in to comment.