Skip to content

Commit

Permalink
Support comma separated sorts
Browse files Browse the repository at this point in the history
  • Loading branch information
stevelacey committed Jun 25, 2022
1 parent b12887c commit 6b35678
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 30 deletions.
45 changes: 27 additions & 18 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,23 @@ def test_profile_list(client, db, profile, user):
assert result["profiles"][0]["username"] == user.username


def test_profile_list_filter(client, db, profile, user):
def test_profile_list_filters(client, db, profile, user):
response = client.get(f"/profiles/?name={user.first_name} {user.last_name}")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 1
assert result["profiles"][0]["username"] == user.username


def test_profile_list_icontains_filter(client, db, profile, user):
def test_profile_list_icontains_filters(client, db, profile, user):
response = client.get(f"/profiles/?name__icontains={user.first_name}")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 1
assert result["profiles"][0]["username"] == user.username


def test_profile_list_annotation_filter(client, db, profile_factory):
def test_profile_list_annotation_filters(client, db, profile_factory):
profile_factory.create(user__date_joined="2020-01-01T00:00:00Z")
profile_factory.create(user__date_joined="2020-12-01T00:00:00Z")
response = client.get("/profiles/?dateJoined__gte=2020-06-01T00:00:00Z")
Expand All @@ -52,7 +52,7 @@ def test_profile_list_annotation_filter(client, db, profile_factory):
assert len(result["profiles"]) == 1


def test_profile_list_and_filter(client, db, profile_factory, tag_factory):
def test_profile_list_and_filters(client, db, profile_factory, tag_factory):
tag1, tag2, tag3 = tag_factory.create_batch(3)
profile_factory.create(tags=[tag1])
profile_factory.create(tags=[tag2])
Expand All @@ -65,23 +65,23 @@ def test_profile_list_and_filter(client, db, profile_factory, tag_factory):
assert len(result["profiles"]) == 1


def test_profile_list_in_array_filter(client, db, profile, user):
def test_profile_list_in_array_filters(client, db, profile, user):
response = client.get(f"/profiles/?name__in={user.first_name} {user.last_name}&name__in=Din Djarin")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 1
assert result["profiles"][0]["username"] == user.username


def test_profile_list_in_string_filter(client, db, profile, user):
def test_profile_list_in_string_filters(client, db, profile, user):
response = client.get(f"/profiles/?name__in={user.first_name} {user.last_name},Din Djarin")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 1
assert result["profiles"][0]["username"] == user.username


def test_profile_list_or_filter(client, db, profile_factory, tag_factory):
def test_profile_list_or_filters(client, db, profile_factory, tag_factory):
tag1, tag2, tag3 = tag_factory.create_batch(3)
profile_factory.create(tags=[tag1])
profile_factory.create(tags=[tag2])
Expand All @@ -94,35 +94,35 @@ def test_profile_list_or_filter(client, db, profile_factory, tag_factory):
assert len(result["profiles"]) == 3


def test_profile_list_negated_filter(client, db, profile, user):
def test_profile_list_negated_filters(client, db, profile, user):
response = client.get(f"/profiles/?firstName!={user.first_name}")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 0


def test_profile_list_negated__icontains__filter(client, db, profile, user):
def test_profile_list_negated__icontains__filters(client, db, profile, user):
response = client.get(f"/profiles/?firstName__icontains!={user.first_name}")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 0


def test_profile_list_not_in_array_filter(client, db, profile, user):
def test_profile_list_not_in_array_filters(client, db, profile, user):
response = client.get(f"/profiles/?name__in!={user.first_name} {user.last_name}&name__in!=Din Djarin")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 0


def test_profile_list_not_in_string_filter(client, db, profile, user):
def test_profile_list_not_in_string_filters(client, db, profile, user):
response = client.get(f"/profiles/?name__in!={user.first_name} {user.last_name},Din Djarin")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 0


def test_profile_list_subset_filter(client, db, profile, user):
def test_profile_list_subset_filters(client, db, profile, user):
response = client.get(f"/profiles/subset/?name={user.first_name} {user.last_name}")
result = response.json()
assert response.status_code == 200, result
Expand Down Expand Up @@ -303,14 +303,19 @@ def test_user_list(client, db, user):
assert len(result["users"]) == 1
assert result["users"][0]["id"] == user.pk
assert result["users"][0]["username"] == user.username
assert result["users"][0]["email"] == user.email


def test_user_list_fields(client, db, user):
response = client.get("/users/?fields=username")
@pytest.mark.parametrize("url,url_invalid", [
("/users/?fields=id,username", "/users/?fields=id,invalid,username"),
("/users/?fields=id&fields=username", "/users/?fields=id&fields=invalid&fields=username"),
])
def test_user_list_fields(client, db, url, url_invalid, user):
response = client.get(url)
result = response.json()
assert response.status_code == 200, result
assert result["users"] == [dict(username=user.username)]
response = client.get("/users/?fields=invalid")
assert result["users"] == [dict(id=user.pk, username=user.username)]
response = client.get(url_invalid)
result = response.json()
assert response.status_code == 400, result
assert result == dict(message="Invalid fields: OrderedSet(['invalid'])")
Expand Down Expand Up @@ -368,12 +373,16 @@ def test_user_list_sort_desc(client, db, user_factory):
assert result["users"][1]["username"] == "a"


def test_user_list_multisort(client, now, db, user_factory):
@pytest.mark.parametrize("url", [
"/users/?sort=dateJoined,-id,invalid",
"/users/?sort=dateJoined&sort=-id&sort=invalid",
])
def test_user_list_multisort(client, now, db, url, user_factory):
user_factory.create(username="a", date_joined=now)
user_factory.create(username="b", date_joined=now - timedelta(hours=1))
user_factory.create(username="c", date_joined=now)
user_factory.create(username="d", date_joined=now)
response = client.get("/users/?sort=dateJoined&sort=-id&sort=invalid")
response = client.get(url)
result = response.json()
assert response.status_code == 200, result
assert len(result["users"]) == 4
Expand Down
17 changes: 8 additions & 9 deletions worf/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from worf.casing import camel_to_snake, snake_to_camel
from worf.conf import settings
from worf.exceptions import SerializerError
from worf.shortcuts import list_param


class SerializeModels:
Expand All @@ -31,17 +32,12 @@ def get_serializer_context(self):
def get_serializer_kwargs(self):
return dict(
context=dict(request=self.request, **self.get_serializer_context()),
only=self.get_serializer_only(),
only=[
".".join(map(camel_to_snake, field.split(".")))
for field in list_param(self.bundle.get("fields", []))
],
)

def get_serializer_only(self):
only = self.bundle.get("fields")
if isinstance(only, str):
only = only.split(",")
if isinstance(only, list):
only = [".".join(map(camel_to_snake, field.split("."))) for field in only]
return only

def load_serializer(self):
try:
return self.get_serializer()
Expand Down Expand Up @@ -89,6 +85,9 @@ class Serializer(marshmallow.Schema):
FieldFile: fields.File,
}

def __init__(self, only=(), *args, **kwargs):
super().__init__(only=only or None, *args, **kwargs)

def __call__(self, **kwargs):
only = self.only
if self.only and kwargs.get("only"):
Expand Down
4 changes: 4 additions & 0 deletions worf/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ def get_current_version():
return f"{__version__}@{hash}"
except: # pragma: no cover # noqa E722 Dont crash for any reason whatsoever
return __version__


def list_param(value):
return value.split(",") if isinstance(value, str) else value
7 changes: 4 additions & 3 deletions worf/views/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from worf.conf import settings
from worf.exceptions import HTTP420
from worf.filters import apply_filterset, generate_filterset
from worf.shortcuts import list_param
from worf.views.base import AbstractBaseAPI
from worf.views.create import CreateAPI

Expand Down Expand Up @@ -131,7 +132,7 @@ def get_processed_queryset(self):
lookups = self.lookup_kwargs.items()
filterset_kwargs = {k: v for k, v in lookups if not isinstance(v, list)}
list_kwargs = {k: v for k, v in lookups if isinstance(v, list)}
ordering = self.get_ordering(self.request.GET.getlist("sort"))
ordering = self.get_ordering()

queryset = self.get_queryset()

Expand Down Expand Up @@ -159,10 +160,10 @@ def get_processed_queryset(self):

return queryset

def get_ordering(self, sorts):
def get_ordering(self):
ordering = []

for sort in sorts:
for sort in list_param(self.bundle.get("sort", [])):
field = "__".join(map(camel_to_snake, sort.lstrip("-").split(".")))
if field not in self.sort_fields:
continue
Expand Down

0 comments on commit 6b35678

Please sign in to comment.