diff --git a/drf_spectacular/contrib/django_filters.py b/drf_spectacular/contrib/django_filters.py index 83ffbc1b..53fe4ff6 100644 --- a/drf_spectacular/contrib/django_filters.py +++ b/drf_spectacular/contrib/django_filters.py @@ -1,4 +1,4 @@ -from drf_spectacular.plumbing import build_parameter_type, get_view_model +from drf_spectacular.plumbing import build_parameter_type, follow_field_source, get_view_model from drf_spectacular.utils import OpenApiParameter try: @@ -20,7 +20,8 @@ def get_schema_operation_parameters(self, view): parameters = [] for field_name, field in filterset_class.base_filters.items(): - model_field = model._meta.get_field(field.field_name) + path = field.field_name.split('__') + model_field = follow_field_source(model, path) parameters.append(build_parameter_type( name=field_name, diff --git a/tests/contrib/test_django_filters.py b/tests/contrib/test_django_filters.py index b19dab80..143d8aec 100644 --- a/tests/contrib/test_django_filters.py +++ b/tests/contrib/test_django_filters.py @@ -23,6 +23,11 @@ class Product(models.Model): price = models.FloatField() +class SubProduct(models.Model): + sub_price = models.FloatField() + product = models.ForeignKey(Product, on_delete=models.CASCADE) + + class ProductSerializer(serializers.ModelSerializer): class Meta: model = Product @@ -30,12 +35,13 @@ class Meta: class ProductFilter(FilterSet): - min_price = NumberFilter(field_name="price", lookup_expr='gte') max_price = NumberFilter(field_name="price", lookup_expr='lte') + max_sub_price = NumberFilter(field_name="subproduct__sub_price", lookup_expr='lte') + sub = NumberFilter(field_name="subproduct", lookup_expr='exact') class Meta: model = Product - fields = ['category', 'in_stock', 'min_price', 'max_price'] + fields = ['category', 'in_stock', 'max_price'] class ProductViewset(viewsets.ReadOnlyModelViewSet): @@ -63,11 +69,24 @@ def test_django_filters(no_warnings): @pytest.mark.urls(__name__) @pytest.mark.django_db def test_django_filters_requests(no_warnings): - Product.objects.create(category='X', price=4, in_stock=True) + product = Product.objects.create(category='X', price=4, in_stock=True) + SubProduct.objects.create(sub_price=5, product=product) - response = APIClient().get('/api/products/?min_price=3') + response = APIClient().get('/api/products/?max_price=1') + assert response.status_code == 200 + assert len(response.json()) == 0 + response = APIClient().get('/api/products/?max_price=5') + assert response.status_code == 200 + assert len(response.json()) == 1 + response = APIClient().get('/api/products/?max_sub_price=1') + assert response.status_code == 200 + assert len(response.json()) == 0 + response = APIClient().get('/api/products/?max_sub_price=6') + assert response.status_code == 200 + assert len(response.json()) == 1 + response = APIClient().get('/api/products/?sub=1') assert response.status_code == 200 assert len(response.json()) == 1 - response = APIClient().get('/api/products/?min_price=5') + response = APIClient().get('/api/products/?sub=2') assert response.status_code == 200 assert len(response.json()) == 0 diff --git a/tests/contrib/test_django_filters.yml b/tests/contrib/test_django_filters.yml index 48e9f5fd..291c7133 100644 --- a/tests/contrib/test_django_filters.yml +++ b/tests/contrib/test_django_filters.yml @@ -28,11 +28,16 @@ paths: format: float description: max_price - in: query - name: min_price + name: max_sub_price schema: type: number format: float - description: min_price + description: max_sub_price + - in: query + name: sub + schema: + type: integer + description: sub tags: - products security: