From 634681a72e9cdb1fefcf5779eceb8ec95ae6c13f Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Tue, 26 Dec 2023 13:15:23 -0500 Subject: [PATCH] Fixes #13606: Fix filtering by null for multiselect custom fields --- netbox/extras/models/customfields.py | 4 +--- netbox/extras/tests/test_customfields.py | 13 +++++++------ netbox/utilities/filters.py | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/netbox/extras/models/customfields.py b/netbox/extras/models/customfields.py index f70812bc075..ff887ddeb97 100644 --- a/netbox/extras/models/customfields.py +++ b/netbox/extras/models/customfields.py @@ -10,7 +10,6 @@ from django.core.validators import RegexValidator, ValidationError from django.db import models from django.urls import reverse -from django.utils.html import escape from django.utils.safestring import mark_safe from django.utils.translation import gettext_lazy as _ @@ -571,8 +570,7 @@ def to_filter(self, lookup_expr=None): # Multiselect elif self.type == CustomFieldTypeChoices.TYPE_MULTISELECT: - filter_class = filters.MultiValueCharFilter - kwargs['lookup_expr'] = 'has_key' + filter_class = filters.MultiValueArrayFilter # Object elif self.type == CustomFieldTypeChoices.TYPE_OBJECT: diff --git a/netbox/extras/tests/test_customfields.py b/netbox/extras/tests/test_customfields.py index 7ac6b20358d..574452a81c5 100644 --- a/netbox/extras/tests/test_customfields.py +++ b/netbox/extras/tests/test_customfields.py @@ -1329,7 +1329,7 @@ def setUpTestData(cls): choice_set = CustomFieldChoiceSet.objects.create( name='Custom Field Choice Set 1', - extra_choices=(('a', 'A'), ('b', 'B'), ('c', 'C'), ('x', 'X')) + extra_choices=(('a', 'A'), ('b', 'B'), ('c', 'C')) ) # Integer filtering @@ -1435,7 +1435,7 @@ def setUpTestData(cls): 'cf7': 'http://a.example.com', 'cf8': 'http://a.example.com', 'cf9': 'A', - 'cf10': ['A', 'X'], + 'cf10': ['A', 'B'], 'cf11': manufacturers[0].pk, 'cf12': [manufacturers[0].pk, manufacturers[3].pk], }), @@ -1449,7 +1449,7 @@ def setUpTestData(cls): 'cf7': 'http://b.example.com', 'cf8': 'http://b.example.com', 'cf9': 'B', - 'cf10': ['B', 'X'], + 'cf10': ['B', 'C'], 'cf11': manufacturers[1].pk, 'cf12': [manufacturers[1].pk, manufacturers[3].pk], }), @@ -1463,7 +1463,7 @@ def setUpTestData(cls): 'cf7': 'http://c.example.com', 'cf8': 'http://c.example.com', 'cf9': 'C', - 'cf10': ['C', 'X'], + 'cf10': None, 'cf11': manufacturers[2].pk, 'cf12': [manufacturers[2].pk, manufacturers[3].pk], }), @@ -1531,8 +1531,9 @@ def test_filter_select(self): self.assertEqual(self.filterset({'cf_cf9': ['A', 'B']}, self.queryset).qs.count(), 2) def test_filter_multiselect(self): - self.assertEqual(self.filterset({'cf_cf10': ['A', 'B']}, self.queryset).qs.count(), 2) - self.assertEqual(self.filterset({'cf_cf10': ['X']}, self.queryset).qs.count(), 3) + self.assertEqual(self.filterset({'cf_cf10': ['A']}, self.queryset).qs.count(), 1) + self.assertEqual(self.filterset({'cf_cf10': ['A', 'C']}, self.queryset).qs.count(), 2) + self.assertEqual(self.filterset({'cf_cf10': ['null']}, self.queryset).qs.count(), 1) def test_filter_object(self): manufacturer_ids = Manufacturer.objects.values_list('id', flat=True) diff --git a/netbox/utilities/filters.py b/netbox/utilities/filters.py index 1bf17beae32..72c9124a171 100644 --- a/netbox/utilities/filters.py +++ b/netbox/utilities/filters.py @@ -9,6 +9,7 @@ __all__ = ( 'ContentTypeFilter', 'MACAddressFilter', + 'MultiValueArrayFilter', 'MultiValueCharFilter', 'MultiValueDateFilter', 'MultiValueDateTimeFilter', @@ -85,6 +86,21 @@ class MultiValueTimeFilter(django_filters.MultipleChoiceFilter): field_class = multivalue_field_factory(forms.TimeField) +@extend_schema_field(OpenApiTypes.STR) +class MultiValueArrayFilter(django_filters.MultipleChoiceFilter): + field_class = multivalue_field_factory(forms.CharField) + + def __init__(self, *args, lookup_expr='contains', **kwargs): + # Set default lookup_expr to 'contains' + super().__init__(*args, lookup_expr=lookup_expr, **kwargs) + + def get_filter_predicate(self, v): + # If filtering for null values, ignore lookup_expr + if v is None: + return {self.field_name: None} + return super().get_filter_predicate(v) + + class MACAddressFilter(django_filters.CharFilter): pass