diff --git a/README.md b/README.md index f28166d..835062f 100644 --- a/README.md +++ b/README.md @@ -179,15 +179,8 @@ WORF_SERIALIZER_DEFAULT_OPTIONS = { Permissions ----------- -Permissions functions can be found in `worf.permissions`. - -These functions extend the API View, so they require `self` to be defined as a -parameter. This is done in order to allow access to `self.request` during -permission testing. - -If permissions should be granted, functions should return `int(200)`. - -If permissions fail, they should return an `HTTPException` +Permissions are callable classes that can be found in `worf.permissions`, they're passed +the `request` and `kwargs` from the view, and raise an exception if the check fails. Validators diff --git a/pytest.ini b/pytest.ini index 8755fe8..97ebc3d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,7 @@ [pytest] addopts = --cov - --cov-fail-under 90 + --cov-fail-under 92 --cov-report term:skip-covered --cov-report html --no-cov-on-fail diff --git a/tests/conftest.py b/tests/conftest.py index bc515d4..69d6ce1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,14 @@ def pytest_configure(): INSTALLED_APPS=[ "django.contrib.auth", "django.contrib.contenttypes", + "django.contrib.sessions", "tests", "worf", ], + MIDDLEWARE=[ + "django.contrib.sessions.middleware.SessionMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + ], PASSWORD_HASHERS=["django.contrib.auth.hashers.MD5PasswordHasher"], ROOT_URLCONF="tests.urls", SECRET_KEY="secret", @@ -36,7 +41,7 @@ def pytest_configure(): USE_L10N=False, USE_TZ=True, WORF_API_NAME="Test API", - WORF_DEBUG=True, + WORF_DEBUG=False, ) django.setup() @@ -80,7 +85,7 @@ def client_fixture(): def now_fixture(): from django.utils import timezone - return timezone.now() + return timezone.now @pytest.fixture(name="url") diff --git a/tests/test_conf.py b/tests/test_conf.py index 9c36ccf..6103065 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -5,4 +5,4 @@ def test_settings(): assert settings.WORF_API_NAME == "Test API" assert settings.WORF_API_ROOT == "/api/" assert settings.WORF_BROWSABLE_API is True - assert settings.WORF_DEBUG is True + assert settings.WORF_DEBUG is False diff --git a/tests/test_permissions.py b/tests/test_permissions.py index ec3c6b7..e5b1a9e 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -1,35 +1,45 @@ import pytest from django.contrib.auth.models import AnonymousUser, User -from django.test import RequestFactory from worf.exceptions import HTTP401, HTTP404 -from worf.permissions import Authenticated, Staff +from worf.permissions import Authenticated, PublicEndpoint, Staff -factory = RequestFactory() - - -@pytest.mark.django_db -def test_authenticated(): - request = factory.get("/") +def test_authenticated(db, rf): + permission = Authenticated() + request = rf.get("/") request.user = AnonymousUser() - assert isinstance(Authenticated(None, request), HTTP401) + + with pytest.raises(HTTP401): + assert permission(request) is None request.user = User.objects.create(username="test", password="test") - assert Authenticated(None, request) == 200 + assert permission(request) is None + +def test_public_endpoint(db, rf): + permission = PublicEndpoint() + request = rf.get("/") + request.user = AnonymousUser() + assert permission(request) is None + request.user = User.objects.create(username="test", password="test") + assert permission(request) is None -@pytest.mark.django_db -def test_staff(): - request = factory.get("/") +def test_staff(db, rf): + permission = Staff() + request = rf.get("/") request.user = AnonymousUser() - assert isinstance(Staff(None, request), HTTP404) + + with pytest.raises(HTTP404): + assert permission(request) is None request.user = User.objects.create(username="test", password="test") - assert isinstance(Staff(None, request), HTTP404) + + with pytest.raises(HTTP404): + assert permission(request) is None request.user.is_staff = True request.user.save() - assert Staff(None, request) == 200 + permission(request) diff --git a/tests/test_validators.py b/tests/test_validators.py index 3601ad3..291ac13 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -10,9 +10,7 @@ @pytest.fixture(name="profile_view") -def profile_view_fixture(db, now, profile_factory): - from django.test import RequestFactory - +def profile_view_fixture(db, now, profile_factory, rf): from tests.views import ProfileDetail profile_factory.create(email=email, phone=phone) @@ -30,11 +28,11 @@ def profile_view_fixture(db, now, profile_factory): small_integer=123, recovery_email=email, recovery_phone=phone, - last_active=now.date().isoformat(), - created_at=now.isoformat(), + last_active=now().date().isoformat(), + created_at=now().isoformat(), ) ) - view.request = RequestFactory().patch(f"/{uuid}/") + view.request = rf.patch(f"/{uuid}/") view.kwargs = dict(id=str(uuid)) return view @@ -74,7 +72,16 @@ def test_validate_bundle_accepts_nulls(profile_view): profile_view.validate_bundle("recovery_email") -def test_validate_bundle_raises_invalid_field_writes(profile_view): +def test_validate_bundle_raises_invalid_booleans(profile_view): + profile_view.set_bundle(dict(boolean="nooo")) + + with pytest.raises(ValidationError) as e: + profile_view.validate_bundle("boolean") + + assert "Field boolean accepts a boolean, got nooo" in str(e.value) + + +def test_validate_bundle_raises_invalid_fields(profile_view): profile_view.set_bundle(dict(invalid_field=email)) with pytest.raises(ValidationError) as e: @@ -83,7 +90,7 @@ def test_validate_bundle_raises_invalid_field_writes(profile_view): assert "invalid_field is not editable" in str(e.value) -def test_validate_bundle_raises_read_only_field_writes(profile_view): +def test_validate_bundle_raises_read_only_fields(profile_view): with pytest.raises(ValidationError) as e: profile_view.validate_bundle("recovery_phone") diff --git a/tests/test_views.py b/tests/test_views.py index e12023c..871d69b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,5 +1,6 @@ from datetime import timedelta from unittest.mock import patch +from uuid import uuid4 import pytest @@ -13,6 +14,13 @@ def test_profile_detail(client, db, profile, user): assert result["username"] == user.username +def test_profile_not_found(client, db, profile, user): + response = client.get(f"/profiles/{uuid4()}/") + result = response.json() + assert response.status_code == 404, result + assert result["message"] == "Not Found" + + def test_profile_delete(client, db, profile, user): response = client.delete(f"/profiles/{profile.pk}/") assert response.status_code == 204, response.content @@ -281,6 +289,27 @@ def test_profile_update_m2m_through_required_fields(client, db, method, profile, assert "Invalid skills" in result["message"] +def test_staff_detail(admin_client, profile, user): + response = admin_client.get(f"/profiles/{profile.pk}/staff/") + result = response.json() + assert response.status_code == 200, result + assert result["username"] == user.username + + +def test_staff_detail_is_not_found_for_user(user_client, profile, user): + response = user_client.get(f"/profiles/{profile.pk}/staff/") + result = response.json() + assert response.status_code == 404, result + assert result["message"] == "Not Found" + + +def test_staff_detail_is_unauthorized_for_guest(client, db, profile, user): + response = client.get(f"/profiles/{profile.pk}/staff/") + result = response.json() + assert response.status_code == 401, result + assert result["message"] == "Unauthorized" + + def test_user_detail(client, db, user): response = client.get(f"/users/{user.pk}/") result = response.json() @@ -375,11 +404,12 @@ def test_user_list_sort_desc(client, db, url, user_factory): @pytest.mark.parametrize("url_params__array_format", ["comma", "repeat"]) -def test_user_list_multisort(client, now, db, url, user_factory): - user_factory.create(username="a", date_joined=now) - user_factory.create(username="b", date_joined=now - timedelta(hours=1)) - user_factory.create(username="c", date_joined=now) - user_factory.create(username="d", date_joined=now) +def test_user_list_multisort(client, db, now, url, user_factory): + date_joined = now() + user_factory.create(username="a", date_joined=date_joined) + user_factory.create(username="b", date_joined=date_joined - timedelta(hours=1)) + user_factory.create(username="c", date_joined=date_joined) + user_factory.create(username="d", date_joined=date_joined) response = client.get(url("/users/", {"sort": ["dateJoined", "-id", "x"]})) result = response.json() assert response.status_code == 200, result @@ -390,6 +420,14 @@ def test_user_list_multisort(client, now, db, url, user_factory): assert result["users"][3]["username"] == "a" +def test_user_self(client, user_client, user): + response = client.get("/user/") + assert response.status_code == 401 + + response = user_client.get("/user/") + assert response.status_code == 200 + + def test_user_unique_create_with_existing_value(client, db, user, user_factory): user_factory.create(username="already_taken") payload = dict(username="already_taken") diff --git a/tests/urls.py b/tests/urls.py index dd7a33a..5751129 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,17 +1,13 @@ from django.urls import path -from tests.views import ( - ProfileDetail, - ProfileList, - ProfileListSubSet, - UserDetail, - UserList, -) +from tests import views urlpatterns = [ - path("profiles/", ProfileList.as_view()), - path("profiles/subset/", ProfileListSubSet.as_view()), - path("profiles//", ProfileDetail.as_view()), - path("users/", UserList.as_view()), - path("users//", UserDetail.as_view()), + path("profiles/", views.ProfileList.as_view()), + path("profiles/subset/", views.ProfileListSubSet.as_view()), + path("profiles//", views.ProfileDetail.as_view()), + path("profiles//staff/", views.StaffDetail.as_view()), + path("user/", views.UserSelf.as_view()), + path("users/", views.UserList.as_view()), + path("users//", views.UserDetail.as_view()), ] diff --git a/tests/views.py b/tests/views.py index 6008f7d..0d5e6d2 100644 --- a/tests/views.py +++ b/tests/views.py @@ -5,7 +5,8 @@ from tests.models import Profile from tests.serializers import ProfileSerializer, UserSerializer -from worf.permissions import PublicEndpoint +from worf.exceptions import AuthenticationError +from worf.permissions import Authenticated, PublicEndpoint, Staff from worf.views import CreateAPI, DeleteAPI, DetailAPI, ListAPI, UpdateAPI @@ -51,6 +52,10 @@ def validate_phone(self, value): return "+5555555555" +class StaffDetail(ProfileDetail): + permissions = [Authenticated, Staff] + + class UserList(CreateAPI, ListAPI): model = User ordering = ["pk"] @@ -82,3 +87,14 @@ class UserDetail(UpdateAPI, DetailAPI): model = User serializer = UserSerializer(exclude=["date_joined"]) permissions = [PublicEndpoint] + + +class UserSelf(DetailAPI): + model = User + serializer = UserSerializer + permissions = [PublicEndpoint] + + def get_instance(self): + if not self.request.user.is_authenticated: + raise AuthenticationError("Log in with your username and password") + return self.request.user diff --git a/worf/exceptions.py b/worf/exceptions.py index 67b8e9e..574dcec 100644 --- a/worf/exceptions.py +++ b/worf/exceptions.py @@ -42,6 +42,12 @@ class HTTP422(HTTPException): HTTP_EXCEPTIONS = (HTTP400, HTTP401, HTTP404, HTTP409, HTTP410, HTTP420, HTTP422) +class AuthenticationError(Exception): + def __init__(self, message): + super().__init__(message) + self.message = message + + class NamingThingsError(ValueError): pass @@ -55,4 +61,6 @@ class NotImplementedInWorfYet(NotImplementedError): class SerializerError(ValueError): - pass + def __init__(self, message): + super().__init__(message) + self.message = message diff --git a/worf/filters.py b/worf/filters.py index b87dbe9..d8f40b4 100644 --- a/worf/filters.py +++ b/worf/filters.py @@ -15,12 +15,12 @@ def get_filters(self): state = self._build_state() for name in self.queryset.query.annotations.keys(): - if name in self.Meta.exclude or name in filters: + if name in self.Meta.exclude or name in filters: # pragma: no cover continue try: annotation_filter = self._build_annotation_filter(name, state) - except SkipFilter: + except SkipFilter: # pragma: no cover continue if annotation_filter is not None: @@ -31,11 +31,11 @@ def get_filters(self): def _build_annotation_filter(self, name, state): field = self.queryset.query.annotations.get(name).output_field - if isinstance(field, RelatedField): + if isinstance(field, RelatedField): # pragma: no cover if not self.Meta.allow_related: raise SkipFilter return self._build_filterset_from_related_field(name, field) - elif isinstance(field, ForeignObjectRel): + elif isinstance(field, ForeignObjectRel): # pragma: no cover if not self.Meta.allow_related_reverse: raise SkipFilter return self._build_filterset_from_reverse_field(name, field) diff --git a/worf/permissions.py b/worf/permissions.py index 44aa3aa..34bb5e9 100644 --- a/worf/permissions.py +++ b/worf/permissions.py @@ -1,17 +1,22 @@ from worf.exceptions import HTTP401, HTTP404 -def Authenticated(self, request): - if not request.user.is_authenticated: - return HTTP401() - return 200 +class Authenticated: + def __call__(self, request, **kwargs): + if request.user.is_authenticated: + return + raise HTTP401() -def Staff(self, request): - if not request.user.is_authenticated or not request.user.is_staff: - return HTTP404() - return 200 +class PublicEndpoint: + def __call__(self, request, **kwargs): + pass -def PublicEndpoint(self, request): - return 200 + +class Staff: + def __call__(self, request, **kwargs): + if request.user.is_authenticated and request.user.is_staff: + return + + raise HTTP404() diff --git a/worf/views/base.py b/worf/views/base.py index 89ab9e5..2bba599 100644 --- a/worf/views/base.py +++ b/worf/views/base.py @@ -1,5 +1,4 @@ import json -import types import warnings from io import BytesIO from urllib.parse import parse_qs @@ -18,10 +17,8 @@ from worf.casing import camel_to_snake, snake_to_camel from worf.conf import settings from worf.exceptions import ( - HTTP400, - HTTP404, - HTTP422, HTTP_EXCEPTIONS, + AuthenticationError, PermissionsError, SerializerError, ) @@ -59,10 +56,10 @@ class AbstractBaseAPI(SerializeModels, ValidateFields, APIResponse): def __init__(self, *args, **kwargs): self.codepath = f"{self.__module__}.{self.__class__.__name__}" - if self.model is None: + if self.model is None: # pragma: no cover raise ImproperlyConfigured(f"Model is not set on {self.codepath}") - if not isinstance(self.permissions, list): + if not isinstance(self.permissions, list): # pragma: no cover raise ImproperlyConfigured( f"{self.codepath}.permissions must be type: list" ) @@ -76,10 +73,6 @@ def __init__(self, *args, **kwargs): ), ) - for method in self.permissions: - # append authorization functions to this class - setattr(self, method.__name__, types.MethodType(method, self)) - super().__init__(*args, **kwargs) @property @@ -97,53 +90,44 @@ def dispatch(self, request, *args, **kwargs): handler = getattr(self, method, self.http_method_not_allowed) try: - self._check_permissions() # only returns 200 or HTTP_EXCEPTIONS + self.check_permissions() self.set_bundle_from_request(request) - return handler(request, *args, **kwargs) # calls self.serialize() + return handler(request, *args, **kwargs) except HTTP_EXCEPTIONS as e: message = e.message status = e.status + except AuthenticationError as e: + message = e.message + status = 401 except ObjectDoesNotExist as e: if self.model and not isinstance(e, self.model.DoesNotExist): raise e - message = HTTP404.message - status = HTTP404.status + message = "Not Found" + status = 404 except RequestDataTooBig: self.request._body = self.request.read(None) # prevent further raises message = f"Max upload size is {filesizeformat(settings.DATA_UPLOAD_MAX_MEMORY_SIZE)}" - status = HTTP422.status + status = 422 except SerializerError as e: - message = str(e) - status = HTTP400.status + message = e.message + status = 400 except ValidationError as e: message = e.message - status = HTTP422.status + status = 422 return self.render_to_response(dict(message=message), status) - def _check_permissions(self): - """Return a permissions exception when in debug mode instead of 404.""" - for method in self.permissions: - permission_func = getattr(self, method.__name__) - response = permission_func(self.request) - if response == 200: - continue - - if settings.WORF_DEBUG: - raise PermissionsError( - "Permissions function {}.{} returned {}. You'd normally see a 404 here but WORF_DEBUG=True.".format( - method.__module__, method.__name__, response - ) - ) - + def check_permissions(self): + for perm in self.permissions: try: - raise response - except TypeError: - raise ImproperlyConfigured( - "Permissions function {}.{} must return 200 or an HTTPException".format( - method.__module__, method.__name__ - ) - ) + perm()(self.request, **self.kwargs) + except HTTP_EXCEPTIONS as e: + if settings.WORF_DEBUG: + raise PermissionsError( + f"Permission check {perm.__module__}.{perm.__name__} raised {e.__class__.__name__}. " + f"You'd normally see a {e.status} here but WORF_DEBUG=True." + ) from e + raise e def get_instance(self): return self.instance if hasattr(self, "instance") else None diff --git a/worf/views/list.py b/worf/views/list.py index be85e80..85c0f99 100644 --- a/worf/views/list.py +++ b/worf/views/list.py @@ -36,19 +36,19 @@ def __init__(self, *args, **kwargs): codepath = self.codepath - if not isinstance(self.filters, dict): + if not isinstance(self.filters, dict): # pragma: no cover raise ImproperlyConfigured(f"{codepath}.filters must be type: dict") - if not isinstance(self.ordering, list): + if not isinstance(self.ordering, list): # pragma: no cover raise ImproperlyConfigured(f"{codepath}.ordering must be type: list") - if not isinstance(self.filter_fields, list): + if not isinstance(self.filter_fields, list): # pragma: no cover raise ImproperlyConfigured(f"{codepath}.filter_fields must be type: list") - if not isinstance(self.search_fields, (dict, list)): + if not isinstance(self.search_fields, (dict, list)): # pragma: no cover raise ImproperlyConfigured(f"{codepath}.search_fields must be type: list") - if not isinstance(self.sort_fields, list): + if not isinstance(self.sort_fields, list): # pragma: no cover raise ImproperlyConfigured(f"{codepath}.sort_fields must be type: list") # generate a default filterset if a custom one was not provided