Skip to content

Commit

Permalink
Raise from permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
stevelacey committed Aug 30, 2022
1 parent d56f837 commit 399a4ce
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 106 deletions.
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
addopts =
--cov
--cov-fail-under 90
--cov-fail-under 92
--cov-report term:skip-covered
--cov-report html
--no-cov-on-fail
Expand Down
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,14 @@ def pytest_configure():
INSTALLED_APPS=[
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"tests",
"worf",
],
MIDDLEWARE=[
"django.contrib.sessions.middleware.SessionMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
],
PASSWORD_HASHERS=["django.contrib.auth.hashers.MD5PasswordHasher"],
ROOT_URLCONF="tests.urls",
SECRET_KEY="secret",
Expand All @@ -36,7 +41,7 @@ def pytest_configure():
USE_L10N=False,
USE_TZ=True,
WORF_API_NAME="Test API",
WORF_DEBUG=True,
WORF_DEBUG=False,
)

django.setup()
Expand Down Expand Up @@ -80,7 +85,7 @@ def client_fixture():
def now_fixture():
from django.utils import timezone

return timezone.now()
return timezone.now


@pytest.fixture(name="url")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ def test_settings():
assert settings.WORF_API_NAME == "Test API"
assert settings.WORF_API_ROOT == "/api/"
assert settings.WORF_BROWSABLE_API is True
assert settings.WORF_DEBUG is True
assert settings.WORF_DEBUG is False
28 changes: 12 additions & 16 deletions tests/test_permissions.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,31 @@
import pytest

from django.contrib.auth.models import AnonymousUser, User
from django.test import RequestFactory

from worf.exceptions import HTTP401, HTTP404
from worf.permissions import Authenticated, Staff

factory = RequestFactory()


@pytest.mark.django_db
def test_authenticated():
request = factory.get("/")

def test_authenticated(db, rf):
request = rf.get("/")
request.user = AnonymousUser()
assert isinstance(Authenticated(None, request), HTTP401)
with pytest.raises(HTTP401):
Authenticated(request)

request.user = User.objects.create(username="test", password="test")
assert Authenticated(None, request) == 200

Authenticated(request)

@pytest.mark.django_db
def test_staff():
request = factory.get("/")

def test_staff(db, rf):
request = rf.get("/")
request.user = AnonymousUser()
assert isinstance(Staff(None, request), HTTP404)
with pytest.raises(HTTP404):
Staff(request)

request.user = User.objects.create(username="test", password="test")
assert isinstance(Staff(None, request), HTTP404)
with pytest.raises(HTTP404):
Staff(request)

request.user.is_staff = True
request.user.save()
assert Staff(None, request) == 200
Staff(request)
23 changes: 15 additions & 8 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


@pytest.fixture(name="profile_view")
def profile_view_fixture(db, now, profile_factory):
from django.test import RequestFactory

def profile_view_fixture(db, now, profile_factory, rf):
from tests.views import ProfileDetail

profile_factory.create(email=email, phone=phone)
Expand All @@ -30,11 +28,11 @@ def profile_view_fixture(db, now, profile_factory):
small_integer=123,
recovery_email=email,
recovery_phone=phone,
last_active=now.date().isoformat(),
created_at=now.isoformat(),
last_active=now().date().isoformat(),
created_at=now().isoformat(),
)
)
view.request = RequestFactory().patch(f"/{uuid}/")
view.request = rf.patch(f"/{uuid}/")
view.kwargs = dict(id=str(uuid))
return view

Expand Down Expand Up @@ -74,7 +72,16 @@ def test_validate_bundle_accepts_nulls(profile_view):
profile_view.validate_bundle("recovery_email")


def test_validate_bundle_raises_invalid_field_writes(profile_view):
def test_validate_bundle_raises_invalid_booleans(profile_view):
profile_view.set_bundle(dict(boolean="nooo"))

with pytest.raises(ValidationError) as e:
profile_view.validate_bundle("boolean")

assert "Field boolean accepts a boolean, got nooo" in str(e.value)


def test_validate_bundle_raises_invalid_fields(profile_view):
profile_view.set_bundle(dict(invalid_field=email))

with pytest.raises(ValidationError) as e:
Expand All @@ -83,7 +90,7 @@ def test_validate_bundle_raises_invalid_field_writes(profile_view):
assert "invalid_field is not editable" in str(e.value)


def test_validate_bundle_raises_read_only_field_writes(profile_view):
def test_validate_bundle_raises_read_only_fields(profile_view):
with pytest.raises(ValidationError) as e:
profile_view.validate_bundle("recovery_phone")

Expand Down
48 changes: 43 additions & 5 deletions tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import timedelta
from unittest.mock import patch
from uuid import uuid4

import pytest

Expand All @@ -13,6 +14,13 @@ def test_profile_detail(client, db, profile, user):
assert result["username"] == user.username


def test_profile_not_found(client, db, profile, user):
response = client.get(f"/profiles/{uuid4()}/")
result = response.json()
assert response.status_code == 404, result
assert result["message"] == "Not Found"


def test_profile_delete(client, db, profile, user):
response = client.delete(f"/profiles/{profile.pk}/")
assert response.status_code == 204, response.content
Expand Down Expand Up @@ -281,6 +289,27 @@ def test_profile_update_m2m_through_required_fields(client, db, method, profile,
assert "Invalid skills" in result["message"]


def test_staff_detail(admin_client, profile, user):
response = admin_client.get(f"/profiles/{profile.pk}/staff/")
result = response.json()
assert response.status_code == 200, result
assert result["username"] == user.username


def test_staff_detail_is_not_found_for_user(user_client, profile, user):
response = user_client.get(f"/profiles/{profile.pk}/staff/")
result = response.json()
assert response.status_code == 404, result
assert result["message"] == "Not Found"


def test_staff_detail_is_unauthorized_for_guest(client, db, profile, user):
response = client.get(f"/profiles/{profile.pk}/staff/")
result = response.json()
assert response.status_code == 401, result
assert result["message"] == "Unauthorized"


def test_user_detail(client, db, user):
response = client.get(f"/users/{user.pk}/")
result = response.json()
Expand Down Expand Up @@ -375,11 +404,12 @@ def test_user_list_sort_desc(client, db, url, user_factory):


@pytest.mark.parametrize("url_params__array_format", ["comma", "repeat"])
def test_user_list_multisort(client, now, db, url, user_factory):
user_factory.create(username="a", date_joined=now)
user_factory.create(username="b", date_joined=now - timedelta(hours=1))
user_factory.create(username="c", date_joined=now)
user_factory.create(username="d", date_joined=now)
def test_user_list_multisort(client, db, now, url, user_factory):
date_joined = now()
user_factory.create(username="a", date_joined=date_joined)
user_factory.create(username="b", date_joined=date_joined - timedelta(hours=1))
user_factory.create(username="c", date_joined=date_joined)
user_factory.create(username="d", date_joined=date_joined)
response = client.get(url("/users/", {"sort": ["dateJoined", "-id", "x"]}))
result = response.json()
assert response.status_code == 200, result
Expand All @@ -390,6 +420,14 @@ def test_user_list_multisort(client, now, db, url, user_factory):
assert result["users"][3]["username"] == "a"


def test_user_self(client, user_client, user):
response = client.get("/user/")
assert response.status_code == 401

response = user_client.get("/user/")
assert response.status_code == 200


def test_user_unique_create_with_existing_value(client, db, user, user_factory):
user_factory.create(username="already_taken")
payload = dict(username="already_taken")
Expand Down
20 changes: 8 additions & 12 deletions tests/urls.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from django.urls import path

from tests.views import (
ProfileDetail,
ProfileList,
ProfileListSubSet,
UserDetail,
UserList,
)
from tests import views

urlpatterns = [
path("profiles/", ProfileList.as_view()),
path("profiles/subset/", ProfileListSubSet.as_view()),
path("profiles/<str:id>/", ProfileDetail.as_view()),
path("users/", UserList.as_view()),
path("users/<str:id>/", UserDetail.as_view()),
path("profiles/", views.ProfileList.as_view()),
path("profiles/subset/", views.ProfileListSubSet.as_view()),
path("profiles/<uuid:id>/", views.ProfileDetail.as_view()),
path("profiles/<uuid:id>/staff/", views.StaffDetail.as_view()),
path("user/", views.UserSelf.as_view()),
path("users/", views.UserList.as_view()),
path("users/<int:id>/", views.UserDetail.as_view()),
]
18 changes: 17 additions & 1 deletion tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from tests.models import Profile
from tests.serializers import ProfileSerializer, UserSerializer
from worf.permissions import PublicEndpoint
from worf.exceptions import AuthenticationError
from worf.permissions import Authenticated, PublicEndpoint, Staff
from worf.views import CreateAPI, DeleteAPI, DetailAPI, ListAPI, UpdateAPI


Expand Down Expand Up @@ -51,6 +52,10 @@ def validate_phone(self, value):
return "+5555555555"


class StaffDetail(ProfileDetail):
permissions = [Authenticated, Staff]


class UserList(CreateAPI, ListAPI):
model = User
ordering = ["pk"]
Expand Down Expand Up @@ -82,3 +87,14 @@ class UserDetail(UpdateAPI, DetailAPI):
model = User
serializer = UserSerializer(exclude=["date_joined"])
permissions = [PublicEndpoint]


class UserSelf(DetailAPI):
model = User
serializer = UserSerializer
permissions = [PublicEndpoint]

def get_instance(self):
if not self.request.user.is_authenticated:
raise AuthenticationError("Log in with your username and password")
return self.request.user
10 changes: 9 additions & 1 deletion worf/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class HTTP422(HTTPException):
HTTP_EXCEPTIONS = (HTTP400, HTTP401, HTTP404, HTTP409, HTTP410, HTTP420, HTTP422)


class AuthenticationError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message


class NamingThingsError(ValueError):
pass

Expand All @@ -55,4 +61,6 @@ class NotImplementedInWorfYet(NotImplementedError):


class SerializerError(ValueError):
pass
def __init__(self, message):
super().__init__(message)
self.message = message
8 changes: 4 additions & 4 deletions worf/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def get_filters(self):
state = self._build_state()

for name in self.queryset.query.annotations.keys():
if name in self.Meta.exclude or name in filters:
if name in self.Meta.exclude or name in filters: # pragma: no cover
continue

try:
annotation_filter = self._build_annotation_filter(name, state)
except SkipFilter:
except SkipFilter: # pragma: no cover
continue

if annotation_filter is not None:
Expand All @@ -31,11 +31,11 @@ def get_filters(self):
def _build_annotation_filter(self, name, state):
field = self.queryset.query.annotations.get(name).output_field

if isinstance(field, RelatedField):
if isinstance(field, RelatedField): # pragma: no cover
if not self.Meta.allow_related:
raise SkipFilter
return self._build_filterset_from_related_field(name, field)
elif isinstance(field, ForeignObjectRel):
elif isinstance(field, ForeignObjectRel): # pragma: no cover
if not self.Meta.allow_related_reverse:
raise SkipFilter
return self._build_filterset_from_reverse_field(name, field)
Expand Down
25 changes: 15 additions & 10 deletions worf/permissions.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from worf.exceptions import HTTP401, HTTP404


def Authenticated(self, request):
if not request.user.is_authenticated:
return HTTP401()
return 200
class Authenticated:
def __init__(self, request, **kwargs):
if request.user.is_authenticated:
return

raise HTTP401()

def Staff(self, request):
if not request.user.is_authenticated or not request.user.is_staff:
return HTTP404()
return 200

class PublicEndpoint:
def __init__(self, request, **kwargs):
pass

def PublicEndpoint(self, request):
return 200

class Staff:
def __init__(self, request, **kwargs):
if request.user.is_authenticated and request.user.is_staff:
return

raise HTTP404()
Loading

0 comments on commit 399a4ce

Please sign in to comment.