diff --git a/tests/test_views.py b/tests/test_views.py index 9e69c1e..06fb155 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -17,6 +17,7 @@ def test_profile_list(client, db, profile, user): assert len(result["profiles"]) == 1 assert result["profiles"][0]["username"] == user.username + def test_profile_list_search(client, db, profile, user): response = client.get(f"/profiles/?name={user.first_name} {user.last_name}") result = response.json() @@ -24,6 +25,16 @@ def test_profile_list_search(client, db, profile, user): assert len(result["profiles"]) == 1 assert result["profiles"][0]["username"] == user.username + +def test_profile_list_annotation_filter(client, db, profile_factory): + profile_factory.create(user__date_joined="2020-01-01T00:00:00") + profile_factory.create(user__date_joined="2020-12-01T00:00:00") + response = client.get("/profiles/?dateJoined__gte=2020-06-01T00:00:00") + result = response.json() + assert response.status_code == 200, result + assert len(result["profiles"]) == 1 + + def test_profile_list_subset_search(client, db, profile, user): response = client.get(f"/profiles/subset/?name={user.first_name} {user.last_name}") result = response.json() diff --git a/tests/views.py b/tests/views.py index 34098a2..10a1ef8 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,6 +1,6 @@ from django.core.exceptions import ValidationError from django.contrib.auth.models import User -from django.db.models import Value +from django.db.models import F, Value from django.db.models.functions import Concat from worf.permissions import PublicEndpoint @@ -26,12 +26,14 @@ class ProfileList(ListAPI): model = Profile queryset = Profile.objects.annotate( name=Concat("user__first_name", Value(" "), "user__last_name"), + date_joined=F("user__date_joined"), ) ordering = ["pk"] serializer = ProfileSerializer permissions = [PublicEndpoint] search_fields = [] filter_fields = [ + "date_joined__gte", "tags", ] diff --git a/worf/filters.py b/worf/filters.py index 4eb233b..8216886 100644 --- a/worf/filters.py +++ b/worf/filters.py @@ -1,15 +1,53 @@ from urllib.parse import urlencode +from django.db.models.fields.related import ForeignObjectRel, RelatedField from django.http import QueryDict from url_filter.filtersets import ModelFilterSet +from url_filter.exceptions import SkipFilter -def generate_filterset(model): +class AnnotatedModelFilterSet(ModelFilterSet): + def get_filters(self): + filters = super().get_filters() + + if self.queryset is not None: + state = self._build_state() + + for name in self.queryset.query.annotations.keys(): + if name in self.Meta.exclude or name in filters: + continue + + try: + annotation_filter = self._build_annotation_filter(name, state) + except SkipFilter: + continue + + if annotation_filter is not None: + filters[name] = annotation_filter + + return filters + + def _build_annotation_filter(self, name, state): + field = self.queryset.query.annotations.get(name).output_field + + if isinstance(field, RelatedField): + if not self.Meta.allow_related: + raise SkipFilter + return self._build_filterset_from_related_field(name, field) + elif isinstance(field, ForeignObjectRel): + if not self.Meta.allow_related_reverse: + raise SkipFilter + return self._build_filterset_from_reverse_field(name, field) + + return self._build_filter_from_field(name, field) + + +def generate_filterset(model, queryset): return type( f"{model.__name__}FilterSet", - (ModelFilterSet,), - dict(Meta=type("Meta", (), dict(model=model))), + (AnnotatedModelFilterSet,), + dict(Meta=type("Meta", (), dict(model=model, queryset=queryset))), ) diff --git a/worf/views/list.py b/worf/views/list.py index fa5ee63..8b588b9 100644 --- a/worf/views/list.py +++ b/worf/views/list.py @@ -54,7 +54,7 @@ def __init__(self, *args, **kwargs): # generate a default filterset if a custom one was not provided if self.filter_set is None: - self.filter_set = generate_filterset(self.model) + self.filter_set = generate_filterset(self.model, self.queryset) # support deprecated search_fields and/or dict syntax (note that `and` does nothing) if isinstance(self.search_fields, dict):