Skip to content

Commit

Permalink
bugfix forward/reverse model traversal #323
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Mar 5, 2021
1 parent d8c416a commit 6f12e8d
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 8 deletions.
3 changes: 3 additions & 0 deletions drf_spectacular/contrib/django_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def resolve_filter_field(self, auto_schema, model, filterset_class, field_name,
# default to string if nothing else works
schema = build_basic_type(OpenApiTypes.STR)

# primary keys are usually non-editable (readOnly=True) and map_model_field correctly
# signals that attribute. however this does not apply in this context.
schema.pop('readOnly', None)
# enrich schema with additional info from filter_field
enum = schema.pop('enum', None)
if 'choices' in filter_field.extra:
Expand Down
5 changes: 3 additions & 2 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,9 @@ def _map_serializer_field(self, field, direction, collect_meta=True):
# estimates the relating model field and jumps to it's target model PK field.
# also differentiate as source can be direct (pk) or relation field (model).
model_field = follow_field_source(model, source)
if anyisinstance(model_field, [models.ForeignKey, models.ManyToManyField]):
model_field = model_field.target_field
if callable(model_field):
# follow_field_source bailed with a warning. be graceful and default to str.
model_field = models.TextField()

# primary keys are usually non-editable (readOnly=True) and map_model_field correctly
# signals that attribute. however this does not apply in the context of relations.
Expand Down
20 changes: 18 additions & 2 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import inflection
import uritemplate
from django.apps import apps
from django.db.models.fields.related_descriptors import (
ForwardManyToOneDescriptor, ManyToManyDescriptor, ReverseManyToOneDescriptor,
ReverseOneToOneDescriptor,
)
from django.db.models.fields.reverse_related import ForeignObjectRel
from django.urls.resolvers import ( # type: ignore
_PATH_PARAMETER_COMPONENT_RE, RegexPattern, Resolver404, RoutePattern, URLPattern, URLResolver,
Expand Down Expand Up @@ -359,11 +363,23 @@ def _follow_field_source(model, path: List[str]):
return field_or_property.func
elif callable(field_or_property):
return field_or_property
elif isinstance(field_or_property, ManyToManyDescriptor):
if field_or_property.reverse:
return field_or_property.rel.target_field # m2m reverse
else:
return field_or_property.field.target_field # m2m forward
elif isinstance(field_or_property, ReverseOneToOneDescriptor):
return field_or_property.related.target_field # o2o reverse
elif isinstance(field_or_property, ReverseManyToOneDescriptor):
return field_or_property.rel.target_field # type: ignore # foreign reverse
elif isinstance(field_or_property, ForwardManyToOneDescriptor):
return field_or_property.field.target_field # type: ignore # o2o & foreign forward
else:
field = model._meta.get_field(path[0])
if isinstance(field, ForeignObjectRel):
# resolve DRF internal object to PK field as approximation
return field.get_related_field() # type: ignore
# case only occurs when relations are traversed in reverse and
# not via the related_name (default: X_set) but the model name.
return field.target_field
else:
return field
else:
Expand Down
10 changes: 6 additions & 4 deletions tests/test_plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ def test_force_instance():

def test_follow_field_source_forward_reverse(no_warnings):
class FFS1(models.Model):
id = models.UUIDField(primary_key=True)
field_bool = models.BooleanField()

class FFS2(models.Model):
ffs1 = models.ForeignKey(FFS1, on_delete=models.PROTECT)

class FFS3(models.Model):
id = models.CharField(primary_key=True, max_length=3)
ffs2 = models.ForeignKey(FFS2, on_delete=models.PROTECT)
field_float = models.FloatField()

Expand All @@ -63,14 +65,14 @@ class FFS3(models.Model):

assert isinstance(forward_field, models.BooleanField)
assert isinstance(reverse_field, models.FloatField)
assert isinstance(forward_model, models.ForeignKey)
assert isinstance(reverse_model, models.AutoField)
assert isinstance(forward_model, models.UUIDField)
assert isinstance(reverse_model, models.CharField)

auto_schema = AutoSchema()
assert auto_schema._map_model_field(forward_field, None)['type'] == 'boolean'
assert auto_schema._map_model_field(reverse_field, None)['type'] == 'number'
assert auto_schema._map_model_field(forward_model, None)['type'] == 'integer'
assert auto_schema._map_model_field(reverse_model, None)['type'] == 'integer'
assert auto_schema._map_model_field(forward_model, None)['type'] == 'string'
assert auto_schema._map_model_field(reverse_model, None)['type'] == 'string'


def test_detype_patterns_with_module_includes(no_warnings):
Expand Down
122 changes: 122 additions & 0 deletions tests/test_regressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class M1(models.Model):
pass # pragma: no cover

class M2(models.Model):
id = models.UUIDField()
m1_r = models.ForeignKey(M1, on_delete=models.CASCADE)
m1_rw = models.ForeignKey(M1, on_delete=models.CASCADE)

Expand Down Expand Up @@ -86,6 +87,127 @@ class M3Viewset(viewsets.ReadOnlyModelViewSet):
assert properties['m2']['type'] == 'integer'


def test_serializer_reverse_relations_including_read_only(no_warnings):
class M5(models.Model):
pass

class M5One(models.Model):
id = models.CharField(primary_key=True, max_length=10)
field = models.OneToOneField(M5, on_delete=models.CASCADE)

class M5Many(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4)
field = models.ManyToManyField(M5)

class M5Foreign(models.Model):
id = models.FloatField(primary_key=True)
field = models.ForeignKey(M5, on_delete=models.CASCADE)

class XSerializer(serializers.ModelSerializer):
m5foreign_set_explicit = serializers.PrimaryKeyRelatedField(
many=True, source='m5foreign_set', queryset=M5Foreign.objects.all()
)
m5foreign_set_ro = serializers.PrimaryKeyRelatedField(
many=True, source='m5foreign_set', read_only=True,
)
m5many_set_explicit = serializers.PrimaryKeyRelatedField(
many=True, source='m5many_set', queryset=M5Many.objects.all()
)
m5many_set_ro = serializers.PrimaryKeyRelatedField(
many=True, source='m5many_set', read_only=True,
)
m5one_ro = serializers.PrimaryKeyRelatedField(
source='m5one', read_only=True,
)

class Meta:
model = M5
fields = [
'm5many_set',
'm5many_set_explicit',
'm5many_set_ro',
'm5foreign_set',
'm5foreign_set_explicit',
'm5foreign_set_ro',
'm5one',
'm5one_ro',
]

class TestViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
queryset = M5.objects.all()
serializer_class = XSerializer

schema = generate_schema('/x/', TestViewSet)
properties = schema['components']['schemas']['X']['properties']

m5many_pk = {'type': 'string', 'format': 'uuid'}
assert properties['m5many_set']['items'] == m5many_pk
assert properties['m5many_set_ro']['items'] == m5many_pk
assert properties['m5many_set_explicit']['items'] == m5many_pk

m5foreign_pk = {'type': 'number', 'format': 'float'}
assert properties['m5foreign_set']['items'] == m5foreign_pk
assert properties['m5foreign_set_ro']['items'] == m5foreign_pk
assert properties['m5foreign_set_explicit']['items'] == m5foreign_pk

assert properties['m5one'] == {'type': 'string'}
assert properties['m5one_ro'] == {'readOnly': True, 'type': 'string'}


def test_serializer_forward_relations_including_read_only(no_warnings):
class M6One(models.Model):
id = models.CharField(primary_key=True, max_length=10)

class M6Many(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4)

class M6Foreign(models.Model):
id = models.FloatField(primary_key=True)

class M6(models.Model):
field_one = models.OneToOneField(M6One, on_delete=models.CASCADE)
field_many = models.ManyToManyField(M6Many)
field_foreign = models.ForeignKey(M6Foreign, on_delete=models.CASCADE)

class XSerializer(serializers.ModelSerializer):
field_one_ro = serializers.PrimaryKeyRelatedField(
source='field_one', read_only=True
)
field_foreign_ro = serializers.PrimaryKeyRelatedField(
source='field_foreign', read_only=True
)
field_many_ro = serializers.PrimaryKeyRelatedField(
source='field_many', read_only=True, many=True
)

class Meta:
model = M6
fields = [
'field_one',
'field_one_ro',
'field_many',
'field_many_ro',
'field_foreign',
'field_foreign_ro',
]

class TestViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
queryset = M6.objects.all()
serializer_class = XSerializer

schema = generate_schema('/x/', TestViewSet)
properties = schema['components']['schemas']['X']['properties']

assert properties['field_one'] == {'type': 'string'}
assert properties['field_one_ro'] == {'type': 'string', 'readOnly': True}
assert properties['field_foreign'] == {'type': 'number', 'format': 'float'}
assert properties['field_foreign_ro'] == {'type': 'number', 'format': 'float', 'readOnly': True}
assert properties['field_many'] == {'type': 'array', 'items': {'type': 'string', 'format': 'uuid'}}
assert properties['field_many_ro'] == {
'type': 'array', 'items': {'type': 'string', 'format': 'uuid'}, 'readOnly': True
}


def test_path_implicit_required(no_warnings):
class M2Serializer(serializers.Serializer):
pass # pragma: no cover
Expand Down

0 comments on commit 6f12e8d

Please sign in to comment.