diff --git a/drf_spectacular/generators.py b/drf_spectacular/generators.py index 70c21b2a..24c6f2cc 100644 --- a/drf_spectacular/generators.py +++ b/drf_spectacular/generators.py @@ -105,7 +105,8 @@ def create_view(self, callback, method, request=None): original_cls = callback.cls callback.cls = override_view.view_replacement() - view = super().create_view(callback, method, request) + # we refrain from passing request and deal with it ourselves in parse() + view = super().create_view(callback, method, None) # drf-yasg compatibility feature. makes the view aware that we are running # schema generation and not a real request. @@ -160,13 +161,13 @@ def _initialise_endpoints(self): self.inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) self.endpoints = self.inspector.get_api_endpoints() - def _get_paths_and_endpoints(self, request): + def _get_paths_and_endpoints(self): """ Generate (path, method, view) given (path, method, callback) for paths. """ view_endpoints = [] for path, path_regex, method, callback in self.endpoints: - view = self.create_view(callback, method, request) + view = self.create_view(callback, method) path = self.coerce_path(path, method, view) view_endpoints.append((path, path_regex, method, view)) @@ -176,7 +177,7 @@ def parse(self, input_request, public): """ Iterate endpoints generating per method path operations. """ result = {} self._initialise_endpoints() - endpoints = self._get_paths_and_endpoints(None if public else input_request) + endpoints = self._get_paths_and_endpoints() if spectacular_settings.SCHEMA_PATH_PREFIX is None: # estimate common path prefix if none was given. only use it if we encountered more @@ -193,17 +194,10 @@ def parse(self, input_request, public): path_prefix = '^' + path_prefix # make sure regex only matches from the start for path, path_regex, method, view in endpoints: - if not self.has_view_permissions(path, method, view): - continue - - if input_request: - request = input_request - else: - # mocked request to allow certain operations in get_queryset and get_serializer[_class] - # without exceptions being raised due to no request. - request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, input_request) + view.request = spectacular_settings.GET_MOCK_REQUEST(method, path, view, input_request) - view.request = request + if not (public or self.has_view_permissions(path, method, view)): + continue if view.versioning_class and not is_versioning_supported(view.versioning_class): warn( @@ -212,8 +206,7 @@ def parse(self, input_request, public): ) elif view.versioning_class: version = ( - self.api_version # generator was explicitly versioned - or getattr(request, 'version', None) # incoming request was versioned + self.api_version # explicit version from CLI, SpecView or SpecView request or view.versioning_class.default_version # fallback ) if not version: diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index aeb2c681..5494dd69 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -872,8 +872,21 @@ def camelize_operation(path, operation): def build_mock_request(method, path, view, original_request, **kwargs): + """ build a mocked request and use original request as reference if available """ request = getattr(APIRequestFactory(), method.lower())(path=path) request = view.initialize_request(request) + if original_request: + request.user = original_request.user + request.auth = original_request.auth + # ignore headers related to authorization as it has been handled above. + # also ignore ACCEPT as the MIME type refers to SpectacularAPIView and the + # version (if available) has already been processed by SpectacularAPIView. + for name, value in original_request.META.items(): + if not name.startswith('HTTP_'): + continue + if name in ['HTTP_ACCEPT', 'HTTP_COOKIE', 'HTTP_AUTHORIZATION']: + continue + request.META[name] = value return request diff --git a/drf_spectacular/views.py b/drf_spectacular/views.py index e722d21c..11e71f10 100644 --- a/drf_spectacular/views.py +++ b/drf_spectacular/views.py @@ -67,7 +67,10 @@ def get(self, request, *args, **kwargs): return self._get_schema_response(request) def _get_schema_response(self, request): - generator = self.generator_class(urlconf=self.urlconf, api_version=self.api_version) + # version specified as parameter to the view always takes precedence. after + # that we try to source version through the schema view's own versioning_class. + version = self.api_version or request.version + generator = self.generator_class(urlconf=self.urlconf, api_version=version) return Response(generator.get_schema(request=request, public=self.serve_public)) diff --git a/tests/__init__.py b/tests/__init__.py index 812cd536..95d08f95 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,5 @@ import difflib +import json import os from drf_spectacular.validation import validate_schema @@ -35,15 +36,21 @@ def assert_schema(schema, reference_filename, transforms=None): for t in transforms or []: generated = t(generated) + assert_equal(expected, generated) + # this is more a less a sanity check as checked-in schemas should be valid anyhow + validate_schema(schema) + + +def assert_equal(a, b): + if not isinstance(a, str) or isinstance(b, str): + a = json.dumps(a, indent=4) + b = json.dumps(b, indent=4) diff = difflib.unified_diff( - expected.splitlines(True), - generated.splitlines(True), + a.splitlines(True), + b.splitlines(True), ) diff = ''.join(diff) - assert expected == generated and not diff, diff - - # this is more a less a sanity check as checked-in schemas should be valid anyhow - validate_schema(schema) + assert a == b and not diff, diff def generate_schema(route, viewset=None, view=None, view_function=None, patterns=None): diff --git a/tests/test_mock_request.py b/tests/test_mock_request.py new file mode 100644 index 00000000..836fc6ca --- /dev/null +++ b/tests/test_mock_request.py @@ -0,0 +1,104 @@ +import pytest +import yaml +from django.contrib.auth.models import User +from django.urls import include, path +from rest_framework import routers, viewsets +from rest_framework.authentication import TokenAuthentication +from rest_framework.authtoken.models import Token +from rest_framework.permissions import IsAuthenticated +from rest_framework.test import APIClient +from rest_framework.versioning import AcceptHeaderVersioning + +from drf_spectacular.generators import SchemaGenerator +from drf_spectacular.views import SpectacularAPIView +from tests.models import SimpleModel, SimpleSerializer + + +class AnotherSimpleSerializer(SimpleSerializer): + pass + + +class XViewset(viewsets.ModelViewSet): + authentication_classes = [TokenAuthentication] + permission_classes = [IsAuthenticated] + queryset = SimpleModel.objects.none() + + def get_serializer_class(self): + # make sure the mocked request possesses the correct path and + # schema endpoint path does not leak in. + assert self.request.path.startswith('/api/x/') + # make schema dependent on request method + if self.request.method == 'GET': + return SimpleSerializer + else: + return AnotherSimpleSerializer + + +router = routers.SimpleRouter() +router.register('x', XViewset) +urlpatterns = [ + path('api/', include(router.urls)), + path('api/schema-plain/', SpectacularAPIView.as_view()), + path('api/schema-authenticated/', SpectacularAPIView.as_view( + authentication_classes=[TokenAuthentication] + )), + path('api/schema-authenticated-private/', SpectacularAPIView.as_view( + authentication_classes=[TokenAuthentication], + serve_public=False, + )), + path('api/schema-versioned/', SpectacularAPIView.as_view( + versioning_class=AcceptHeaderVersioning + )) +] + + +@pytest.mark.urls(__name__) +def test_mock_request_symmetry_plain(no_warnings): + response = APIClient().get('/api/schema-plain/', **{'HTTP_X_SPECIAL_HEADER': '1'}) + assert response.status_code == 200 + schema_online = yaml.load(response.content, Loader=yaml.SafeLoader) + schema_offline = SchemaGenerator().get_schema(public=True) + assert schema_offline == schema_online + + +@pytest.mark.urls(__name__) +def test_mock_request_symmetry_version(no_warnings): + response = APIClient().get('/api/schema-versioned/', **{ + 'HTTP_ACCEPT': 'application/json; version=v2', + }) + assert response.status_code == 200 + schema_online = yaml.load(response.content, Loader=yaml.SafeLoader) + schema_offline = SchemaGenerator(api_version='v2').get_schema(public=True) + + assert schema_offline == schema_online + assert schema_online['info']['version'] == '0.0.0 (v2)' + + +@pytest.mark.parametrize(['serve_public', 'authenticated', 'url', 'expected_endpoints'], [ + (True, True, '/api/schema-authenticated/', 5), + (True, False, '/api/schema-authenticated/', 5), + (False, True, '/api/schema-authenticated-private/', None), + (False, False, '/api/schema-authenticated-private/', 3), +]) +@pytest.mark.urls(__name__) +@pytest.mark.django_db +def test_mock_request_symmetry_authentication( + no_warnings, serve_public, authenticated, url, expected_endpoints +): + user = User.objects.create(username='test') + token, _ = Token.objects.get_or_create(user=user) + auth_header = {'HTTP_AUTHORIZATION': f'Token {token}'} if authenticated else {} + response = APIClient().get(url, **auth_header) + assert response.status_code == 200 + + schema_online = yaml.load(response.content, Loader=yaml.SafeLoader) + schema_offline = SchemaGenerator().get_schema(public=serve_public) + + if expected_endpoints: + assert schema_offline == schema_online + assert len(schema_online['paths']) == expected_endpoints + else: + # authenticated & non-public case does not really make sense for + # offline generation as we have no request. + assert len(schema_online['paths']) == 5 + assert len(schema_offline['paths']) == 3