diff --git a/tests/test_views.py b/tests/test_views.py index 0793e5d..7cad909 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -27,7 +27,7 @@ def test_profile_list(client, db, profile, user): assert result["profiles"][0]["username"] == user.username -def test_profile_list_filter(client, db, profile, user): +def test_profile_list_filters(client, db, profile, user): response = client.get(f"/profiles/?name={user.first_name} {user.last_name}") result = response.json() assert response.status_code == 200, result @@ -35,7 +35,7 @@ def test_profile_list_filter(client, db, profile, user): assert result["profiles"][0]["username"] == user.username -def test_profile_list_icontains_filter(client, db, profile, user): +def test_profile_list_icontains_filters(client, db, profile, user): response = client.get(f"/profiles/?name__icontains={user.first_name}") result = response.json() assert response.status_code == 200, result @@ -43,7 +43,7 @@ def test_profile_list_icontains_filter(client, db, profile, user): assert result["profiles"][0]["username"] == user.username -def test_profile_list_annotation_filter(client, db, profile_factory): +def test_profile_list_annotation_filters(client, db, profile_factory): profile_factory.create(user__date_joined="2020-01-01T00:00:00Z") profile_factory.create(user__date_joined="2020-12-01T00:00:00Z") response = client.get("/profiles/?dateJoined__gte=2020-06-01T00:00:00Z") @@ -52,7 +52,7 @@ def test_profile_list_annotation_filter(client, db, profile_factory): assert len(result["profiles"]) == 1 -def test_profile_list_and_filter(client, db, profile_factory, tag_factory): +def test_profile_list_and_filters(client, db, profile_factory, tag_factory): tag1, tag2, tag3 = tag_factory.create_batch(3) profile_factory.create(tags=[tag1]) profile_factory.create(tags=[tag2]) @@ -65,7 +65,7 @@ def test_profile_list_and_filter(client, db, profile_factory, tag_factory): assert len(result["profiles"]) == 1 -def test_profile_list_in_array_filter(client, db, profile, user): +def test_profile_list_in_array_filters(client, db, profile, user): response = client.get(f"/profiles/?name__in={user.first_name} {user.last_name}&name__in=Din Djarin") result = response.json() assert response.status_code == 200, result @@ -73,7 +73,7 @@ def test_profile_list_in_array_filter(client, db, profile, user): assert result["profiles"][0]["username"] == user.username -def test_profile_list_in_string_filter(client, db, profile, user): +def test_profile_list_in_string_filters(client, db, profile, user): response = client.get(f"/profiles/?name__in={user.first_name} {user.last_name},Din Djarin") result = response.json() assert response.status_code == 200, result @@ -81,7 +81,7 @@ def test_profile_list_in_string_filter(client, db, profile, user): assert result["profiles"][0]["username"] == user.username -def test_profile_list_or_filter(client, db, profile_factory, tag_factory): +def test_profile_list_or_filters(client, db, profile_factory, tag_factory): tag1, tag2, tag3 = tag_factory.create_batch(3) profile_factory.create(tags=[tag1]) profile_factory.create(tags=[tag2]) @@ -94,35 +94,35 @@ def test_profile_list_or_filter(client, db, profile_factory, tag_factory): assert len(result["profiles"]) == 3 -def test_profile_list_negated_filter(client, db, profile, user): +def test_profile_list_negated_filters(client, db, profile, user): response = client.get(f"/profiles/?firstName!={user.first_name}") result = response.json() assert response.status_code == 200, result assert len(result["profiles"]) == 0 -def test_profile_list_negated__icontains__filter(client, db, profile, user): +def test_profile_list_negated__icontains__filters(client, db, profile, user): response = client.get(f"/profiles/?firstName__icontains!={user.first_name}") result = response.json() assert response.status_code == 200, result assert len(result["profiles"]) == 0 -def test_profile_list_not_in_array_filter(client, db, profile, user): +def test_profile_list_not_in_array_filters(client, db, profile, user): response = client.get(f"/profiles/?name__in!={user.first_name} {user.last_name}&name__in!=Din Djarin") result = response.json() assert response.status_code == 200, result assert len(result["profiles"]) == 0 -def test_profile_list_not_in_string_filter(client, db, profile, user): +def test_profile_list_not_in_string_filters(client, db, profile, user): response = client.get(f"/profiles/?name__in!={user.first_name} {user.last_name},Din Djarin") result = response.json() assert response.status_code == 200, result assert len(result["profiles"]) == 0 -def test_profile_list_subset_filter(client, db, profile, user): +def test_profile_list_subset_filters(client, db, profile, user): response = client.get(f"/profiles/subset/?name={user.first_name} {user.last_name}") result = response.json() assert response.status_code == 200, result @@ -303,14 +303,19 @@ def test_user_list(client, db, user): assert len(result["users"]) == 1 assert result["users"][0]["id"] == user.pk assert result["users"][0]["username"] == user.username + assert result["users"][0]["email"] == user.email -def test_user_list_fields(client, db, user): - response = client.get("/users/?fields=username") +@pytest.mark.parametrize("url,url_invalid", [ + ("/users/?fields=id,username", "/users/?fields=id,invalid,username"), + ("/users/?fields=id&fields=username", "/users/?fields=id&fields=invalid&fields=username"), +]) +def test_user_list_fields(client, db, url, url_invalid, user): + response = client.get(url) result = response.json() assert response.status_code == 200, result - assert result["users"] == [dict(username=user.username)] - response = client.get("/users/?fields=invalid") + assert result["users"] == [dict(id=user.pk, username=user.username)] + response = client.get(url_invalid) result = response.json() assert response.status_code == 400, result assert result == dict(message="Invalid fields: OrderedSet(['invalid'])") @@ -368,12 +373,16 @@ def test_user_list_sort_desc(client, db, user_factory): assert result["users"][1]["username"] == "a" -def test_user_list_multisort(client, now, db, user_factory): +@pytest.mark.parametrize("url", [ + "/users/?sort=dateJoined,-id,invalid", + "/users/?sort=dateJoined&sort=-id&sort=invalid", +]) +def test_user_list_multisort(client, now, db, url, user_factory): user_factory.create(username="a", date_joined=now) user_factory.create(username="b", date_joined=now - timedelta(hours=1)) user_factory.create(username="c", date_joined=now) user_factory.create(username="d", date_joined=now) - response = client.get("/users/?sort=dateJoined&sort=-id&sort=invalid") + response = client.get(url) result = response.json() assert response.status_code == 200, result assert len(result["users"]) == 4 diff --git a/worf/serializers.py b/worf/serializers.py index 3c630be..c81716d 100644 --- a/worf/serializers.py +++ b/worf/serializers.py @@ -7,6 +7,7 @@ from worf.casing import camel_to_snake, snake_to_camel from worf.conf import settings from worf.exceptions import SerializerError +from worf.shortcuts import list_param class SerializeModels: @@ -31,17 +32,12 @@ def get_serializer_context(self): def get_serializer_kwargs(self): return dict( context=dict(request=self.request, **self.get_serializer_context()), - only=self.get_serializer_only(), + only=[ + ".".join(map(camel_to_snake, field.split("."))) + for field in list_param(self.bundle.get("fields", [])) + ], ) - def get_serializer_only(self): - only = self.bundle.get("fields") - if isinstance(only, str): - only = only.split(",") - if isinstance(only, list): - only = [".".join(map(camel_to_snake, field.split("."))) for field in only] - return only - def load_serializer(self): try: return self.get_serializer() @@ -89,6 +85,9 @@ class Serializer(marshmallow.Schema): FieldFile: fields.File, } + def __init__(self, only=(), *args, **kwargs): + super().__init__(only=only or None, *args, **kwargs) + def __call__(self, **kwargs): only = self.only if self.only and kwargs.get("only"): diff --git a/worf/shortcuts.py b/worf/shortcuts.py index 08c4f3a..4835bf4 100644 --- a/worf/shortcuts.py +++ b/worf/shortcuts.py @@ -13,3 +13,7 @@ def get_current_version(): return f"{__version__}@{hash}" except: # pragma: no cover # noqa E722 Dont crash for any reason whatsoever return __version__ + + +def list_param(value): + return value.split(",") if isinstance(value, str) else value diff --git a/worf/views/list.py b/worf/views/list.py index 359c967..be85e80 100644 --- a/worf/views/list.py +++ b/worf/views/list.py @@ -10,6 +10,7 @@ from worf.conf import settings from worf.exceptions import HTTP420 from worf.filters import apply_filterset, generate_filterset +from worf.shortcuts import list_param from worf.views.base import AbstractBaseAPI from worf.views.create import CreateAPI @@ -131,7 +132,7 @@ def get_processed_queryset(self): lookups = self.lookup_kwargs.items() filterset_kwargs = {k: v for k, v in lookups if not isinstance(v, list)} list_kwargs = {k: v for k, v in lookups if isinstance(v, list)} - ordering = self.get_ordering(self.request.GET.getlist("sort")) + ordering = self.get_ordering() queryset = self.get_queryset() @@ -159,10 +160,10 @@ def get_processed_queryset(self): return queryset - def get_ordering(self, sorts): + def get_ordering(self): ordering = [] - for sort in sorts: + for sort in list_param(self.bundle.get("sort", [])): field = "__".join(map(camel_to_snake, sort.lstrip("-").split("."))) if field not in self.sort_fields: continue