Skip to content

Commit

Permalink
Fix a bug which meant that schema generation could be incorrect if vi…
Browse files Browse the repository at this point in the history
…ewsets were ordered in a particular way
  • Loading branch information
Paul Gilmartin authored and Paul Gilmartin committed May 1, 2022
1 parent 4acdc27 commit 7dabba5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
17 changes: 9 additions & 8 deletions graph_wrap/django_rest_framework/api_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def root_type(self):
root_type = SerializerTransformer(
self._root_serializer,
self.type_mapping,
self._root_graphene_type_name,
seen_nested_serializers=self.seen_nested_serializers,
).graphene_object_type()
return root_type
Expand Down Expand Up @@ -69,14 +68,13 @@ def __init__(
self,
serializer,
type_mapping=None,
graphene_type_name='',
seen_nested_serializers=None,
):
self._serializer = serializer
self.type_mapping = type_mapping if type_mapping is not None else dict()
self.seen_nested_serializers = seen_nested_serializers if seen_nested_serializers is not None else dict()
self._graphene_type_name = (
graphene_type_name or self._build_graphene_type_name())
self.seen_nested_serializers = (
seen_nested_serializers if seen_nested_serializers is not None else dict())
self._graphene_type_name = self._build_graphene_type_name()
self._graphene_object_type_class_attrs = dict()

def graphene_object_type(self):
Expand Down Expand Up @@ -231,6 +229,9 @@ def _build_graphene_type_name(self):
model = self._field.child.Meta.model.__name__.lower()
else:
model = self._field.Meta.model.__name__.lower()
return self._get_type_number_for_model(model, serializer_cls)

def _get_type_number_for_model(self, model, serializer_cls):
type_name = '{}_type'.format(model)
types_for_model = [
t for t in self.type_mapping if t.startswith(type_name)]
Expand All @@ -252,9 +253,9 @@ def _build_graphene_type_name(self):
related_view_name = related_view_name.split('-')[0]
related_view_set = next(
(v for v in views if v.basename == related_view_name))
serializer = related_view_set.get_serializer()
model = serializer.Meta.model.__name__.lower()
return '{}_type'.format(model)
related_serializer = related_view_set.get_serializer()
model = related_serializer.Meta.model.__name__.lower()
return self._get_type_number_for_model(model, related_serializer.__class__)


class GenericValuedFieldTransformer(ScalarValuedFieldTransformer):
Expand Down
4 changes: 2 additions & 2 deletions tests/django_rest_framework_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_query_fields(self):
)

def test_author_type(self):
author_type = self.type_map['author_type']
author_type = self.type_map['author_type_2']
self.assertEqual(
{'name',
'age',
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_post_query_with_fragments(self):
fragment postFragment on post_type {
content
}
fragment authorFragment on author_type_2 {
fragment authorFragment on author_type {
name
}
'''
Expand Down
2 changes: 1 addition & 1 deletion tests/django_rest_framework_api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
AuthorViewSet, PostViewSet)

router = routers.SimpleRouter()
router.register(r'writer', AuthorViewSet)
router.register(r'post', PostViewSet)
router.register(r'writer', AuthorViewSet)

0 comments on commit 7dabba5

Please sign in to comment.