diff --git a/netbox/tenancy/filtersets.py b/netbox/tenancy/filtersets.py index 0f4900f546f..8bc659a885e 100644 --- a/netbox/tenancy/filtersets.py +++ b/netbox/tenancy/filtersets.py @@ -91,6 +91,19 @@ class ContactAssignmentFilterSet(ChangeLoggedModelFilterSet): queryset=Contact.objects.all(), label=_('Contact (ID)'), ) + group_id = TreeNodeMultipleChoiceFilter( + queryset=ContactGroup.objects.all(), + field_name='contact__group', + lookup_expr='in', + label=_('Contact group (ID)'), + ) + group = TreeNodeMultipleChoiceFilter( + queryset=ContactGroup.objects.all(), + field_name='contact__group', + lookup_expr='in', + to_field_name='slug', + label=_('Contact group (slug)'), + ) role_id = django_filters.ModelMultipleChoiceFilter( queryset=ContactRole.objects.all(), label=_('Contact role (ID)'), diff --git a/netbox/tenancy/tests/test_filtersets.py b/netbox/tenancy/tests/test_filtersets.py index e427c90ce42..d7337396e04 100644 --- a/netbox/tenancy/tests/test_filtersets.py +++ b/netbox/tenancy/tests/test_filtersets.py @@ -1,5 +1,7 @@ +from django.contrib.contenttypes.models import ContentType from django.test import TestCase +from dcim.models import Manufacturer, Site from tenancy.filtersets import * from tenancy.models import * from utilities.testing import ChangeLoggedFilterSetTests @@ -192,3 +194,72 @@ def test_group(self): self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) params = {'group': [group[0].slug, group[1].slug]} self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + +class ContactAssignmentTestCase(TestCase, ChangeLoggedFilterSetTests): + queryset = ContactAssignment.objects.all() + filterset = ContactAssignmentFilterSet + + @classmethod + def setUpTestData(cls): + + manufacturer = Manufacturer.objects.create(name='Manufacturer 1', slug='manufacturer-1') + sites = ( + Site(name='Site 1', slug='site-1'), + Site(name='Site 2', slug='site-2'), + Site(name='Site 3', slug='site-3'), + ) + Site.objects.bulk_create(sites) + + contact_groups = ( + ContactGroup(name='Contact Group 1', slug='contact-group-1'), + ContactGroup(name='Contact Group 2', slug='contact-group-2'), + ContactGroup(name='Contact Group 3', slug='contact-group-3'), + ) + for contactgroup in contact_groups: + contactgroup.save() + + contact_roles = ( + ContactRole(name='Contact Role 1', slug='contact-role-1'), + ContactRole(name='Contact Role 2', slug='contact-role-2'), + ContactRole(name='Contact Role 3', slug='contact-role-3'), + ) + ContactRole.objects.bulk_create(contact_roles) + + contacts = ( + Contact(name='Contact 1', group=contact_groups[0]), + Contact(name='Contact 2', group=contact_groups[1]), + Contact(name='Contact 3', group=contact_groups[2]), + ) + Contact.objects.bulk_create(contacts) + + assignments = ( + ContactAssignment(object=sites[0], contact=contacts[0], role=contact_roles[0]), + ContactAssignment(object=sites[1], contact=contacts[1], role=contact_roles[1]), + ContactAssignment(object=sites[2], contact=contacts[2], role=contact_roles[2]), + ContactAssignment(object=manufacturer, contact=contacts[2], role=contact_roles[2]), + ) + ContactAssignment.objects.bulk_create(assignments) + + def test_content_type(self): + params = {'content_type_id': ContentType.objects.get_by_natural_key('dcim', 'site')} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3) + + def test_contact(self): + contacts = Contact.objects.all()[:2] + params = {'contact_id': [contacts[0].pk, contacts[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_group(self): + group = ContactGroup.objects.all()[:2] + params = {'group_id': [group[0].pk, group[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + params = {'group': [group[0].slug, group[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + + def test_role(self): + role = ContactRole.objects.all()[:2] + params = {'role_id': [role[0].pk, role[1].pk]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2) + params = {'role': [role[0].slug, role[1].slug]} + self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)