diff --git a/README.md b/README.md index d3789a6..3d8e476 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Table of contents - [UpdateAPI](#updateapi) - [Bundle loading](#bundle-loading) - [Field casing](#field-casing) + - [File uploads](#file-uploads) - [Internal naming](#internal-naming) - [Credits](#credits) @@ -56,8 +57,9 @@ Roadmap - [x] Abstracting serializers away from model methods - [x] Declarative marshmallow-based serialization - [x] More support for different HTTP methods +- [x] File upload support on POST +- [ ] File upload support on PATCH/PUT - [ ] Support for user-generated validators -- [ ] Better file upload support - [ ] Better test coverage - [ ] Browsable API docs @@ -310,6 +312,14 @@ This will be strictly translated by the API, and acronyms are not considered: - `API_strict == apiStrict` +File uploads +------------ + +File uploads are supported via `POST` using `multipart/form-data` requests. + +Support for `PATCH`/`PUT` is on the roadmap. + + Internal naming --------------- diff --git a/pytest.ini b/pytest.ini index 964537e..6fcf2a7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,7 @@ [pytest] addopts = --cov - --cov-fail-under 78 + --cov-fail-under 80 --cov-report term:skip-covered --cov-report html --no-cov-on-fail diff --git a/tests/models.py b/tests/models.py index 7e876d2..ab353d1 100644 --- a/tests/models.py +++ b/tests/models.py @@ -7,6 +7,7 @@ class Profile(models.Model): id = models.UUIDField(primary_key=True, default=uuid4) + avatar = models.FileField(upload_to="avatars/", blank=True) email = models.CharField(max_length=64) phone = models.CharField(max_length=64) diff --git a/tests/serializers.py b/tests/serializers.py index 25d5a9b..c5cb47c 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -21,15 +21,19 @@ class ProfileSerializer(Serializer): skills = fields.Nested("RatedSkillSerializer", attribute="ratedskill_set", many=True) team = fields.Nested("TeamSerializer") tags = fields.Nested("TagSerializer", many=True) + user = fields.Nested("UserSerializer") class Meta: fields = [ "username", + "avatar", "email", + "phone", "role", "skills", "team", "tags", + "user", ] diff --git a/tests/test_views.py b/tests/test_views.py index 16381fe..11b304d 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,4 +1,22 @@ from datetime import timedelta +from unittest.mock import patch + +from django.core.files.uploadedfile import SimpleUploadedFile +from django.test.client import MULTIPART_CONTENT + + +@patch('django.core.files.storage.FileSystemStorage.save') +def test_profile_create(mock_save, client, db, role, user): + avatar = SimpleUploadedFile("avatar.jpg", b"", content_type="image/jpeg") + mock_save.return_value = "avatar.jpg" + payload = dict(avatar=avatar, role=role.pk, user=user.pk) + response = client.post(f"/profiles/", data=payload, content_type=MULTIPART_CONTENT) + result = response.json() + assert response.status_code == 201, result + assert result["avatar"] == "/avatar.jpg" + assert result["role"]["id"] == role.pk + assert result["role"]["name"] == role.name + assert result["user"]["username"] == user.username def test_profile_detail(client, db, profile, user): diff --git a/tests/views.py b/tests/views.py index 798bcc6..e278639 100644 --- a/tests/views.py +++ b/tests/views.py @@ -4,13 +4,13 @@ from django.db.models.functions import Concat from worf.permissions import PublicEndpoint -from worf.views import DeleteAPI, DetailAPI, DetailUpdateAPI, ListAPI, UpdateAPI +from worf.views import CreateAPI, DeleteAPI, DetailAPI, ListAPI, UpdateAPI from tests.models import Profile from tests.serializers import ProfileSerializer, UserSerializer -class ProfileList(ListAPI): +class ProfileList(CreateAPI, ListAPI): model = Profile queryset = Profile.objects.annotate( name=Concat("user__first_name", Value(" "), "user__last_name"), @@ -64,7 +64,7 @@ class UserList(ListAPI): ] -class UserDetail(DetailUpdateAPI): +class UserDetail(UpdateAPI, DetailAPI): model = User serializer = UserSerializer permissions = [PublicEndpoint] diff --git a/worf/fields.py b/worf/fields.py index 9bb8fed..ad8d3ef 100644 --- a/worf/fields.py +++ b/worf/fields.py @@ -4,6 +4,11 @@ from marshmallow.fields import * # noqa: F401, F403 +class File(marshmallow.fields.Field): + def _serialize(self, value, attr, obj, **kwargs): + return value.url if value.name else None + + class Nested(marshmallow.fields.Nested): def _serialize(self, nested_obj, attr, obj, **kwargs): if isinstance(nested_obj, Manager): diff --git a/worf/serializers.py b/worf/serializers.py index 9ffaced..c48d216 100644 --- a/worf/serializers.py +++ b/worf/serializers.py @@ -2,6 +2,7 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.db.models.fields.files import FieldFile from worf import fields # noqa: F401 from worf.casing import snake_to_camel @@ -33,6 +34,11 @@ class Serializer(marshmallow.Schema): OPTIONS_CLASS = SerializerOptions + TYPE_MAPPING = { + **marshmallow.Schema.TYPE_MAPPING, + FieldFile: fields.File, + } + def __repr__(self): return f"<{self.__class__.__name__}()>" diff --git a/worf/validators.py b/worf/validators.py index 87cb741..122148f 100644 --- a/worf/validators.py +++ b/worf/validators.py @@ -218,6 +218,9 @@ def validate_bundle(self, key): elif isinstance(field, models.BooleanField): self.bundle[key] = self._validate_boolean(key) + elif isinstance(field, models.FileField): + pass # Django will raise an exception if handled improperly + elif isinstance(field, models.ForeignKey): pass # Django will raise an exception if handled improperly diff --git a/worf/views/base.py b/worf/views/base.py index 8b520c9..d9c9e90 100644 --- a/worf/views/base.py +++ b/worf/views/base.py @@ -116,24 +116,6 @@ def _check_permissions(self): ) ) - def _assemble_bundle_from_request_body(self): - - if self.request.content_type == "multipart/form-data": - # Avoid RawPostDataException - # TODO investigate why test did not catch this error - raw_bundle = {} - - elif self.request.body: # and self.request.body != b'--BoUnDaRyStRiNg--\r\n': - try: - raw_bundle = json.loads(self.request.body) - except json.decoder.JSONDecodeError: - # print("\n\n~~~~~~~~~~~~~~~~~~~~~~", self.request.body, '\n\n') - raw_bundle = {} - else: - raw_bundle = {} - - self.set_bundle(raw_bundle) - def _get_lookup_field(self, field): related = field.find("__") @@ -182,29 +164,44 @@ def validate_lookup_field_values(self): ): self.validate_numeric(url_kwarg) - def set_bundle_from_querystring(self): + def flatten_bundle(self, raw_bundle): # parse_qs gives us a dictionary where all values are lists - qs = parse_qs(self.request.META["QUERY_STRING"]) - - raw_bundle = {} - - for key, value in qs.items(): - raw_bundle[key] = value[0] if len(value) == 1 else value - - self.set_bundle(raw_bundle) + return { + key: value[0] if len(value) == 1 else value + for key, value in raw_bundle.items() + } def set_bundle(self, raw_bundle): self.bundle = {} self.keymap = {} if not raw_bundle: - return # No need to loop or set self.bundle again if it's empty + return for key in raw_bundle.keys(): field = camel_to_snake(key) self.bundle[field] = raw_bundle[key] self.keymap[field] = key + def set_bundle_from_querystring(self): + raw_bundle = self.flatten_bundle(parse_qs(self.request.META["QUERY_STRING"])) + + self.set_bundle(raw_bundle) + + def set_bundle_from_request_body(self): + raw_bundle = {} + + if self.request.content_type == "multipart/form-data": + raw_bundle.update(self.flatten_bundle(self.request.POST)) + raw_bundle.update(self.flatten_bundle(self.request.FILES)) + elif self.request.body: + try: + raw_bundle = json.loads(self.request.body) + except json.decoder.JSONDecodeError: + pass + + self.set_bundle(raw_bundle) + def dispatch(self, request, *args, **kwargs): method = request.method.lower() handler = self.http_method_not_allowed @@ -214,7 +211,7 @@ def dispatch(self, request, *args, **kwargs): try: self._check_permissions() # only returns 200 or HTTP_EXCEPTIONS - self._assemble_bundle_from_request_body() # sets self.bundle + self.set_bundle_from_request_body() # sets self.bundle return handler(request, *args, **kwargs) # calls self.serialize() except HTTP_EXCEPTIONS as e: return self.render_to_response(dict(message=e.message), e.status)