Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

16078 make GraphQL NumberFilter optional #16115

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions netbox/ipam/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,9 @@ class IPAddressTest(APIViewTestCases.APIViewTestCase):
bulk_update_data = {
'description': 'New description',
}
graphql_filter = {
'address': '192.168.0.1/24',
}

@classmethod
def setUpTestData(cls):
Expand Down
2 changes: 1 addition & 1 deletion netbox/netbox/graphql/filter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def map_strawberry_type(field):
pass
elif issubclass(type(field), django_filters.NumberFilter):
should_create_function = True
attr_type = int
attr_type = int | None
elif issubclass(type(field), django_filters.ModelMultipleChoiceFilter):
should_create_function = True
attr_type = List[str] | None
Expand Down
60 changes: 54 additions & 6 deletions netbox/utilities/testing/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,13 +440,12 @@ def _get_graphql_base_name(self):
base_name = self.model._meta.verbose_name.lower().replace(' ', '_')
return getattr(self, 'graphql_base_name', base_name)

def _build_query(self, name, **filters):
def _build_query_with_filter(self, name, filter_string):
arthanson marked this conversation as resolved.
Show resolved Hide resolved
"""
Called by either _build_query or _build_filtered_query - construct the actual
query given a name and filter string
"""
type_class = get_graphql_type_for_model(self.model)
if filters:
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
filter_string = f'({filter_string})'
else:
filter_string = ''

# Compile list of fields to include
fields_string = ''
Expand Down Expand Up @@ -492,6 +491,30 @@ def _build_query(self, name, **filters):

return query

def _build_filtered_query(self, name, **filters):
"""
Create a filtered query: i.e. ip_address_list(filters: {address: "1.1.1.1/24"}){.
"""
if filters:
filter_string = ', '.join(f'{k}: "{v}"' for k, v in filters.items())
filter_string = f'(filters: {{{filter_string}}})'
else:
filter_string = ''

return self._build_query_with_filter(name, filter_string)

def _build_query(self, name, **filters):
"""
Create a normal query - unfiltered or with a string query: i.e. site(name: "aaa"){.
"""
if filters:
filter_string = ', '.join(f'{k}:{v}' for k, v in filters.items())
filter_string = f'({filter_string})'
else:
filter_string = ''

return self._build_query_with_filter(name, filter_string)

@override_settings(LOGIN_REQUIRED=True)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
def test_graphql_get_object(self):
Expand Down Expand Up @@ -550,6 +573,31 @@ def test_graphql_list_objects(self):
self.assertNotIn('errors', data)
self.assertGreater(len(data['data'][field_name]), 0)

@override_settings(LOGIN_REQUIRED=True)
@override_settings(EXEMPT_VIEW_PERMISSIONS=['*', 'auth.user'])
def test_graphql_filter_objects(self):
if not hasattr(self, 'graphql_filter'):
return

url = reverse('graphql')
field_name = f'{self._get_graphql_base_name()}_list'
query = self._build_filtered_query(field_name, **self.graphql_filter)

# Add object-level permission
obj_perm = ObjectPermission(
name='Test permission',
actions=['view']
)
obj_perm.save()
obj_perm.users.add(self.user)
obj_perm.object_types.add(ObjectType.objects.get_for_model(self.model))

response = self.client.post(url, data={'query': query}, format="json", **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = json.loads(response.content)
self.assertNotIn('errors', data)
self.assertGreater(len(data['data'][field_name]), 0)

class APIViewTestCase(
GetObjectViewTestCase,
ListObjectsViewTestCase,
Expand Down
Loading