diff --git a/enumfields/fields.py b/enumfields/fields.py index bdb9988..e3c0d49 100644 --- a/enumfields/fields.py +++ b/enumfields/fields.py @@ -10,10 +10,28 @@ from .compat import import_string from .forms import EnumChoiceField -metaclass = models.SubfieldBase if django.VERSION < (1, 8) else type +class CastOnAssignDescriptor(object): + """ + A property descriptor which ensures that `field.to_python()` is called on _every_ assignment to the field. -class EnumFieldMixin(six.with_metaclass(metaclass)): + This used to be provided by the `django.db.models.subclassing.Creator` class, which in turn + was used by the deprecated-in-Django-1.10 `SubfieldBase` class, hence the reimplementation here. + """ + + def __init__(self, field): + self.field = field + + def __get__(self, obj, type=None): + if obj is None: + return self + return obj.__dict__[self.field.name] + + def __set__(self, obj, value): + obj.__dict__[self.field.name] = self.field.to_python(value) + + +class EnumFieldMixin(object): def __init__(self, enum, **options): if isinstance(enum, six.string_types): self.enum = import_string(enum) @@ -25,6 +43,10 @@ def __init__(self, enum, **options): super(EnumFieldMixin, self).__init__(**options) + def contribute_to_class(self, cls, name): + super(EnumFieldMixin, self).contribute_to_class(cls, name) + setattr(cls, name, CastOnAssignDescriptor(self)) + def to_python(self, value): if value is None or value == '': return None diff --git a/tests/test_django_admin.py b/tests/test_django_admin.py index 42a539c..ba677a1 100644 --- a/tests/test_django_admin.py +++ b/tests/test_django_admin.py @@ -5,8 +5,6 @@ import pytest from django.core.urlresolvers import reverse -from enumfields import EnumIntegerField - from .enums import Color, IntegerEnum, Taste, ZeroEnum from .models import MyModel @@ -71,9 +69,3 @@ def test_model_admin_filter(admin_client, q_color, q_taste, q_int_enum): count = int(re.search('(\d+) my model', response.content.decode('utf8')).group(1)) # and compare it to what we expect. assert count == MyModel.objects.filter(**lookup).count() - - -def test_django_admin_lookup_value_for_integer_enum_field(): - field = EnumIntegerField(Taste) - - assert field.get_prep_value(str(Taste.BITTER)) == 3, "get_prep_value should be able to convert from strings" diff --git a/tests/test_issue_60.py b/tests/test_issue_60.py new file mode 100644 index 0000000..c2449c7 --- /dev/null +++ b/tests/test_issue_60.py @@ -0,0 +1,34 @@ +import pytest + +from .models import MyModel + +try: + from .enums import Color # Use the new location of Color enum +except ImportError: + Color = MyModel.Color # Attempt the 0.7.4 location of color enum + + +@pytest.mark.django_db +def test_fields_value_is_enum_when_unsaved(): + obj = MyModel(color='r') + assert Color.RED == obj.color + + +@pytest.mark.django_db +def test_fields_value_is_enum_when_saved(): + obj = MyModel(color='r') + obj.save() + assert Color.RED == obj.color + + +@pytest.mark.django_db +def test_fields_value_is_enum_when_created(): + obj = MyModel.objects.create(color='r') + assert Color.RED == obj.color + + +@pytest.mark.django_db +def test_fields_value_is_enum_when_retrieved(): + MyModel.objects.create(color='r') + obj = MyModel.objects.first() + assert Color.RED == obj.color