diff --git a/model_utils/tests/fields.py b/model_utils/tests/fields.py index 3f1503a..7c29aa4 100644 --- a/model_utils/tests/fields.py +++ b/model_utils/tests/fields.py @@ -1,26 +1,43 @@ +import django from django.db import models from django.utils.six import with_metaclass, string_types -class MutableField(with_metaclass(models.SubfieldBase, models.TextField)): +def mutable_from_db(value): + if value == '': + return None + try: + if isinstance(value, string_types): + return [int(i) for i in value.split(',')] + except ValueError: + pass + return value - def to_python(self, value): - if value == '': - return None - try: - if isinstance(value, string_types): - return [int(i) for i in value.split(',')] - except ValueError: - pass +def mutable_to_db(value): + if value is None: + return '' + if isinstance(value, list): + value = ','.join((str(i) for i in value)) + return str(value) - return value - def get_db_prep_save(self, value, connection): - if value is None: - return '' +if django.VERSION >= (1, 9, 0): + class MutableField(models.TextField): + def to_python(self, value): + return mutable_from_db(value) - if isinstance(value, list): - value = ','.join((str(i) for i in value)) + def from_db_value(self, value, expression, connection, context): + return mutable_from_db(value) - return super(MutableField, self).get_db_prep_save(value, connection) + def get_db_prep_save(self, value, connection): + value = super(MutableField, self).get_db_prep_save(value, connection) + return mutable_to_db(value) +else: + class MutableField(with_metaclass(models.SubfieldBase, models.TextField)): + def to_python(self, value): + return mutable_from_db(value) + + def get_db_prep_save(self, value, connection): + value = mutable_to_db(value) + return super(MutableField, self).get_db_prep_save(value, connection) diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index b1903ed..6b82541 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -204,7 +204,7 @@ class FeaturedManager(models.Manager): class Tracked(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() - mutable = MutableField() + mutable = MutableField(default=None) tracker = FieldTracker() @@ -249,7 +249,7 @@ class InheritedTracked(Tracked): class ModelTracked(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() - mutable = MutableField() + mutable = MutableField(default=None) tracker = ModelTracker()