diff --git a/pytest.ini b/pytest.ini index 7139a36..964537e 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,7 @@ [pytest] addopts = --cov - --cov-fail-under 75 + --cov-fail-under 78 --cov-report term:skip-covered --cov-report html --no-cov-on-fail diff --git a/tests/conftest.py b/tests/conftest.py index 46c58a7..d7a8a35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,12 @@ import pytest from pytest_factoryboy import register -import django -from django.conf import settings -from django.utils import timezone - def pytest_configure(): """Initialize Django settings.""" + import django + from django.conf import settings + settings.configure( SECRET_KEY="secret", DEBUG=True, @@ -61,16 +60,33 @@ def pytest_configure(): register(UserFactory, "user") +@pytest.fixture(name="admin_client") +def admin_client_fixture(db, admin_user): + from worf.testing import ApiClient + + client = ApiClient() + client.force_login(admin_user) + return client + + +@pytest.fixture(name="client") +def client_fixture(): + from worf.testing import ApiClient + + return ApiClient() + + @pytest.fixture(name="now") def now_fixture(): - return timezone.now() + from django.utils import timezone + return timezone.now() -@pytest.fixture(name="profile_url") -def profile_url_fixture(profile): - return f"/profiles/{profile.pk}/" +@pytest.fixture(name="user_client") +def user_client_fixture(db, user): + from worf.testing import ApiClient -@pytest.fixture(name="user_url") -def user_url_fixture(user): - return f"/users/{user.pk}/" + client = ApiClient() + client.force_login(user) + return client diff --git a/tests/models.py b/tests/models.py index 61e0f5f..516f8d2 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,12 +1,21 @@ +from uuid import uuid4 + from django.db import models from django.contrib.auth.models import User -class DummyModel(models.Model): - id = models.CharField(max_length=64, primary_key=True) +class Profile(models.Model): + id = models.UUIDField(primary_key=True, default=uuid4) + email = models.CharField(max_length=64) phone = models.CharField(max_length=64) + user = models.OneToOneField(User, on_delete=models.CASCADE) + role = models.ForeignKey("Role", on_delete=models.CASCADE) + team = models.ForeignKey("Team", blank=True, null=True, on_delete=models.SET_NULL) + skills = models.ManyToManyField("Skill", through="RatedSkill") + tags = models.ManyToManyField("Tag") + def api(self): return dict(id=self.id, email=self.email, phone=self.phone) @@ -72,11 +81,3 @@ def api(self): class Team(models.Model): name = models.CharField(max_length=64) - - -class Profile(models.Model): - user = models.OneToOneField(User, on_delete=models.CASCADE) - role = models.ForeignKey(Role, on_delete=models.CASCADE) - team = models.ForeignKey(Team, blank=True, null=True, on_delete=models.SET_NULL) - skills = models.ManyToManyField(Skill, through=RatedSkill) - tags = models.ManyToManyField(Tag) diff --git a/tests/test_validators.py b/tests/test_validators.py index c62f640..89d6733 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,82 +1,65 @@ import pytest -from uuid import UUID + +from uuid import uuid4 from django.core.exceptions import ValidationError -from django.test import RequestFactory -from tests.models import DummyModel -from tests.views import DummyAPI +uuid = uuid4() +email = "something@example.com" +phone = "(555) 555-5555" + -test_uuid = "ce6a5b4f-599d-4442-8a74-d7a8d2b54854" -test_email = "something@example.com" -test_phone = "(555) 555-5555" +@pytest.fixture(name="profile_view") +def profile_view_fixture(db, profile_factory): + from django.test import RequestFactory + from tests.views import ProfileDetail -@pytest.fixture -@pytest.mark.django_db -def view(): - view = DummyAPI() + profile_factory.create(email=email, phone=phone) + view = ProfileDetail() view.bundle = { - "id": test_uuid, - "email": test_email, - "phone": test_phone, + "id": str(uuid), + "email": email, + "phone": phone, } - DummyModel.objects.create(id=UUID(test_uuid), email=test_email, phone=test_phone) - view.request = RequestFactory().patch(f"/{test_uuid}/") - view.kwargs = dict(id=test_uuid) + view.request = RequestFactory().patch(f"/{uuid}/") + view.kwargs = dict(id=str(uuid)) + view.serializer = None return view -@pytest.mark.django_db -def test_validate_bundle(view): - assert view.validate_bundle("id") - assert view.validate_bundle("email") - assert view.validate_bundle("phone") +def test_validate_bundle(profile_view): + assert profile_view.validate_bundle("id") + assert profile_view.validate_bundle("email") + assert profile_view.validate_bundle("phone") -@pytest.mark.django_db -def test_validate_uuid_accepts_str(view): - string = "ce6a5b4f-599d-4442-8a74-d7a8d2b54854" - result = view.validate_uuid(string) - assert result == UUID(string) +def test_validate_uuid_accepts_str(profile_view): + assert profile_view.validate_uuid(str(uuid)) == uuid -@pytest.mark.django_db -def test_validate_uuid_accepts_uuid(view): - uuid = UUID("ce6a5b4f-599d-4442-8a74-d7a8d2b54854") - result = view.validate_uuid(uuid) - assert result == uuid +def test_validate_uuid_accepts_uuid(profile_view): + assert profile_view.validate_uuid(uuid) == uuid -@pytest.mark.django_db -def test_validate_uuid_raises_error(view): - string = "not-a-uuid" +def test_validate_uuid_raises_error(profile_view): with pytest.raises(ValidationError): - view.validate_uuid(string) + profile_view.validate_uuid("not-a-uuid") -@pytest.mark.django_db -def test_validate_email_passes(view): - email = "something@example.com" - result = view.validate_email(email) - assert email == result +def test_validate_email_passes(profile_view): + assert profile_view.validate_email(email) == email -@pytest.mark.django_db -def test_validate_email_raises_error(view): - email = "fake.example@com" +def test_validate_email_raises_error(profile_view): with pytest.raises(ValidationError): - view.validate_email(email) + profile_view.validate_email("fake.example@com") -@pytest.mark.django_db -def test_validate_custom_field_passes(view): - phone = "(555) 555-5555" - assert view.validate_phone(phone) == "+5555555555" +def test_validate_custom_field_passes(profile_view): + assert profile_view.validate_phone(phone) == "+5555555555" -@pytest.mark.django_db -def test_validate_custom_field_raises_error(view): - phone = "invalid number" +def test_validate_custom_field_raises_error(profile_view): with pytest.raises(ValidationError): - view.validate_phone(phone) + profile_view.validate_phone("invalid number") diff --git a/tests/test_views.py b/tests/test_views.py index aad6cc0..16381fe 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,15 +1,19 @@ from datetime import timedelta -json_type = "application/json" - -def test_profile_detail(client, db, profile_url, user): - response = client.get(profile_url) +def test_profile_detail(client, db, profile, user): + response = client.get(f"/profiles/{profile.pk}/") result = response.json() assert response.status_code == 200, result assert result["username"] == user.username +def test_profile_delete(client, db, profile, user): + response = client.delete(f"/profiles/{profile.pk}/") + assert response.status_code == 204, response.content + assert response.content == b"" + + def test_profile_list(client, db, profile, user): response = client.get("/profiles/") result = response.json() @@ -18,7 +22,7 @@ def test_profile_list(client, db, profile, user): assert result["profiles"][0]["username"] == user.username -def test_profile_list_search(client, db, profile, user): +def test_profile_list_filter(client, db, profile, user): response = client.get(f"/profiles/?name={user.first_name} {user.last_name}") result = response.json() assert response.status_code == 200, result @@ -35,13 +39,6 @@ def test_profile_list_annotation_filter(client, db, profile_factory): assert len(result["profiles"]) == 1 -def test_profile_list_subset_search(client, db, profile, user): - response = client.get(f"/profiles/subset/?name={user.first_name} {user.last_name}") - result = response.json() - assert response.status_code == 200, result - assert len(result["profiles"]) == 0 - - def test_profile_list_and_filter(client, db, profile_factory, tag_factory): tag1, tag2, tag3 = tag_factory.create_batch(3) profile_factory.create(tags=[tag1]) @@ -68,38 +65,45 @@ def test_profile_list_or_filter(client, db, profile_factory, tag_factory): assert len(result["profiles"]) == 3 -def test_profile_update_fk(client, db, profile_url, role, team): +def test_profile_list_subset_filter(client, db, profile, user): + response = client.get(f"/profiles/subset/?name={user.first_name} {user.last_name}") + result = response.json() + assert response.status_code == 200, result + assert len(result["profiles"]) == 0 + + +def test_profile_update_fk(client, db, profile, role, team): payload = dict(role=role.pk, team=team.pk) - response = client.patch(profile_url, payload, content_type=json_type) + response = client.patch(f"/profiles/{profile.pk}/", payload) result = response.json() assert response.status_code == 200, result assert result["role"]["name"] == role.name assert result["team"]["name"] == team.name -def test_profile_update_fk_invalid_role(client, db, profile_url, role, team): - response = client.patch(profile_url, dict(role=123), content_type=json_type) +def test_profile_update_fk_invalid_role(client, db, profile, role, team): + response = client.patch(f"/profiles/{profile.pk}/", dict(role=123)) result = response.json() assert response.status_code == 422, result assert result["message"] == "Invalid role" -def test_profile_update_fk_role_is_not_nullable(client, db, profile_url, role, team): - response = client.patch(profile_url, dict(role=None), content_type=json_type) +def test_profile_update_fk_role_is_not_nullable(client, db, profile, role, team): + response = client.patch(f"/profiles/{profile.pk}/", dict(role=None)) result = response.json() assert response.status_code == 422, result assert result["message"] == "Invalid role" -def test_profile_update_fk_team_is_nullable(client, db, profile_url, role, team): - response = client.patch(profile_url, dict(team=None), content_type=json_type) +def test_profile_update_fk_team_is_nullable(client, db, profile, role, team): + response = client.patch(f"/profiles/{profile.pk}/", dict(team=None)) result = response.json() assert response.status_code == 200, result assert result["team"] is None -def test_profile_update_m2m(client, db, profile_url, tag): - response = client.patch(profile_url, dict(tags=[tag.pk]), content_type=json_type) +def test_profile_update_m2m(client, db, profile, tag): + response = client.patch(f"/profiles/{profile.pk}/", dict(tags=[tag.pk])) result = response.json() assert response.status_code == 200, result assert len(result["tags"]) == 1 @@ -107,31 +111,31 @@ def test_profile_update_m2m(client, db, profile_url, tag): assert result["tags"][0]["name"] == tag.name -def test_profile_update_m2m_can_be_empty(client, db, profile_url, tag): - response = client.patch(profile_url, dict(tags=[]), content_type=json_type) +def test_profile_update_m2m_can_be_empty(client, db, profile, tag): + response = client.patch(f"/profiles/{profile.pk}/", dict(tags=[])) result = response.json() assert response.status_code == 200, result assert len(result["tags"]) == 0 -def test_profile_update_m2m_is_not_nullable(client, db, profile_url, tag): - response = client.patch(profile_url, dict(tags=None), content_type=json_type) +def test_profile_update_m2m_is_not_nullable(client, db, profile, tag): + response = client.patch(f"/profiles/{profile.pk}/", dict(tags=None)) result = response.json() assert response.status_code == 422, result assert "tags accepts an array, got None" in result["message"] -def test_profile_update_m2m_must_be_pks(client, db, profile_url, tag): +def test_profile_update_m2m_must_be_pks(client, db, profile, tag): payload = dict(tags=["invalid"]) - response = client.patch(profile_url, payload, content_type=json_type) + response = client.patch(f"/profiles/{profile.pk}/", payload) result = response.json() assert response.status_code == 422, result assert "Invalid tags" in result["message"] -def test_profile_update_m2m_through(client, db, profile_url, skill): +def test_profile_update_m2m_through(client, db, profile, skill): payload = dict(skills=[dict(id=skill.pk, rating=4)]) - response = client.patch(profile_url, payload, content_type=json_type) + response = client.patch(f"/profiles/{profile.pk}/", payload) result = response.json() assert response.status_code == 200, result assert len(result["skills"]) == 1 @@ -140,46 +144,46 @@ def test_profile_update_m2m_through(client, db, profile_url, skill): assert result["skills"][0]["rating"] == 4 -def test_profile_update_m2m_through_can_be_empty(client, db, profile_url, skill): - response = client.patch(profile_url, dict(skills=[]), content_type=json_type) +def test_profile_update_m2m_through_can_be_empty(client, db, profile, skill): + response = client.patch(f"/profiles/{profile.pk}/", dict(skills=[])) result = response.json() assert response.status_code == 200, result assert len(result["skills"]) == 0 -def test_profile_update_m2m_through_is_not_nullable(client, db, profile_url, skill): - response = client.patch(profile_url, dict(skills=None), content_type=json_type) +def test_profile_update_m2m_through_is_not_nullable(client, db, profile, skill): + response = client.patch(f"/profiles/{profile.pk}/", dict(skills=None)) result = response.json() assert response.status_code == 422, result assert "skills accepts an array, got None" in result["message"] -def test_profile_update_m2m_through_must_be_dicts(client, db, profile_url, skill): +def test_profile_update_m2m_through_must_be_dicts(client, db, profile, skill): payload = dict(skills=["invalid"]) - response = client.patch(profile_url, payload, content_type=json_type) + response = client.patch(f"/profiles/{profile.pk}/", payload) result = response.json() assert response.status_code == 422, result assert "Invalid skills" == result["message"] -def test_profile_update_m2m_through_ids_must_be_pks(client, db, profile_url, skill): +def test_profile_update_m2m_through_ids_must_be_pks(client, db, profile, skill): payload = dict(skills=[dict(id="invalid")]) - response = client.patch(profile_url, payload, content_type=json_type) + response = client.patch(f"/profiles/{profile.pk}/", payload) result = response.json() assert response.status_code == 422, result assert "Invalid skills" in result["message"] -def test_profile_update_m2m_through_required_fields(client, db, profile_url, skill): +def test_profile_update_m2m_through_required_fields(client, db, profile, skill): payload = dict(skills=[dict(id=skill.pk)]) - response = client.patch(profile_url, payload, content_type=json_type) + response = client.patch(f"/profiles/{profile.pk}/", payload) result = response.json() assert response.status_code == 422, result assert "Invalid skills" in result["message"] -def test_user_detail(client, db, user, user_url): - response = client.get(user_url) +def test_user_detail(client, db, user): + response = client.get(f"/users/{user.pk}/") result = response.json() assert response.status_code == 200, result assert result["username"] == user.username @@ -202,21 +206,21 @@ def test_user_list_filters(client, db, user_factory): user2 = user_factory.create(date_joined=february) user3 = user_factory.create(email="test3@test.com", date_joined=march) - # search by email + # filter by email response = client.get(f"/users/?email={user1.email}") assert len(response.json()["users"]) == 1 - # search by string + # filter by string response = client.get("/users/?q=example.com") assert len(response.json()["users"]) == 2 response = client.get(f"/users/?q={user3.email}") assert len(response.json()["users"]) == 1 - # search by username -- filter ignored, username is not in filter_fields + # filter by username -- filter ignored, username is not in filter_fields response = client.get(f"/users/?username={user2.username}") assert len(response.json()["users"]) == 3 - # search by date joined + # filter by date joined gte, lte = "2021-02-01T00:00:00Z", "2021-02-15T00:00:00Z" response = client.get(f"/users/?dateJoined__gte={gte}&date_joined__lte={lte}") assert len(response.json()["users"]) == 1 @@ -260,9 +264,18 @@ def test_user_list_multisort(client, now, db, user_factory): assert result["users"][3]["username"] == "a" -def test_user_update(client, db, user_url): +def test_user_patch(client, db, user): + payload = dict(username="testtest", email="something@example.com") + response = client.patch(f"/users/{user.pk}/", payload) + result = response.json() + assert response.status_code == 200, result + assert result["username"] == "testtest" + assert result["email"] == "something@example.com" + + +def test_user_update(client, db, user): payload = dict(username="testtest", email="something@example.com") - response = client.patch(user_url, payload, content_type=json_type) + response = client.put(f"/users/{user.pk}/", payload) result = response.json() assert response.status_code == 200, result assert result["username"] == "testtest" diff --git a/tests/urls.py b/tests/urls.py index 9d59b16..62cc1a9 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,7 +1,6 @@ from django.urls import include, path from tests.views import ( - DummyAPI, ProfileDetail, ProfileList, ProfileListSubSet, @@ -10,11 +9,9 @@ ) urlpatterns = [ - path("", DummyAPI.as_view()), path("profiles/", ProfileList.as_view()), path("profiles/subset/", ProfileListSubSet.as_view()), path("profiles//", ProfileDetail.as_view()), path("users/", UserList.as_view()), path("users//", UserDetail.as_view()), - path("/", DummyAPI.as_view()), ] diff --git a/tests/views.py b/tests/views.py index d8d39c8..798bcc6 100644 --- a/tests/views.py +++ b/tests/views.py @@ -4,24 +4,12 @@ from django.db.models.functions import Concat from worf.permissions import PublicEndpoint -from worf.views import DetailUpdateAPI, ListAPI +from worf.views import DeleteAPI, DetailAPI, DetailUpdateAPI, ListAPI, UpdateAPI -from tests.models import DummyModel, Profile +from tests.models import Profile from tests.serializers import ProfileSerializer, UserSerializer -class DummyAPI(DetailUpdateAPI): - model = DummyModel - permissions = [PublicEndpoint] - - def validate_phone(self, value): - try: - assert value == "(555) 555-5555" - except AssertionError: - raise ValidationError("{value} is not a valid phone number") - return "+5555555555" - - class ProfileList(ListAPI): model = Profile queryset = Profile.objects.annotate( @@ -43,11 +31,18 @@ class ProfileListSubSet(ProfileList): queryset = ProfileList.queryset.none() -class ProfileDetail(DetailUpdateAPI): +class ProfileDetail(DeleteAPI, UpdateAPI, DetailAPI): model = Profile serializer = ProfileSerializer permissions = [PublicEndpoint] + def validate_phone(self, value): + try: + assert value == "(555) 555-5555" + except AssertionError: + raise ValidationError("{value} is not a valid phone number") + return "+5555555555" + class UserList(ListAPI): model = User diff --git a/worf/assigns.py b/worf/assigns.py new file mode 100644 index 0000000..ab341d6 --- /dev/null +++ b/worf/assigns.py @@ -0,0 +1,91 @@ +from django.core.exceptions import ValidationError +from django.db import models +from django.db.utils import IntegrityError + +from worf.casing import snake_to_camel + + +class AssignAttributes: + def save(self, instance, bundle): + items = [ + (key, getattr(self.model, key), value) for key, value in bundle.items() + ] + + for key, attr, value in items: + if isinstance(value, models.Model): + setattr(instance, key, value) + continue + + if isinstance(attr.field, models.ForeignKey): + self.set_foreign_key(instance, key, value) + continue + + if isinstance(attr.field, models.ManyToManyField): + continue + + setattr(instance, key, value) + + instance.save() + + for key, attr, value in items: + if isinstance(attr.field, models.ManyToManyField): + if not attr.through._meta.auto_created: + self.set_many_to_many_with_through(instance, key, value) + continue + + self.set_many_to_many(instance, key, value) + + def set_foreign_key(self, instance, key, value): + related_model = self.get_related_model(key) + try: + related_instance = ( + related_model.objects.get(pk=value) if value is not None else None + ) + except related_model.DoesNotExist as e: + raise ValidationError(f"Invalid {snake_to_camel(key)}") from e + setattr(instance, key, related_instance) + + def set_many_to_many(self, instance, key, value): + try: + getattr(instance, key).set(value) + except (IntegrityError, ValueError) as e: + raise ValidationError(f"Invalid {snake_to_camel(key)}") from e + + def set_many_to_many_with_through(self, instance, key, value): + try: + attr = getattr(self.model, key) + + through_model = attr.through + model_name = self.model._meta.model_name + target_field_name = attr.field.m2m_target_field_name() + reverse_name = attr.field.m2m_reverse_name() + + getattr(instance, key).clear() + + through_model.objects.bulk_create( + [ + through_model( + **{ + item_key: item_value + for item_key, item_value in item.items() + if item_key != target_field_name + }, + **{ + model_name: instance, + reverse_name: item[target_field_name], + }, + ) + for item in value + ] + ) + except (AttributeError, IntegrityError, ValueError) as e: + raise ValidationError(f"Invalid {snake_to_camel(key)}") from e + + def validate(self): + for key in self.bundle.keys(): + self.validate_bundle(key) + + field = self.model._meta.get_field(key) + + if self.bundle[key] is None and not field.null: + raise ValidationError(f"Invalid {snake_to_camel(key)}") diff --git a/worf/lookups.py b/worf/lookups.py new file mode 100644 index 0000000..19de334 --- /dev/null +++ b/worf/lookups.py @@ -0,0 +1,19 @@ +class FindInstance: + lookup_field = "id" + lookup_url_kwarg = "id" + queryset = None + + def get_instance(self): + self.lookup_kwargs = {self.lookup_field: self.kwargs[self.lookup_url_kwarg]} + + self.validate_lookup_field_values() + + if not hasattr(self, "instance"): + self.instance = self.get_queryset().get(**self.lookup_kwargs) + + return self.instance + + def get_queryset(self): + if self.queryset is None: + return self.model.objects.all() + return self.queryset.all() diff --git a/worf/testing.py b/worf/testing.py new file mode 100644 index 0000000..496d98e --- /dev/null +++ b/worf/testing.py @@ -0,0 +1,10 @@ +from functools import partialmethod + +from django.test.client import Client + + +class ApiClient(Client): + delete = partialmethod(Client.delete, content_type="application/json") + patch = partialmethod(Client.patch, content_type="application/json") + post = partialmethod(Client.post, content_type="application/json") + put = partialmethod(Client.put, content_type="application/json") diff --git a/worf/transformers.py b/worf/transformers.py deleted file mode 100644 index b646013..0000000 --- a/worf/transformers.py +++ /dev/null @@ -1,7 +0,0 @@ -def transform_to_dict(choices): - """Convert various choices list of lists into json object.""" - # TODO this is copied from context_processors. - transformed_choices = dict() - for each in choices: - transformed_choices[each[0]] = each[1] - return transformed_choices diff --git a/worf/validators.py b/worf/validators.py index 6cdbaaf..3897a52 100644 --- a/worf/validators.py +++ b/worf/validators.py @@ -197,6 +197,9 @@ def validate_bundle(self, key): elif hasattr(self, f"validate_{key}"): self.bundle[key] = getattr(self, f"validate_{key}")(self.bundle[key]) + elif isinstance(field, models.UUIDField): + self.bundle[key] = self.validate_uuid(self.bundle[key]) + elif isinstance(field, (models.CharField, models.TextField, models.SlugField)): self.bundle[key] = self._validate_string(key, field.max_length) diff --git a/worf/views/__init__.py b/worf/views/__init__.py index fc48c44..c3a2e37 100644 --- a/worf/views/__init__.py +++ b/worf/views/__init__.py @@ -1,4 +1,6 @@ from worf.views.base import APIResponse, AbstractBaseAPI # noqa from worf.views.create import CreateAPI # noqa +from worf.views.delete import DeleteAPI # noqa from worf.views.detail import DetailAPI, DetailUpdateAPI # noqa from worf.views.list import ListAPI, ListCreateAPI # noqa +from worf.views.update import UpdateAPI # noqa diff --git a/worf/views/base.py b/worf/views/base.py index 4712cee..4ab390f 100644 --- a/worf/views/base.py +++ b/worf/views/base.py @@ -11,7 +11,7 @@ ValidationError, ) from django.db import models -from django.http import JsonResponse +from django.http import HttpResponse, JsonResponse from django.middleware.gzip import GZipMiddleware from django.views import View from django.views.decorators.cache import never_cache @@ -42,7 +42,7 @@ def render_to_response(self, data=None, status_code=None): msg += "render_to_response, nor did its serializer method" raise ImproperlyConfigured(msg) - response = JsonResponse(payload) + response = JsonResponse(payload) if payload != "" else HttpResponse() # except TypeError: # TODO add something meaningful to the stack trace diff --git a/worf/views/create.py b/worf/views/create.py index 61d4795..2b4c956 100644 --- a/worf/views/create.py +++ b/worf/views/create.py @@ -1,33 +1,19 @@ -from django.core.exceptions import ValidationError - -from worf.casing import snake_to_camel +from worf.assigns import AssignAttributes from worf.views.base import AbstractBaseAPI -class CreateAPI(AbstractBaseAPI): - def serialize(self): - return {} +class CreateAPI(AssignAttributes, AbstractBaseAPI): + def create(self): + instance = self.get_instance() + self.validate() + self.save(instance, self.bundle) + instance.refresh_from_db() + return instance + + def get_instance(self): + return self.model() def post(self, request, *args, **kwargs): new_instance = self.create() serializer = self.get_serializer() return self.render_to_response(serializer.read(new_instance), 201) - - def create(self): - self.validate() - - return self.model.objects.create(**self.bundle) - - def validate(self): - create_fields = self.get_serializer().create() - - for key in self.bundle.keys(): - self.validate_bundle(key) - # ignore create_fields for now if it's empty - # this should be moved into validate bundle - if create_fields and key not in create_fields: - raise ValidationError( - "{} not allowed when creating {}".format( - snake_to_camel(key), self.name - ) - ) diff --git a/worf/views/delete.py b/worf/views/delete.py new file mode 100644 index 0000000..ccc34e8 --- /dev/null +++ b/worf/views/delete.py @@ -0,0 +1,10 @@ +from worf.views.base import AbstractBaseAPI + + +class DeleteAPI(AbstractBaseAPI): + def delete(self, request, *args, **kwargs): + self.destroy() + return self.render_to_response("", 204) + + def destroy(self): + self.get_instance().delete() diff --git a/worf/views/detail.py b/worf/views/detail.py index 57b9542..f6b7459 100644 --- a/worf/views/detail.py +++ b/worf/views/detail.py @@ -1,16 +1,11 @@ -from django.core.exceptions import ImproperlyConfigured, ValidationError -from django.db import models -from django.db.utils import IntegrityError +from django.core.exceptions import ImproperlyConfigured -from worf.casing import snake_to_camel +from worf.lookups import FindInstance from worf.views.base import AbstractBaseAPI +from worf.views.update import UpdateAPI -class DetailAPI(AbstractBaseAPI): - lookup_field = "id" - lookup_url_kwarg = "id" - queryset = None - +class DetailAPI(FindInstance, AbstractBaseAPI): def get(self, request, *args, **kwargs): return self.render_to_response() @@ -22,107 +17,8 @@ def serialize(self): raise ImproperlyConfigured(f"{serializer} did not return a dictionary") return payload - def get_queryset(self): - if self.queryset is None: - return self.model.objects.all() - return self.queryset.all() - - def get_instance(self): - self.lookup_kwargs = {self.lookup_field: self.kwargs[self.lookup_url_kwarg]} - - self.validate_lookup_field_values() - - if not hasattr(self, "instance"): - self.instance = self.get_queryset().get(**self.lookup_kwargs) - - return self.instance - - def set_foreign_key(self, instance, key): - related_model = self.get_related_model(key) - try: - related_instance = ( - related_model.objects.get(pk=self.bundle[key]) - if self.bundle[key] is not None - else None - ) - except related_model.DoesNotExist as e: - raise ValidationError(f"Invalid {snake_to_camel(key)}") from e - setattr(instance, key, related_instance) - - def set_many_to_many(self, instance, key): - try: - getattr(instance, key).set(self.bundle[key]) - except (IntegrityError, ValueError) as e: - raise ValidationError(f"Invalid {snake_to_camel(key)}") from e - - def set_many_to_many_with_through(self, instance, key): - try: - attr = getattr(self.model, key) - - through_model = attr.through - model_name = self.model._meta.model_name - target_field_name = attr.field.m2m_target_field_name() - reverse_name = attr.field.m2m_reverse_name() - - getattr(instance, key).clear() - - through_model.objects.bulk_create( - [ - through_model( - **{ - key: value - for key, value in item.items() - if key != target_field_name - }, - **{ - model_name: instance, - reverse_name: item[target_field_name], - }, - ) - for item in self.bundle[key] - ] - ) - except (AttributeError, IntegrityError, ValueError) as e: - raise ValidationError(f"Invalid {snake_to_camel(key)}") from e - - def update(self): - instance = self.get_instance() - - self.validate() - - for key in self.bundle.keys(): - attr = getattr(self.model, key) - - if isinstance(attr.field, models.ForeignKey): - self.set_foreign_key(instance, key) - continue - - if isinstance(attr.field, models.ManyToManyField): - if not attr.through._meta.auto_created: - self.set_many_to_many_with_through(instance, key) - continue - - self.set_many_to_many(instance, key) - continue - - setattr(instance, key, self.bundle[key]) - - instance.save() - instance.refresh_from_db() - - return instance - - def validate(self): - for key in self.bundle.keys(): - self.validate_bundle(key) - - field = self.model._meta.get_field(key) - - if self.bundle[key] is None and not field.null: - raise ValidationError(f"Invalid {snake_to_camel(key)}") - -class DetailUpdateAPI(DetailAPI): +class DetailUpdateAPI(UpdateAPI, DetailAPI): def patch(self, request, *args, **kwargs): self.update() return self.get(request) diff --git a/worf/views/list.py b/worf/views/list.py index b0462f5..b4b1ee1 100644 --- a/worf/views/list.py +++ b/worf/views/list.py @@ -29,9 +29,6 @@ class ListAPI(AbstractBaseAPI): max_per_page = None num_pages = 1 - def get(self, request, *args, **kwargs): - return self.render_to_response() - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -63,6 +60,9 @@ def __init__(self, *args, **kwargs): ) self.search_fields = self.search_fields.get("or", []) + def get(self, request, *args, **kwargs): + return self.render_to_response() + def _set_base_lookup_kwargs(self): # Filters set directly on the class self.lookup_kwargs.update(self.filters) @@ -243,5 +243,5 @@ def serialize(self): return payload -class ListCreateAPI(ListAPI, CreateAPI): +class ListCreateAPI(CreateAPI, ListAPI): pass diff --git a/worf/views/update.py b/worf/views/update.py new file mode 100644 index 0000000..1879975 --- /dev/null +++ b/worf/views/update.py @@ -0,0 +1,20 @@ +from worf.assigns import AssignAttributes +from worf.lookups import FindInstance +from worf.views.base import AbstractBaseAPI + + +class UpdateAPI(AssignAttributes, FindInstance, AbstractBaseAPI): + def patch(self, request, *args, **kwargs): + self.update() + return self.render_to_response() + + def put(self, request, *args, **kwargs): + self.update() + return self.render_to_response() + + def update(self): + instance = self.get_instance() + self.validate() + self.save(instance, self.bundle) + instance.refresh_from_db() + return instance