Skip to content

Commit

Permalink
File upload support
Browse files Browse the repository at this point in the history
  • Loading branch information
stevelacey committed Nov 17, 2021
1 parent bdb324b commit cf6e448
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 34 deletions.
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Table of contents
- [UpdateAPI](#updateapi)
- [Bundle loading](#bundle-loading)
- [Field casing](#field-casing)
- [File uploads](#file-uploads)
- [Internal naming](#internal-naming)
- [Credits](#credits)

Expand All @@ -56,8 +57,9 @@ Roadmap
- [x] Abstracting serializers away from model methods
- [x] Declarative marshmallow-based serialization
- [x] More support for different HTTP methods
- [x] File upload support on POST
- [ ] File upload support on PATCH/PUT
- [ ] Support for user-generated validators
- [ ] Better file upload support
- [ ] Better test coverage
- [ ] Browsable API docs

Expand Down Expand Up @@ -310,6 +312,14 @@ This will be strictly translated by the API, and acronyms are not considered:
- `API_strict == apiStrict`


File uploads
------------

File uploads are supported via `POST` using `multipart/form-data` requests.

Support for `PATCH`/`PUT` is on the roadmap.


Internal naming
---------------

Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
addopts =
--cov
--cov-fail-under 78
--cov-fail-under 80
--cov-report term:skip-covered
--cov-report html
--no-cov-on-fail
Expand Down
1 change: 1 addition & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class Profile(models.Model):
id = models.UUIDField(primary_key=True, default=uuid4)

avatar = models.FileField(upload_to="avatars/", blank=True)
email = models.CharField(max_length=64)
phone = models.CharField(max_length=64)

Expand Down
4 changes: 4 additions & 0 deletions tests/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@ class ProfileSerializer(Serializer):
skills = fields.Nested("RatedSkillSerializer", attribute="ratedskill_set", many=True)
team = fields.Nested("TeamSerializer")
tags = fields.Nested("TagSerializer", many=True)
user = fields.Nested("UserSerializer")

class Meta:
fields = [
"username",
"avatar",
"email",
"phone",
"role",
"skills",
"team",
"tags",
"user",
]


Expand Down
18 changes: 18 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
from datetime import timedelta
from unittest.mock import patch

from django.core.files.uploadedfile import SimpleUploadedFile
from django.test.client import MULTIPART_CONTENT


@patch('django.core.files.storage.FileSystemStorage.save')
def test_profile_create(mock_save, client, db, role, user):
avatar = SimpleUploadedFile("avatar.jpg", b"", content_type="image/jpeg")
mock_save.return_value = "avatar.jpg"
payload = dict(avatar=avatar, role=role.pk, user=user.pk)
response = client.post(f"/profiles/", data=payload, content_type=MULTIPART_CONTENT)
result = response.json()
assert response.status_code == 201, result
assert result["avatar"] == "/avatar.jpg"
assert result["role"]["id"] == role.pk
assert result["role"]["name"] == role.name
assert result["user"]["username"] == user.username


def test_profile_detail(client, db, profile, user):
Expand Down
6 changes: 3 additions & 3 deletions tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from django.db.models.functions import Concat

from worf.permissions import PublicEndpoint
from worf.views import DeleteAPI, DetailAPI, DetailUpdateAPI, ListAPI, UpdateAPI
from worf.views import CreateAPI, DeleteAPI, DetailAPI, ListAPI, UpdateAPI

from tests.models import Profile
from tests.serializers import ProfileSerializer, UserSerializer


class ProfileList(ListAPI):
class ProfileList(CreateAPI, ListAPI):
model = Profile
queryset = Profile.objects.annotate(
name=Concat("user__first_name", Value(" "), "user__last_name"),
Expand Down Expand Up @@ -64,7 +64,7 @@ class UserList(ListAPI):
]


class UserDetail(DetailUpdateAPI):
class UserDetail(UpdateAPI, DetailAPI):
model = User
serializer = UserSerializer
permissions = [PublicEndpoint]
5 changes: 5 additions & 0 deletions worf/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from marshmallow.fields import * # noqa: F401, F403


class File(marshmallow.fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
return value.url if value.name else None


class Nested(marshmallow.fields.Nested):
def _serialize(self, nested_obj, attr, obj, **kwargs):
if isinstance(nested_obj, Manager):
Expand Down
6 changes: 6 additions & 0 deletions worf/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.db.models.fields.files import FieldFile

from worf import fields # noqa: F401
from worf.casing import snake_to_camel
Expand Down Expand Up @@ -33,6 +34,11 @@ class Serializer(marshmallow.Schema):

OPTIONS_CLASS = SerializerOptions

TYPE_MAPPING = {
**marshmallow.Schema.TYPE_MAPPING,
FieldFile: fields.File,
}

def __repr__(self):
return f"<{self.__class__.__name__}()>"

Expand Down
3 changes: 3 additions & 0 deletions worf/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def validate_bundle(self, key):
elif isinstance(field, models.BooleanField):
self.bundle[key] = self._validate_boolean(key)

elif isinstance(field, models.FileField):
pass # Django will raise an exception if handled improperly

elif isinstance(field, models.ForeignKey):
pass # Django will raise an exception if handled improperly

Expand Down
55 changes: 26 additions & 29 deletions worf/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,6 @@ def _check_permissions(self):
)
)

def _assemble_bundle_from_request_body(self):

if self.request.content_type == "multipart/form-data":
# Avoid RawPostDataException
# TODO investigate why test did not catch this error
raw_bundle = {}

elif self.request.body: # and self.request.body != b'--BoUnDaRyStRiNg--\r\n':
try:
raw_bundle = json.loads(self.request.body)
except json.decoder.JSONDecodeError:
# print("\n\n~~~~~~~~~~~~~~~~~~~~~~", self.request.body, '\n\n')
raw_bundle = {}
else:
raw_bundle = {}

self.set_bundle(raw_bundle)

def _get_lookup_field(self, field):
related = field.find("__")

Expand Down Expand Up @@ -182,29 +164,44 @@ def validate_lookup_field_values(self):
):
self.validate_numeric(url_kwarg)

def set_bundle_from_querystring(self):
def flatten_bundle(self, raw_bundle):
# parse_qs gives us a dictionary where all values are lists
qs = parse_qs(self.request.META["QUERY_STRING"])

raw_bundle = {}

for key, value in qs.items():
raw_bundle[key] = value[0] if len(value) == 1 else value

self.set_bundle(raw_bundle)
return {
key: value[0] if len(value) == 1 else value
for key, value in raw_bundle.items()
}

def set_bundle(self, raw_bundle):
self.bundle = {}
self.keymap = {}

if not raw_bundle:
return # No need to loop or set self.bundle again if it's empty
return

for key in raw_bundle.keys():
field = camel_to_snake(key)
self.bundle[field] = raw_bundle[key]
self.keymap[field] = key

def set_bundle_from_querystring(self):
raw_bundle = self.flatten_bundle(parse_qs(self.request.META["QUERY_STRING"]))

self.set_bundle(raw_bundle)

def set_bundle_from_request_body(self):
raw_bundle = {}

if self.request.content_type == "multipart/form-data":
raw_bundle.update(self.flatten_bundle(self.request.POST))
raw_bundle.update(self.flatten_bundle(self.request.FILES))
elif self.request.body:
try:
raw_bundle = json.loads(self.request.body)
except json.decoder.JSONDecodeError:
pass

self.set_bundle(raw_bundle)

def dispatch(self, request, *args, **kwargs):
method = request.method.lower()
handler = self.http_method_not_allowed
Expand All @@ -214,7 +211,7 @@ def dispatch(self, request, *args, **kwargs):

try:
self._check_permissions() # only returns 200 or HTTP_EXCEPTIONS
self._assemble_bundle_from_request_body() # sets self.bundle
self.set_bundle_from_request_body() # sets self.bundle
return handler(request, *args, **kwargs) # calls self.serialize()
except HTTP_EXCEPTIONS as e:
return self.render_to_response(dict(message=e.message), e.status)
Expand Down

0 comments on commit cf6e448

Please sign in to comment.