diff --git a/netbox/netbox/graphql/filter_mixins.py b/netbox/netbox/graphql/filter_mixins.py index 5075e9aa282..76cfd891551 100644 --- a/netbox/netbox/graphql/filter_mixins.py +++ b/netbox/netbox/graphql/filter_mixins.py @@ -4,7 +4,7 @@ import django_filters import strawberry import strawberry_django -from django.core.exceptions import FieldDoesNotExist +from django.core.exceptions import FieldDoesNotExist, ValidationError from strawberry import auto from ipam.fields import ASNField from netbox.graphql.scalars import BigInt @@ -201,4 +201,9 @@ def wrapper(cls): class BaseFilterMixin: def filter_by_filterset(self, queryset, key): - return self.filterset(data={key: getattr(self, key)}, queryset=queryset).qs + filterset = self.filterset(data={key: getattr(self, key)}, queryset=queryset) + if not filterset.is_valid(): + # We could raise validation error but strawberry logs it all to the + # console i.e. raise ValidationError(f"{k}: {v[0]}") + return filterset.qs.none() + return filterset.qs diff --git a/netbox/netbox/tests/test_graphql.py b/netbox/netbox/tests/test_graphql.py index 2cf9ee87b9d..ab80c79c762 100644 --- a/netbox/netbox/tests/test_graphql.py +++ b/netbox/netbox/tests/test_graphql.py @@ -1,7 +1,13 @@ +import json + from django.test import override_settings from django.urls import reverse +from rest_framework import status -from utilities.testing import disable_warnings, TestCase +from core.models import ObjectType +from dcim.models import Site, Location +from users.models import ObjectPermission +from utilities.testing import disable_warnings, APITestCase, TestCase class GraphQLTestCase(TestCase): @@ -34,3 +40,45 @@ def test_graphiql_interface(self): response = self.client.get(url, **header) with disable_warnings('django.request'): self.assertHttpStatus(response, 302) # Redirect to login page + + +class GraphQLAPITestCase(APITestCase): + + @override_settings(LOGIN_REQUIRED=True) + @override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user']) + def test_graphql_filter_objects(self): + """ + Test the operation of filters for GraphQL API requests. + """ + sites = ( + Site(name='Site 1', slug='site-1'), + Site(name='Site 2', slug='site-2'), + ) + Site.objects.bulk_create(sites) + Location.objects.create(site=sites[0], name='Location 1', slug='location-1'), + Location.objects.create(site=sites[1], name='Location 2', slug='location-2'), + + # Add object-level permission + obj_perm = ObjectPermission( + name='Test permission', + actions=['view'] + ) + obj_perm.save() + obj_perm.users.add(self.user) + obj_perm.object_types.add(ObjectType.objects.get_for_model(Location)) + + # A valid request should return the filtered list + url = reverse('graphql') + query = '{location_list(filters: {site_id: "' + str(sites[0].pk) + '"}) {id site {id}}}' + response = self.client.post(url, data={'query': query}, format="json", **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertNotIn('errors', data) + self.assertEqual(len(data['data']['location_list']), 1) + + # An invalid request should return an empty list + query = '{location_list(filters: {site_id: "99999"}) {id site {id}}}' # Invalid site ID + response = self.client.post(url, data={'query': query}, format="json", **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = json.loads(response.content) + self.assertEqual(len(data['data']['location_list']), 0)