Skip to content

Commit

Permalink
Annotation filters
Browse files Browse the repository at this point in the history
  • Loading branch information
stevelacey committed Nov 8, 2021
1 parent ee6d9df commit cf01378
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 5 deletions.
11 changes: 11 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@ def test_profile_list(client, db, profile, user):
assert len(result["profiles"]) == 1
assert result["profiles"][0]["username"] == user.username


def test_profile_list_search(client, db, profile, user):
response = client.get(f"/profiles/?name={user.first_name} {user.last_name}")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 1
assert result["profiles"][0]["username"] == user.username


def test_profile_list_annotation_filter(client, db, profile_factory):
profile_factory.create(user__date_joined="2020-01-01T00:00:00")
profile_factory.create(user__date_joined="2020-12-01T00:00:00")
response = client.get("/profiles/?dateJoined__gte=2020-06-01T00:00:00")
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 1


def test_profile_list_subset_search(client, db, profile, user):
response = client.get(f"/profiles/subset/?name={user.first_name} {user.last_name}")
result = response.json()
Expand Down
4 changes: 3 additions & 1 deletion tests/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.core.exceptions import ValidationError
from django.contrib.auth.models import User
from django.db.models import Value
from django.db.models import F, Value
from django.db.models.functions import Concat

from worf.permissions import PublicEndpoint
Expand All @@ -26,12 +26,14 @@ class ProfileList(ListAPI):
model = Profile
queryset = Profile.objects.annotate(
name=Concat("user__first_name", Value(" "), "user__last_name"),
date_joined=F("user__date_joined"),
)
ordering = ["pk"]
serializer = ProfileSerializer
permissions = [PublicEndpoint]
search_fields = []
filter_fields = [
"date_joined__gte",
"tags",
]

Expand Down
44 changes: 41 additions & 3 deletions worf/filters.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,53 @@
from urllib.parse import urlencode

from django.db.models.fields.related import ForeignObjectRel, RelatedField
from django.http import QueryDict

from url_filter.filtersets import ModelFilterSet
from url_filter.exceptions import SkipFilter


def generate_filterset(model):
class AnnotatedModelFilterSet(ModelFilterSet):
def get_filters(self):
filters = super().get_filters()

if self.queryset is not None:
state = self._build_state()

for name in self.queryset.query.annotations.keys():
if name in self.Meta.exclude or name in filters:
continue

try:
annotation_filter = self._build_annotation_filter(name, state)
except SkipFilter:
continue

if annotation_filter is not None:
filters[name] = annotation_filter

return filters

def _build_annotation_filter(self, name, state):
field = self.queryset.query.annotations.get(name).output_field

if isinstance(field, RelatedField):
if not self.Meta.allow_related:
raise SkipFilter
return self._build_filterset_from_related_field(name, field)
elif isinstance(field, ForeignObjectRel):
if not self.Meta.allow_related_reverse:
raise SkipFilter
return self._build_filterset_from_reverse_field(name, field)

return self._build_filter_from_field(name, field)


def generate_filterset(model, queryset):
return type(
f"{model.__name__}FilterSet",
(ModelFilterSet,),
dict(Meta=type("Meta", (), dict(model=model))),
(AnnotatedModelFilterSet,),
dict(Meta=type("Meta", (), dict(model=model, queryset=queryset))),
)


Expand Down
2 changes: 1 addition & 1 deletion worf/views/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, *args, **kwargs):

# generate a default filterset if a custom one was not provided
if self.filter_set is None:
self.filter_set = generate_filterset(self.model)
self.filter_set = generate_filterset(self.model, self.queryset)

# support deprecated search_fields and/or dict syntax (note that `and` does nothing)
if isinstance(self.search_fields, dict):
Expand Down

0 comments on commit cf01378

Please sign in to comment.