Skip to content

Commit

Permalink
Switch permissions to classes
Browse files Browse the repository at this point in the history
  • Loading branch information
stevelacey committed Aug 30, 2022
1 parent d56f837 commit ed4195f
Show file tree
Hide file tree
Showing 14 changed files with 177 additions and 115 deletions.
11 changes: 2 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,8 @@ WORF_SERIALIZER_DEFAULT_OPTIONS = {
Permissions
-----------

Permissions functions can be found in `worf.permissions`.

These functions extend the API View, so they require `self` to be defined as a
parameter. This is done in order to allow access to `self.request` during
permission testing.

If permissions should be granted, functions should return `int(200)`.

If permissions fail, they should return an `HTTPException`
Permissions are callable classes that can be found in `worf.permissions`, they're passed
the `request` and `kwargs` from the view, and raise an exception if the check fails.


Validators
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
addopts =
--cov
--cov-fail-under 90
--cov-fail-under 92
--cov-report term:skip-covered
--cov-report html
--no-cov-on-fail
Expand Down
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ def pytest_configure():
INSTALLED_APPS=[
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"tests",
"worf",
],
MIDDLEWARE=[
"django.contrib.sessions.middleware.SessionMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
],
PASSWORD_HASHERS=["django.contrib.auth.hashers.MD5PasswordHasher"],
ROOT_URLCONF="tests.urls",
SECRET_KEY="secret",
Expand All @@ -36,7 +41,7 @@ def pytest_configure():
USE_L10N=False,
USE_TZ=True,
WORF_API_NAME="Test API",
WORF_DEBUG=True,
WORF_DEBUG=False,
)

django.setup()
Expand Down Expand Up @@ -80,7 +85,7 @@ def client_fixture():
def now_fixture():
from django.utils import timezone

return timezone.now()
return timezone.now


@pytest.fixture(name="url")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ def test_settings():
assert settings.WORF_API_NAME == "Test API"
assert settings.WORF_API_ROOT == "/api/"
assert settings.WORF_BROWSABLE_API is True
assert settings.WORF_DEBUG is True
assert settings.WORF_DEBUG is False
42 changes: 26 additions & 16 deletions tests/test_permissions.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,45 @@
import pytest

from django.contrib.auth.models import AnonymousUser, User
from django.test import RequestFactory

from worf.exceptions import HTTP401, HTTP404
from worf.permissions import Authenticated, Staff
from worf.permissions import Authenticated, PublicEndpoint, Staff

factory = RequestFactory()


@pytest.mark.django_db
def test_authenticated():
request = factory.get("/")

def test_authenticated(db, rf):
permission = Authenticated()
request = rf.get("/")
request.user = AnonymousUser()
assert isinstance(Authenticated(None, request), HTTP401)

with pytest.raises(HTTP401):
assert permission(request) is None

request.user = User.objects.create(username="test", password="test")
assert Authenticated(None, request) == 200
assert permission(request) is None


def test_public_endpoint(db, rf):
permission = PublicEndpoint()
request = rf.get("/")
request.user = AnonymousUser()
assert permission(request) is None
request.user = User.objects.create(username="test", password="test")
assert permission(request) is None

@pytest.mark.django_db
def test_staff():
request = factory.get("/")

def test_staff(db, rf):
permission = Staff()
request = rf.get("/")
request.user = AnonymousUser()
assert isinstance(Staff(None, request), HTTP404)

with pytest.raises(HTTP404):
assert permission(request) is None

request.user = User.objects.create(username="test", password="test")
assert isinstance(Staff(None, request), HTTP404)

with pytest.raises(HTTP404):
assert permission(request) is None

request.user.is_staff = True
request.user.save()
assert Staff(None, request) == 200
permission(request)
23 changes: 15 additions & 8 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


@pytest.fixture(name="profile_view")
def profile_view_fixture(db, now, profile_factory):
from django.test import RequestFactory

def profile_view_fixture(db, now, profile_factory, rf):
from tests.views import ProfileDetail

profile_factory.create(email=email, phone=phone)
Expand All @@ -30,11 +28,11 @@ def profile_view_fixture(db, now, profile_factory):
small_integer=123,
recovery_email=email,
recovery_phone=phone,
last_active=now.date().isoformat(),
created_at=now.isoformat(),
last_active=now().date().isoformat(),
created_at=now().isoformat(),
)
)
view.request = RequestFactory().patch(f"/{uuid}/")
view.request = rf.patch(f"/{uuid}/")
view.kwargs = dict(id=str(uuid))
return view

Expand Down Expand Up @@ -74,7 +72,16 @@ def test_validate_bundle_accepts_nulls(profile_view):
profile_view.validate_bundle("recovery_email")


def test_validate_bundle_raises_invalid_field_writes(profile_view):
def test_validate_bundle_raises_invalid_booleans(profile_view):
profile_view.set_bundle(dict(boolean="nooo"))

with pytest.raises(ValidationError) as e:
profile_view.validate_bundle("boolean")

assert "Field boolean accepts a boolean, got nooo" in str(e.value)


def test_validate_bundle_raises_invalid_fields(profile_view):
profile_view.set_bundle(dict(invalid_field=email))

with pytest.raises(ValidationError) as e:
Expand All @@ -83,7 +90,7 @@ def test_validate_bundle_raises_invalid_field_writes(profile_view):
assert "invalid_field is not editable" in str(e.value)


def test_validate_bundle_raises_read_only_field_writes(profile_view):
def test_validate_bundle_raises_read_only_fields(profile_view):
with pytest.raises(ValidationError) as e:
profile_view.validate_bundle("recovery_phone")

Expand Down
48 changes: 43 additions & 5 deletions tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import timedelta
from unittest.mock import patch
from uuid import uuid4

import pytest

Expand All @@ -13,6 +14,13 @@ def test_profile_detail(client, db, profile, user):
assert result["username"] == user.username


def test_profile_not_found(client, db, profile, user):
response = client.get(f"/profiles/{uuid4()}/")
result = response.json()
assert response.status_code == 404, result
assert result["message"] == "Not Found"


def test_profile_delete(client, db, profile, user):
response = client.delete(f"/profiles/{profile.pk}/")
assert response.status_code == 204, response.content
Expand Down Expand Up @@ -281,6 +289,27 @@ def test_profile_update_m2m_through_required_fields(client, db, method, profile,
assert "Invalid skills" in result["message"]


def test_staff_detail(admin_client, profile, user):
response = admin_client.get(f"/profiles/{profile.pk}/staff/")
result = response.json()
assert response.status_code == 200, result
assert result["username"] == user.username


def test_staff_detail_is_not_found_for_user(user_client, profile, user):
response = user_client.get(f"/profiles/{profile.pk}/staff/")
result = response.json()
assert response.status_code == 404, result
assert result["message"] == "Not Found"


def test_staff_detail_is_unauthorized_for_guest(client, db, profile, user):
response = client.get(f"/profiles/{profile.pk}/staff/")
result = response.json()
assert response.status_code == 401, result
assert result["message"] == "Unauthorized"


def test_user_detail(client, db, user):
response = client.get(f"/users/{user.pk}/")
result = response.json()
Expand Down Expand Up @@ -375,11 +404,12 @@ def test_user_list_sort_desc(client, db, url, user_factory):


@pytest.mark.parametrize("url_params__array_format", ["comma", "repeat"])
def test_user_list_multisort(client, now, db, url, user_factory):
user_factory.create(username="a", date_joined=now)
user_factory.create(username="b", date_joined=now - timedelta(hours=1))
user_factory.create(username="c", date_joined=now)
user_factory.create(username="d", date_joined=now)
def test_user_list_multisort(client, db, now, url, user_factory):
date_joined = now()
user_factory.create(username="a", date_joined=date_joined)
user_factory.create(username="b", date_joined=date_joined - timedelta(hours=1))
user_factory.create(username="c", date_joined=date_joined)
user_factory.create(username="d", date_joined=date_joined)
response = client.get(url("/users/", {"sort": ["dateJoined", "-id", "x"]}))
result = response.json()
assert response.status_code == 200, result
Expand All @@ -390,6 +420,14 @@ def test_user_list_multisort(client, now, db, url, user_factory):
assert result["users"][3]["username"] == "a"


def test_user_self(client, user_client, user):
response = client.get("/user/")
assert response.status_code == 401

response = user_client.get("/user/")
assert response.status_code == 200


def test_user_unique_create_with_existing_value(client, db, user, user_factory):
user_factory.create(username="already_taken")
payload = dict(username="already_taken")
Expand Down
20 changes: 8 additions & 12 deletions tests/urls.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from django.urls import path

from tests.views import (
ProfileDetail,
ProfileList,
ProfileListSubSet,
UserDetail,
UserList,
)
from tests import views

urlpatterns = [
path("profiles/", ProfileList.as_view()),
path("profiles/subset/", ProfileListSubSet.as_view()),
path("profiles/<str:id>/", ProfileDetail.as_view()),
path("users/", UserList.as_view()),
path("users/<str:id>/", UserDetail.as_view()),
path("profiles/", views.ProfileList.as_view()),
path("profiles/subset/", views.ProfileListSubSet.as_view()),
path("profiles/<uuid:id>/", views.ProfileDetail.as_view()),
path("profiles/<uuid:id>/staff/", views.StaffDetail.as_view()),
path("user/", views.UserSelf.as_view()),
path("users/", views.UserList.as_view()),
path("users/<int:id>/", views.UserDetail.as_view()),
]
18 changes: 17 additions & 1 deletion tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from tests.models import Profile
from tests.serializers import ProfileSerializer, UserSerializer
from worf.permissions import PublicEndpoint
from worf.exceptions import AuthenticationError
from worf.permissions import Authenticated, PublicEndpoint, Staff
from worf.views import CreateAPI, DeleteAPI, DetailAPI, ListAPI, UpdateAPI


Expand Down Expand Up @@ -51,6 +52,10 @@ def validate_phone(self, value):
return "+5555555555"


class StaffDetail(ProfileDetail):
permissions = [Authenticated, Staff]


class UserList(CreateAPI, ListAPI):
model = User
ordering = ["pk"]
Expand Down Expand Up @@ -82,3 +87,14 @@ class UserDetail(UpdateAPI, DetailAPI):
model = User
serializer = UserSerializer(exclude=["date_joined"])
permissions = [PublicEndpoint]


class UserSelf(DetailAPI):
model = User
serializer = UserSerializer
permissions = [PublicEndpoint]

def get_instance(self):
if not self.request.user.is_authenticated:
raise AuthenticationError("Log in with your username and password")
return self.request.user
10 changes: 9 additions & 1 deletion worf/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class HTTP422(HTTPException):
HTTP_EXCEPTIONS = (HTTP400, HTTP401, HTTP404, HTTP409, HTTP410, HTTP420, HTTP422)


class AuthenticationError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message


class NamingThingsError(ValueError):
pass

Expand All @@ -55,4 +61,6 @@ class NotImplementedInWorfYet(NotImplementedError):


class SerializerError(ValueError):
pass
def __init__(self, message):
super().__init__(message)
self.message = message
8 changes: 4 additions & 4 deletions worf/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def get_filters(self):
state = self._build_state()

for name in self.queryset.query.annotations.keys():
if name in self.Meta.exclude or name in filters:
if name in self.Meta.exclude or name in filters: # pragma: no cover
continue

try:
annotation_filter = self._build_annotation_filter(name, state)
except SkipFilter:
except SkipFilter: # pragma: no cover
continue

if annotation_filter is not None:
Expand All @@ -31,11 +31,11 @@ def get_filters(self):
def _build_annotation_filter(self, name, state):
field = self.queryset.query.annotations.get(name).output_field

if isinstance(field, RelatedField):
if isinstance(field, RelatedField): # pragma: no cover
if not self.Meta.allow_related:
raise SkipFilter
return self._build_filterset_from_related_field(name, field)
elif isinstance(field, ForeignObjectRel):
elif isinstance(field, ForeignObjectRel): # pragma: no cover
if not self.Meta.allow_related_reverse:
raise SkipFilter
return self._build_filterset_from_reverse_field(name, field)
Expand Down
Loading

0 comments on commit ed4195f

Please sign in to comment.