diff --git a/enumfields/fields.py b/enumfields/fields.py index bdb9988..e3c0d49 100644 --- a/enumfields/fields.py +++ b/enumfields/fields.py @@ -10,10 +10,28 @@ from .compat import import_string from .forms import EnumChoiceField -metaclass = models.SubfieldBase if django.VERSION < (1, 8) else type +class CastOnAssignDescriptor(object): + """ + A property descriptor which ensures that `field.to_python()` is called on _every_ assignment to the field. -class EnumFieldMixin(six.with_metaclass(metaclass)): + This used to be provided by the `django.db.models.subclassing.Creator` class, which in turn + was used by the deprecated-in-Django-1.10 `SubfieldBase` class, hence the reimplementation here. + """ + + def __init__(self, field): + self.field = field + + def __get__(self, obj, type=None): + if obj is None: + return self + return obj.__dict__[self.field.name] + + def __set__(self, obj, value): + obj.__dict__[self.field.name] = self.field.to_python(value) + + +class EnumFieldMixin(object): def __init__(self, enum, **options): if isinstance(enum, six.string_types): self.enum = import_string(enum) @@ -25,6 +43,10 @@ def __init__(self, enum, **options): super(EnumFieldMixin, self).__init__(**options) + def contribute_to_class(self, cls, name): + super(EnumFieldMixin, self).contribute_to_class(cls, name) + setattr(cls, name, CastOnAssignDescriptor(self)) + def to_python(self, value): if value is None or value == '': return None diff --git a/tests/test_issue_60.py b/tests/test_issue_60.py new file mode 100644 index 0000000..576bf73 --- /dev/null +++ b/tests/test_issue_60.py @@ -0,0 +1,34 @@ +import pytest + +from .models import MyModel + +try: + from .enums import Color # Use the new location of Color enum +except ImportError: + Color = MyModel.Color # Attempt the 0.7.4 location of color enum + + +@pytest.mark.django_db +def test_fields_value_is_enum_when_unsaved(): + obj = MyModel(color='r') + assert Color.RED == obj.color + + +@pytest.mark.django_db +def test_fields_value_is_enum_when_saved(): + obj = MyModel(color='r') + obj.save() + assert Color.RED == obj.color + + +@pytest.mark.django_db +def test_fields_value_is_enum_when_created(): + obj = MyModel.objects.create(color='r') + assert Color.RED == obj.color + + +@pytest.mark.django_db +def test_fields_value_is_enum_when_retrieved(): + MyModel.objects.create(color='r') + obj = MyModel.objects.all()[:1][0] # .first() not available on all Djangoes + assert Color.RED == obj.color