diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 86d332cc..3a782c92 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -12,8 +12,8 @@ import inflection import uritemplate -from django import __version__ as DJANGO_VERSION from django.apps import apps +from django.db.models.fields.reverse_related import ForeignObjectRel from django.urls.resolvers import ( # type: ignore _PATH_PARAMETER_COMPONENT_RE, RegexPattern, Resolver404, RoutePattern, URLPattern, URLResolver, get_resolver, @@ -341,19 +341,11 @@ def append_meta(schema, meta): return safe_ref({**schema, **meta}) -def get_field_from_model(model, field): +def _follow_field_source(model, path: List[str]): """ - this is a Django 2.2 compatibility function to access a field through a Deferred Attribute + navigate through root model via given navigation path. supports forward/reverse relations. """ - if DJANGO_VERSION.startswith('2'): - # field.field will in effect return self, i.e. a DeferredAttribute again (loop) - return model._meta.get_field(field.field_name) - else: - return field.field - - -def _follow_field_source(model, path): - field_or_property = getattr(model, path[0]) + field_or_property = getattr(model, path[0], None) if len(path) == 1: # end of traversal @@ -362,7 +354,12 @@ def _follow_field_source(model, path): elif callable(field_or_property): return field_or_property else: - return get_field_from_model(model, field_or_property) + field = model._meta.get_field(path[0]) + if isinstance(field, ForeignObjectRel): + # resolve DRF internal object to PK field as approximation + return field.get_related_field() # type: ignore + else: + return field else: if isinstance(field_or_property, property) or callable(field_or_property): if isinstance(field_or_property, property): @@ -377,7 +374,7 @@ def _follow_field_source(model, path): ) return _follow_field_source(target_model, path[1:]) else: - target_model = field_or_property.field.related_model + target_model = model._meta.get_field(path[0]).related_model return _follow_field_source(target_model, path[1:]) diff --git a/tests/test_plumbing.py b/tests/test_plumbing.py index b44d1e0d..9ad3b378 100644 --- a/tests/test_plumbing.py +++ b/tests/test_plumbing.py @@ -1,7 +1,8 @@ from django.db import models from rest_framework import serializers -from drf_spectacular.plumbing import force_instance, is_field, is_serializer +from drf_spectacular.openapi import AutoSchema +from drf_spectacular.plumbing import follow_field_source, force_instance, is_field, is_serializer def test_is_serializer(): @@ -30,3 +31,31 @@ def test_force_instance(): assert isinstance(force_instance(serializers.CharField), serializers.CharField) assert force_instance(5) == 5 assert force_instance(dict) == dict + + +def test_follow_field_source_forward_reverse(no_warnings): + class FFS1(models.Model): + field_bool = models.BooleanField() + + class FFS2(models.Model): + ffs1 = models.ForeignKey(FFS1, on_delete=models.PROTECT) + + class FFS3(models.Model): + ffs2 = models.ForeignKey(FFS2, on_delete=models.PROTECT) + field_float = models.FloatField() + + forward_field = follow_field_source(FFS3, ['ffs2', 'ffs1', 'field_bool']) + reverse_field = follow_field_source(FFS1, ['ffs2', 'ffs3', 'field_float']) + forward_model = follow_field_source(FFS3, ['ffs2', 'ffs1']) + reverse_model = follow_field_source(FFS1, ['ffs2', 'ffs3']) + + assert isinstance(forward_field, models.BooleanField) + assert isinstance(reverse_field, models.FloatField) + assert isinstance(forward_model, models.ForeignKey) + assert isinstance(reverse_model, models.AutoField) + + auto_schema = AutoSchema() + assert auto_schema._map_model_field(forward_field, None)['type'] == 'boolean' + assert auto_schema._map_model_field(reverse_field, None)['type'] == 'number' + assert auto_schema._map_model_field(forward_model, None)['type'] == 'integer' + assert auto_schema._map_model_field(reverse_model, None)['type'] == 'integer'