Skip to content

Commit

Permalink
bugfix mock request asymmetry #370 #250
Browse files Browse the repository at this point in the history
we now always generate a mocked request, not just when schema is
genenerated over CLI. use incoming request as reference if available.
  • Loading branch information
tfranzel committed May 2, 2021
1 parent ced40a0 commit 2eaac13
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 23 deletions.
25 changes: 9 additions & 16 deletions drf_spectacular/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def create_view(self, callback, method, request=None):
original_cls = callback.cls
callback.cls = override_view.view_replacement()

view = super().create_view(callback, method, request)
# we refrain from passing request and deal with it ourselves in parse()
view = super().create_view(callback, method, None)

# drf-yasg compatibility feature. makes the view aware that we are running
# schema generation and not a real request.
Expand Down Expand Up @@ -160,13 +161,13 @@ def _initialise_endpoints(self):
self.inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
self.endpoints = self.inspector.get_api_endpoints()

def _get_paths_and_endpoints(self, request):
def _get_paths_and_endpoints(self):
"""
Generate (path, method, view) given (path, method, callback) for paths.
"""
view_endpoints = []
for path, path_regex, method, callback in self.endpoints:
view = self.create_view(callback, method, request)
view = self.create_view(callback, method)
path = self.coerce_path(path, method, view)
view_endpoints.append((path, path_regex, method, view))

Expand All @@ -176,7 +177,7 @@ def parse(self, input_request, public):
""" Iterate endpoints generating per method path operations. """
result = {}
self._initialise_endpoints()
endpoints = self._get_paths_and_endpoints(None if public else input_request)
endpoints = self._get_paths_and_endpoints()

if spectacular_settings.SCHEMA_PATH_PREFIX is None:
# estimate common path prefix if none was given. only use it if we encountered more
Expand All @@ -193,17 +194,10 @@ def parse(self, input_request, public):
path_prefix = '^' + path_prefix # make sure regex only matches from the start

for path, path_regex, method, view in endpoints:
if not self.has_view_permissions(path, method, view):
continue

if input_request:
request = input_request
else:
# mocked request to allow certain operations in get_queryset and get_serializer[_class]
# without exceptions being raised due to no request.
request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, input_request)
view.request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, input_request)

view.request = request
if not (public or self.has_view_permissions(path, method, view)):
continue

if view.versioning_class and not is_versioning_supported(view.versioning_class):
warn(
Expand All @@ -212,8 +206,7 @@ def parse(self, input_request, public):
)
elif view.versioning_class:
version = (
self.api_version # generator was explicitly versioned
or getattr(request, 'version', None) # incoming request was versioned
self.api_version # explicit version from CLI, SpecView or SpecView request
or view.versioning_class.default_version # fallback
)
if not version:
Expand Down
13 changes: 13 additions & 0 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,8 +872,21 @@ def camelize_operation(path, operation):


def build_mock_request(method, path, view, original_request, **kwargs):
""" build a mocked request and use original request as reference if available """
request = getattr(APIRequestFactory(), method.lower())(path=path)
request = view.initialize_request(request)
if original_request:
request.user = original_request.user
request.auth = original_request.auth
# ignore headers related to authorization as it has been handled above.
# also ignore ACCEPT as the MIME type refers to SpectacularAPIView and the
# version (if available) has already been processed by SpectacularAPIView.
for name, value in original_request.META.items():
if not name.startswith('HTTP_'):
continue
if name in ['HTTP_ACCEPT', 'HTTP_COOKIE', 'HTTP_AUTHORIZATION']:
continue
request.META[name] = value
return request


Expand Down
5 changes: 4 additions & 1 deletion drf_spectacular/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def get(self, request, *args, **kwargs):
return self._get_schema_response(request)

def _get_schema_response(self, request):
generator = self.generator_class(urlconf=self.urlconf, api_version=self.api_version)
# version specified as parameter to the view always takes precedence. after
# that we try to source version through the schema view's own versioning_class.
version = self.api_version or request.version
generator = self.generator_class(urlconf=self.urlconf, api_version=version)
return Response(generator.get_schema(request=request, public=self.serve_public))


Expand Down
19 changes: 13 additions & 6 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import difflib
import json
import os

from drf_spectacular.validation import validate_schema
Expand Down Expand Up @@ -35,15 +36,21 @@ def assert_schema(schema, reference_filename, transforms=None):
for t in transforms or []:
generated = t(generated)

assert_equal(expected, generated)
# this is more a less a sanity check as checked-in schemas should be valid anyhow
validate_schema(schema)


def assert_equal(a, b):
if not isinstance(a, str) or isinstance(b, str):
a = json.dumps(a, indent=4)
b = json.dumps(b, indent=4)
diff = difflib.unified_diff(
expected.splitlines(True),
generated.splitlines(True),
a.splitlines(True),
b.splitlines(True),
)
diff = ''.join(diff)
assert expected == generated and not diff, diff

# this is more a less a sanity check as checked-in schemas should be valid anyhow
validate_schema(schema)
assert a == b and not diff, diff


def generate_schema(route, viewset=None, view=None, view_function=None, patterns=None):
Expand Down
104 changes: 104 additions & 0 deletions tests/test_mock_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
import yaml
from django.contrib.auth.models import User
from django.urls import include, path
from rest_framework import routers, viewsets
from rest_framework.authentication import TokenAuthentication
from rest_framework.authtoken.models import Token
from rest_framework.permissions import IsAuthenticated
from rest_framework.test import APIClient
from rest_framework.versioning import AcceptHeaderVersioning

from drf_spectacular.generators import SchemaGenerator
from drf_spectacular.views import SpectacularAPIView
from tests.models import SimpleModel, SimpleSerializer


class AnotherSimpleSerializer(SimpleSerializer):
pass


class XViewset(viewsets.ModelViewSet):
authentication_classes = [TokenAuthentication]
permission_classes = [IsAuthenticated]
queryset = SimpleModel.objects.none()

def get_serializer_class(self):
# make sure the mocked request possesses the correct path and
# schema endpoint path does not leak in.
assert self.request.path.startswith('/api/x/')
# make schema dependent on request method
if self.request.method == 'GET':
return SimpleSerializer
else:
return AnotherSimpleSerializer


router = routers.SimpleRouter()
router.register('x', XViewset)
urlpatterns = [
path('api/', include(router.urls)),
path('api/schema-plain/', SpectacularAPIView.as_view()),
path('api/schema-authenticated/', SpectacularAPIView.as_view(
authentication_classes=[TokenAuthentication]
)),
path('api/schema-authenticated-private/', SpectacularAPIView.as_view(
authentication_classes=[TokenAuthentication],
serve_public=False,
)),
path('api/schema-versioned/', SpectacularAPIView.as_view(
versioning_class=AcceptHeaderVersioning
))
]


@pytest.mark.urls(__name__)
def test_mock_request_symmetry_plain(no_warnings):
response = APIClient().get('/api/schema-plain/', **{'HTTP_X_SPECIAL_HEADER': '1'})
assert response.status_code == 200
schema_online = yaml.load(response.content, Loader=yaml.SafeLoader)
schema_offline = SchemaGenerator().get_schema(public=True)
assert schema_offline == schema_online


@pytest.mark.urls(__name__)
def test_mock_request_symmetry_version(no_warnings):
response = APIClient().get('/api/schema-versioned/', **{
'HTTP_ACCEPT': 'application/json; version=v2',
})
assert response.status_code == 200
schema_online = yaml.load(response.content, Loader=yaml.SafeLoader)
schema_offline = SchemaGenerator(api_version='v2').get_schema(public=True)

assert schema_offline == schema_online
assert schema_online['info']['version'] == '0.0.0 (v2)'


@pytest.mark.parametrize(['serve_public', 'authenticated', 'url', 'expected_endpoints'], [
(True, True, '/api/schema-authenticated/', 5),
(True, False, '/api/schema-authenticated/', 5),
(False, True, '/api/schema-authenticated-private/', None),
(False, False, '/api/schema-authenticated-private/', 3),
])
@pytest.mark.urls(__name__)
@pytest.mark.django_db
def test_mock_request_symmetry_authentication(
no_warnings, serve_public, authenticated, url, expected_endpoints
):
user = User.objects.create(username='test')
token, _ = Token.objects.get_or_create(user=user)
auth_header = {'HTTP_AUTHORIZATION': f'Token {token}'} if authenticated else {}
response = APIClient().get(url, **auth_header)
assert response.status_code == 200

schema_online = yaml.load(response.content, Loader=yaml.SafeLoader)
schema_offline = SchemaGenerator().get_schema(public=serve_public)

if expected_endpoints:
assert schema_offline == schema_online
assert len(schema_online['paths']) == expected_endpoints
else:
# authenticated & non-public case does not really make sense for
# offline generation as we have no request.
assert len(schema_online['paths']) == 5
assert len(schema_offline['paths']) == 3

0 comments on commit 2eaac13

Please sign in to comment.