diff --git a/AUTHORS.rst b/AUTHORS.rst index f1a3ea3..93d1003 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -18,5 +18,8 @@ Rinat Shigapov Ryan Kaskel Simon Meers sayane +Tony Aldridge +Travis Swicegood Trey Hunner zyegfryed +Filipe Ximenes diff --git a/CHANGES.rst b/CHANGES.rst index 3d9bb66..ae10b14 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,29 @@ CHANGES master (unreleased) ------------------- +* `get_subclass()` method is now available on both managers and + querysets. Thanks Travis Swicegood. Merge of GH-82. + +* Indexing into a ``Choices`` instance now translates database representations + to human-readable choice names, rather than simply indexing into an array of + choice tuples. (Indexing into ``Choices`` was previously not documented.) + +* Fix bug in `InheritanceManager` with grandchild classes on Django 1.6+; + `select_subclasses('child', 'child__grandchild')` would only ever get to the + child class. Thanks Keryn Knight for report and proposed fix. + +* MonitorField now accepts a 'when' parameter. It will update only when the field + changes to one of the values specified. + + +1.5.0 (2013.08.29) +------------------ + +* `Choices` now accepts option-groupings. Fixes GH-14. + +* `Choices` can now be added to other `Choices` or to any iterable, and can be + compared for equality with itself. Thanks Tony Aldridge. (Merge of GH-76.) + * `Choices` now `__contains__` its Python identifier values. Thanks Keryn Knight. (Merge of GH-69). @@ -13,6 +36,8 @@ master (unreleased) * Fixed ``FieldTracker`` usage on inherited models. Fixes GH-57. +* Added mutable field support to ``FieldTracker`` (Merge of GH-73, fixes GH-74) + 1.4.0 (2013.06.03) ------------------ diff --git a/README.rst b/README.rst index ccc554c..34aa4b3 100644 --- a/README.rst +++ b/README.rst @@ -42,3 +42,4 @@ pull requests tracked in it are closed, but all new issues should be filed at GitHub.) .. _BitBucket: https://bitbucket.org/carljm/django-model-utils/overview + diff --git a/docs/fields.rst b/docs/fields.rst index 9a0dbaf..5da4ab4 100644 --- a/docs/fields.rst +++ b/docs/fields.rst @@ -53,6 +53,19 @@ field changes: (A ``MonitorField`` can monitor any type of field for changes, not only a ``StatusField``.) +If a list is passed to the ``when`` parameter, the field will only +update when it matches one of the specified values: + +.. code-block:: python + + from model_utils.fields import MonitorField, StatusField + + class Article(models.Model): + STATUS = Choices('draft', 'published') + + status = StatusField() + published_at = MonitorField(monitor='status', when=['published']) + SplitField ---------- diff --git a/docs/utilities.rst b/docs/utilities.rst index 0c3ca61..9f9bb93 100644 --- a/docs/utilities.rst +++ b/docs/utilities.rst @@ -15,7 +15,6 @@ Choices class Article(models.Model): STATUS = Choices('draft', 'published') - # ... status = models.CharField(choices=STATUS, default=STATUS.draft, max_length=20) A ``Choices`` object is initialized with any number of choices. In the @@ -34,7 +33,6 @@ representation. In this case you can provide choices as two-tuples: class Article(models.Model): STATUS = Choices(('draft', _('draft')), ('published', _('published'))) - # ... status = models.CharField(choices=STATUS, default=STATUS.draft, max_length=20) But what if your database representation of choices is constrained in @@ -52,7 +50,38 @@ the third is the human-readable version: class Article(models.Model): STATUS = Choices((0, 'draft', _('draft')), (1, 'published', _('published'))) - # ... + status = models.IntegerField(choices=STATUS, default=STATUS.draft) + +You can index into a ``Choices`` instance to translate a database +representation to its display name: + +.. code-block:: python + + status_display = Article.STATUS[article.status] + +Option groups can also be used with ``Choices``; in that case each +argument is a tuple consisting of the option group name and a list of +options, where each option in the list is either a string, a two-tuple, +or a triple as outlined above. For example: + +.. code-block:: python + + from model_utils import Choices + + class Article(models.Model): + STATUS = Choices(('Visible', ['new', 'archived']), ('Invisible', ['draft', 'deleted'])) + +Choices can be concatenated with the ``+`` operator, both to other Choices +instances and other iterable objects that could be converted into Choices: + +.. code-block:: python + + from model_utils import Choices + + GENERIC_CHOICES = Choices((0, 'draft', _('draft')), (1, 'published', _('published'))) + + class Article(models.Model): + STATUS = GENERIC_CHOICES + [(2, 'featured', _('featured'))] status = models.IntegerField(choices=STATUS, default=STATUS.draft) diff --git a/model_utils/__init__.py b/model_utils/__init__.py index 36586ec..a3da699 100644 --- a/model_utils/__init__.py +++ b/model_utils/__init__.py @@ -1,4 +1,4 @@ from .choices import Choices from .tracker import FieldTracker, ModelTracker -__version__ = '1.4.0.post1' +__version__ = '1.5.0.post1' diff --git a/model_utils/choices.py b/model_utils/choices.py index b9c1c52..d48ba90 100644 --- a/model_utils/choices.py +++ b/model_utils/choices.py @@ -36,52 +36,118 @@ class Choices(object): identifier, the database representation itself is available as an attribute on the ``Choices`` object, returning itself.) + Option groups can also be used with ``Choices``; in that case each + argument is a tuple consisting of the option group name and a list + of options, where each option in the list is either a string, a + two-tuple, or a triple as outlined above. + """ def __init__(self, *choices): - self._full = [] - self._choices = [] - self._choice_dict = {} - for choice in self.equalize(choices): - self._full.append(choice) - self._choices.append((choice[0], choice[2])) - self._choice_dict[choice[1]] = choice[0] + # list of choices expanded to triples - can include optgroups + self._triples = [] + # list of choices as (db, human-readable) - can include optgroups + self._doubles = [] + # dictionary mapping db representation to human-readable + self._display_map = {} + # dictionary mapping Python identifier to db representation + self._identifier_map = {} + # set of db representations + self._db_values = set() + + self._process(choices) + + + def _store(self, triple, triple_collector, double_collector): + self._identifier_map[triple[1]] = triple[0] + self._display_map[triple[0]] = triple[2] + self._db_values.add(triple[0]) + triple_collector.append(triple) + double_collector.append((triple[0], triple[2])) + + + def _process(self, choices, triple_collector=None, double_collector=None): + if triple_collector is None: + triple_collector = self._triples + if double_collector is None: + double_collector = self._doubles + + store = lambda c: self._store(c, triple_collector, double_collector) - def equalize(self, choices): for choice in choices: if isinstance(choice, (list, tuple)): if len(choice) == 3: - yield choice + store(choice) elif len(choice) == 2: - yield (choice[0], choice[0], choice[1]) + if isinstance(choice[1], (list, tuple)): + # option group + group_name = choice[0] + subchoices = choice[1] + tc = [] + triple_collector.append((group_name, tc)) + dc = [] + double_collector.append((group_name, dc)) + self._process(subchoices, tc, dc) + else: + store((choice[0], choice[0], choice[1])) else: - raise ValueError("Choices can't handle a list/tuple of length %s, only 2 or 3" - % len(choice)) + raise ValueError( + "Choices can't take a list of length %s, only 2 or 3" + % len(choice) + ) else: - yield (choice, choice, choice) + store((choice, choice, choice)) + def __len__(self): - return len(self._choices) + return len(self._doubles) + def __iter__(self): - return iter(self._choices) + return iter(self._doubles) + def __getattr__(self, attname): try: - return self._choice_dict[attname] + return self._identifier_map[attname] except KeyError: raise AttributeError(attname) - def __getitem__(self, index): - return self._choices[index] + + def __getitem__(self, key): + return self._display_map[key] + + + def __add__(self, other): + if isinstance(other, self.__class__): + other = other._triples + else: + other = list(other) + return Choices(*(self._triples + other)) + + + def __radd__(self, other): + # radd is never called for matching types, so we don't check here + other = list(other) + return Choices(*(other + self._triples)) + + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._triples == other._triples + return False + def __repr__(self): - return '%s(%s)' % (self.__class__.__name__, - ', '.join(("%s" % repr(i) for i in self._full))) + return '%s(%s)' % ( + self.__class__.__name__, + ', '.join(("%s" % repr(i) for i in self._triples)) + ) + def __contains__(self, item): - if item in self._choice_dict.values(): - return True + return item in self._db_values + def __deepcopy__(self, memo): - return self.__class__(*copy.deepcopy(self._full, memo)) + return self.__class__(*copy.deepcopy(self._triples, memo)) diff --git a/model_utils/fields.py b/model_utils/fields.py index 5aebb3c..83e9e9a 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -82,6 +82,10 @@ class MonitorField(models.DateTimeField): raise TypeError( '%s requires a "monitor" argument' % self.__class__.__name__) self.monitor = monitor + when = kwargs.pop('when', None) + if when is not None: + when = set(when) + self.when = when super(MonitorField, self).__init__(*args, **kwargs) def contribute_to_class(self, cls, name): @@ -101,8 +105,9 @@ class MonitorField(models.DateTimeField): previous = getattr(model_instance, self.monitor_attname, None) current = self.get_monitored_value(model_instance) if previous != current: - setattr(model_instance, self.attname, value) - self._save_initial(model_instance.__class__, model_instance) + if self.when is None or current in self.when: + setattr(model_instance, self.attname, value) + self._save_initial(model_instance.__class__, model_instance) return super(MonitorField, self).pre_save(model_instance, add) diff --git a/model_utils/managers.py b/model_utils/managers.py index 231bffa..3c72585 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -46,9 +46,12 @@ class InheritanceQuerySet(QuerySet): def iterator(self): iter = super(InheritanceQuerySet, self).iterator() if getattr(self, 'subclasses', False): + # sort the subclass names longest first, + # so with 'a' and 'a__b' it goes as deep as possible + subclasses = sorted(self.subclasses, key=len, reverse=True) for obj in iter: sub_obj = None - for s in self.subclasses: + for s in subclasses: sub_obj = self._get_sub_obj_recurse(obj, s) if sub_obj: break @@ -93,6 +96,8 @@ class InheritanceQuerySet(QuerySet): else: return node + def get_subclass(self, *args, **kwargs): + return self.select_subclasses().get(*args, **kwargs) class InheritanceManager(models.Manager): @@ -105,8 +110,7 @@ class InheritanceManager(models.Manager): return self.get_query_set().select_subclasses(*subclasses) def get_subclass(self, *args, **kwargs): - return self.get_query_set().select_subclasses().get(*args, **kwargs) - + return self.get_query_set().get_subclass(*args, **kwargs) class QueryManager(models.Manager): diff --git a/model_utils/tests/fields.py b/model_utils/tests/fields.py new file mode 100644 index 0000000..3f1503a --- /dev/null +++ b/model_utils/tests/fields.py @@ -0,0 +1,26 @@ +from django.db import models +from django.utils.six import with_metaclass, string_types + + +class MutableField(with_metaclass(models.SubfieldBase, models.TextField)): + + def to_python(self, value): + if value == '': + return None + + try: + if isinstance(value, string_types): + return [int(i) for i in value.split(',')] + except ValueError: + pass + + return value + + def get_db_prep_save(self, value, connection): + if value is None: + return '' + + if isinstance(value, list): + value = ','.join((str(i) for i in value)) + + return super(MutableField, self).get_db_prep_save(value, connection) diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index 06b3dea..80537ed 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -1,10 +1,14 @@ +from __future__ import unicode_literals + from django.db import models +from django.utils.encoding import python_2_unicode_compatible from django.utils.translation import ugettext_lazy as _ from model_utils.models import TimeStampedModel, StatusModel, TimeFramedModel from model_utils.tracker import FieldTracker, ModelTracker from model_utils.managers import QueryManager, InheritanceManager, PassThroughManager from model_utils.fields import SplitField, MonitorField, StatusField +from model_utils.tests.fields import MutableField from model_utils import Choices @@ -14,6 +18,7 @@ class InheritanceManagerTestRelated(models.Model): +@python_2_unicode_compatible class InheritanceManagerTestParent(models.Model): # FileField is just a handy descriptor-using field. Refs #6. non_related_field_using_descriptor = models.FileField(upload_to="test") @@ -23,11 +28,17 @@ class InheritanceManagerTestParent(models.Model): objects = InheritanceManager() + def __str__(self): + return "%s(%s)" % ( + self.__class__.__name__[len('InheritanceManagerTest'):], + self.pk, + ) + + class InheritanceManagerTestChild1(InheritanceManagerTestParent): non_related_field_using_descriptor_2 = models.FileField(upload_to="test") normal_field_2 = models.TextField() - pass class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1): @@ -66,6 +77,18 @@ class Monitored(models.Model): +class MonitorWhen(models.Model): + name = models.CharField(max_length=25) + name_changed = MonitorField(monitor="name", when=["Jose", "Maria"]) + + + +class MonitorWhenEmpty(models.Model): + name = models.CharField(max_length=25) + name_changed = MonitorField(monitor="name", when=[]) + + + class Status(StatusModel): STATUS = Choices( ("active", _("active")), @@ -234,6 +257,7 @@ class Spot(models.Model): class Tracked(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() + mutable = MutableField() tracker = FieldTracker() @@ -278,6 +302,7 @@ class InheritedTracked(Tracked): class ModelTracked(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() + mutable = MutableField() tracker = ModelTracker() diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index de3d1ce..2d86596 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -23,12 +23,11 @@ from model_utils.tests.models import ( InheritanceManagerTestGrandChild1_2, InheritanceManagerTestParent, InheritanceManagerTestChild1, InheritanceManagerTestChild2, TimeStamp, Post, Article, Status, - StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded, + StatusPlainTuple, TimeFrame, Monitored, MonitorWhen, MonitorWhenEmpty, StatusManagerAdded, TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot, ModelTracked, ModelTrackedFK, ModelTrackedNotDefault, ModelTrackedMultiple, InheritedModelTracked, - Tracked, TrackedFK, TrackedNotDefault, TrackedNonFieldAttr, - TrackedMultiple, InheritedTracked, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled) - + Tracked, TrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple, + InheritedTracked, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled) class GetExcerptTests(TestCase): @@ -171,6 +170,75 @@ class MonitorFieldTests(TestCase): MonitorField() + +class MonitorWhenFieldTests(TestCase): + """ + Will record changes only when name is 'Jose' or 'Maria' + """ + def setUp(self): + self.instance = MonitorWhen(name='Charlie') + self.created = self.instance.name_changed + + + def test_save_no_change(self): + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + + def test_save_changed_to_Jose(self): + self.instance.name = 'Jose' + self.instance.save() + self.assertTrue(self.instance.name_changed > self.created) + + + def test_save_changed_to_Maria(self): + self.instance.name = 'Maria' + self.instance.save() + self.assertTrue(self.instance.name_changed > self.created) + + + def test_save_changed_to_Pedro(self): + self.instance.name = 'Pedro' + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + + def test_double_save(self): + self.instance.name = 'Jose' + self.instance.save() + changed = self.instance.name_changed + self.instance.save() + self.assertEqual(self.instance.name_changed, changed) + + + +class MonitorWhenEmptyFieldTests(TestCase): + """ + Monitor should never be updated id when is an empty list. + """ + def setUp(self): + self.instance = MonitorWhenEmpty(name='Charlie') + self.created = self.instance.name_changed + + + def test_save_no_change(self): + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + + def test_save_changed_to_Jose(self): + self.instance.name = 'Jose' + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + + def test_save_changed_to_Maria(self): + self.instance.name = 'Maria' + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + + class StatusFieldTests(TestCase): def test_status_with_default_filled(self): @@ -201,7 +269,7 @@ class ChoicesTests(TestCase): def test_indexing(self): - self.assertEqual(self.STATUS[1], ('PUBLISHED', 'PUBLISHED')) + self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') def test_iteration(self): @@ -223,10 +291,12 @@ class ChoicesTests(TestCase): with self.assertRaises(ValueError): Choices(('a',)) + def test_contains_value(self): self.assertTrue('PUBLISHED' in self.STATUS) self.assertTrue('DRAFT' in self.STATUS) + def test_doesnt_contain_value(self): self.assertFalse('UNPUBLISHED' in self.STATUS) @@ -236,6 +306,32 @@ class ChoicesTests(TestCase): list(copy.deepcopy(self.STATUS))) + def test_equality(self): + self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED')) + + + def test_inequality(self): + self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED']) + self.assertNotEqual(self.STATUS, Choices('DRAFT')) + + + def test_composability(self): + self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS) + self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS) + self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS) + + + def test_option_groups(self): + c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) + self.assertEqual( + list(c), + [ + ('group a', [('one', 'one'), ('two', 'two')]), + ('group b', [('three', 'three')]), + ], + ) + + class LabelChoicesTests(ChoicesTests): def setUp(self): self.STATUS = Choices( @@ -254,7 +350,7 @@ class LabelChoicesTests(ChoicesTests): def test_indexing(self): - self.assertEqual(self.STATUS[1], ('PUBLISHED', 'is published')) + self.assertEqual(self.STATUS['PUBLISHED'], 'is published') def test_default(self): @@ -269,6 +365,23 @@ class LabelChoicesTests(ChoicesTests): self.assertEqual(len(self.STATUS), 3) + def test_equality(self): + self.assertEqual(self.STATUS, Choices( + ('DRAFT', 'is draft'), + ('PUBLISHED', 'is published'), + 'DELETED', + )) + + + def test_inequality(self): + self.assertNotEqual(self.STATUS, [ + ('DRAFT', 'is draft'), + ('PUBLISHED', 'is published'), + 'DELETED' + ]) + self.assertNotEqual(self.STATUS, Choices('DRAFT')) + + def test_repr(self): self.assertEqual(repr(self.STATUS), "Choices" + repr(( ('DRAFT', 'DRAFT', 'is draft'), @@ -276,6 +389,7 @@ class LabelChoicesTests(ChoicesTests): ('DELETED', 'DELETED', 'DELETED'), ))) + def test_contains_value(self): self.assertTrue('PUBLISHED' in self.STATUS) self.assertTrue('DRAFT' in self.STATUS) @@ -283,13 +397,46 @@ class LabelChoicesTests(ChoicesTests): # and the internal representation are both DELETED. self.assertTrue('DELETED' in self.STATUS) + def test_doesnt_contain_value(self): self.assertFalse('UNPUBLISHED' in self.STATUS) + def test_doesnt_contain_display_value(self): self.assertFalse('is draft' in self.STATUS) + def test_composability(self): + self.assertEqual( + Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'), + self.STATUS + ) + + self.assertEqual( + (('DRAFT', 'is draft',),) + Choices(('PUBLISHED', 'is published'), 'DELETED'), + self.STATUS + ) + + self.assertEqual( + Choices(('DRAFT', 'is draft',)) + (('PUBLISHED', 'is published'), 'DELETED'), + self.STATUS + ) + + + def test_option_groups(self): + c = Choices( + ('group a', [(1, 'one'), (2, 'two')]), + ['group b', ((3, 'three'),)] + ) + self.assertEqual( + list(c), + [ + ('group a', [(1, 'one'), (2, 'two')]), + ('group b', [(3, 'three')]), + ], + ) + + class IdentifierChoicesTests(ChoicesTests): def setUp(self): @@ -307,7 +454,7 @@ class IdentifierChoicesTests(ChoicesTests): def test_indexing(self): - self.assertEqual(self.STATUS[1], (1, 'is published')) + self.assertEqual(self.STATUS[1], 'is published') def test_getattr(self): @@ -325,20 +472,88 @@ class IdentifierChoicesTests(ChoicesTests): (2, 'DELETED', 'is deleted'), ))) + def test_contains_value(self): self.assertTrue(0 in self.STATUS) self.assertTrue(1 in self.STATUS) self.assertTrue(2 in self.STATUS) + def test_doesnt_contain_value(self): self.assertFalse(3 in self.STATUS) + def test_doesnt_contain_display_value(self): self.assertFalse('is draft' in self.STATUS) + def test_doesnt_contain_python_attr(self): self.assertFalse('PUBLISHED' in self.STATUS) + + def test_equality(self): + self.assertEqual(self.STATUS, Choices( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published'), + (2, 'DELETED', 'is deleted') + )) + + + def test_inequality(self): + self.assertNotEqual(self.STATUS, [ + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published'), + (2, 'DELETED', 'is deleted') + ]) + self.assertNotEqual(self.STATUS, Choices('DRAFT')) + + + def test_composability(self): + self.assertEqual( + Choices( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published') + ) + Choices( + (2, 'DELETED', 'is deleted'), + ), + self.STATUS + ) + + self.assertEqual( + Choices( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published') + ) + ( + (2, 'DELETED', 'is deleted'), + ), + self.STATUS + ) + + self.assertEqual( + ( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published') + ) + Choices( + (2, 'DELETED', 'is deleted'), + ), + self.STATUS + ) + + + def test_option_groups(self): + c = Choices( + ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), + ['group b', ((3, 'THREE', 'three'),)] + ) + self.assertEqual( + list(c), + [ + ('group a', [(1, 'one'), (2, 'two')]), + ('group b', [(3, 'three')]), + ], + ) + + class InheritanceManagerTests(TestCase): def setUp(self): self.child1 = InheritanceManagerTestChild1.objects.create() @@ -401,8 +616,26 @@ class InheritanceManagerTests(TestCase): self.assertEqual( set( self.get_manager().select_subclasses( - "inheritancemanagertestchild1__" - "inheritancemanagertestgrandchild1" + "inheritancemanagertestchild1__inheritancemanagertestgrandchild1" + ) + ), + children, + ) + + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_children_and_grandchildren(self): + children = set([ + self.child1, + InheritanceManagerTestParent(pk=self.child2.pk), + self.grandchild1, + InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), + ]) + self.assertEqual( + set( + self.get_manager().select_subclasses( + "inheritancemanagertestchild1", + "inheritancemanagertestchild1__inheritancemanagertestgrandchild1" ) ), children, @@ -415,6 +648,12 @@ class InheritanceManagerTests(TestCase): self.child1) + def test_get_subclass_on_queryset(self): + self.assertEqual( + self.get_manager().all().get_subclass(pk=self.child1.pk), + self.child1) + + def test_prior_select_related(self): with self.assertNumQueries(1): obj = self.get_manager().select_related( @@ -767,52 +1006,61 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertChanged(name=None, number=None) self.instance.name = '' self.assertChanged(name=None, number=None) + self.instance.mutable = [1,2,3] + self.assertChanged(name=None, number=None, mutable=None) def test_pre_save_has_changed(self): - self.assertHasChanged(name=True, number=False) + self.assertHasChanged(name=True, number=False, mutable=False) self.instance.name = 'new age' - self.assertHasChanged(name=True, number=False) + self.assertHasChanged(name=True, number=False, mutable=False) self.instance.number = 7 self.assertHasChanged(name=True, number=True) + self.instance.mutable = [1,2,3] + self.assertHasChanged(name=True, number=True, mutable=True) def test_first_save(self): - self.assertHasChanged(name=True, number=False) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='', number=None, id=None) + self.assertHasChanged(name=True, number=False, mutable=False) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='', number=None, id=None, mutable=None) self.assertChanged(name=None) self.instance.name = 'retro' self.instance.number = 4 - self.assertHasChanged(name=True, number=True) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='retro', number=4, id=None) - self.assertChanged(name=None, number=None) + self.instance.mutable = [1,2,3] + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) + self.assertChanged(name=None, number=None, mutable=None) # Django 1.4 doesn't have update_fields if django.VERSION >= (1, 5, 0): self.instance.save(update_fields=[]) - self.assertHasChanged(name=True, number=True) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='retro', number=4, id=None) - self.assertChanged(name=None, number=None) + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) + self.assertChanged(name=None, number=None, mutable=None) with self.assertRaises(ValueError): self.instance.save(update_fields=['number']) def test_post_save_has_changed(self): - self.update_instance(name='retro', number=4) - self.assertHasChanged(name=False, number=False) + self.update_instance(name='retro', number=4, mutable=[1,2,3]) + self.assertHasChanged(name=False, number=False, mutable=False) self.instance.name = 'new age' self.assertHasChanged(name=True, number=False) self.instance.number = 8 self.assertHasChanged(name=True, number=True) + self.instance.mutable[1] = 4 + self.assertHasChanged(name=True, number=True, mutable=True) self.instance.name = 'retro' - self.assertHasChanged(name=False, number=True) + self.assertHasChanged(name=False, number=True, mutable=True) def test_post_save_previous(self): - self.update_instance(name='retro', number=4) + self.update_instance(name='retro', number=4, mutable=[1,2,3]) self.instance.name = 'new age' - self.assertPrevious(name='retro', number=4) + self.assertPrevious(name='retro', number=4, mutable=[1,2,3]) + self.instance.mutable[1] = 4 + self.assertPrevious(name='retro', number=4, mutable=[1,2,3]) def test_post_save_changed(self): - self.update_instance(name='retro', number=4) + self.update_instance(name='retro', number=4, mutable=[1,2,3]) self.assertChanged() self.instance.name = 'new age' self.assertChanged(name='retro') @@ -820,36 +1068,48 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertChanged(name='retro', number=4) self.instance.name = 'retro' self.assertChanged(number=4) + self.instance.mutable[1] = 4 + self.assertChanged(number=4, mutable=[1,2,3]) + self.instance.mutable = [1,2,3] + self.assertChanged(number=4) def test_current(self): - self.assertCurrent(id=None, name='', number=None) + self.assertCurrent(id=None, name='', number=None, mutable=None) self.instance.name = 'new age' - self.assertCurrent(id=None, name='new age', number=None) + self.assertCurrent(id=None, name='new age', number=None, mutable=None) self.instance.number = 8 - self.assertCurrent(id=None, name='new age', number=8) + self.assertCurrent(id=None, name='new age', number=8, mutable=None) + self.instance.mutable = [1,2,3] + self.assertCurrent(id=None, name='new age', number=8, mutable=[1,2,3]) + self.instance.mutable[1] = 4 + self.assertCurrent(id=None, name='new age', number=8, mutable=[1,4,3]) self.instance.save() - self.assertCurrent(id=self.instance.id, name='new age', number=8) + self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1,4,3]) @skipUnless( django.VERSION >= (1, 5, 0), "Django 1.4 doesn't have update_fields") def test_update_fields(self): - self.update_instance(name='retro', number=4) + self.update_instance(name='retro', number=4, mutable=[1,2,3]) self.assertChanged() self.instance.name = 'new age' self.instance.number = 8 - self.assertChanged(name='retro', number=4) + self.instance.mutable = [4,5,6] + self.assertChanged(name='retro', number=4, mutable=[1,2,3]) self.instance.save(update_fields=[]) - self.assertChanged(name='retro', number=4) + self.assertChanged(name='retro', number=4, mutable=[1,2,3]) self.instance.save(update_fields=['name']) in_db = self.tracked_class.objects.get(id=self.instance.id) self.assertEqual(in_db.name, self.instance.name) self.assertNotEqual(in_db.number, self.instance.number) - self.assertChanged(number=4) + self.assertChanged(number=4, mutable=[1,2,3]) self.instance.save(update_fields=['number']) + self.assertChanged(mutable=[1,2,3]) + self.instance.save(update_fields=['mutable']) self.assertChanged() in_db = self.tracked_class.objects.get(id=self.instance.id) self.assertEqual(in_db.name, self.instance.name) self.assertEqual(in_db.number, self.instance.number) + self.assertEqual(in_db.mutable, self.instance.mutable) class FieldTrackedModelCustomTests(FieldTrackerTestCase, @@ -1141,24 +1401,27 @@ class ModelTrackerTests(FieldTrackerTests): self.assertChanged() self.instance.name = '' self.assertChanged() + self.instance.mutable = [1,2,3] + self.assertChanged() def test_first_save(self): - self.assertHasChanged(name=True, number=True) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='', number=None, id=None) + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='', number=None, id=None, mutable=None) self.assertChanged() self.instance.name = 'retro' self.instance.number = 4 - self.assertHasChanged(name=True, number=True) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='retro', number=4, id=None) + self.instance.mutable = [1,2,3] + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) self.assertChanged() # Django 1.4 doesn't have update_fields if django.VERSION >= (1, 5, 0): self.instance.save(update_fields=[]) - self.assertHasChanged(name=True, number=True) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='retro', number=4, id=None) + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) self.assertChanged() with self.assertRaises(ValueError): self.instance.save(update_fields=['number']) diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 6c29bc0..77eaaa4 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -1,4 +1,7 @@ from __future__ import unicode_literals + +from copy import deepcopy + from django.db import models from django.core.exceptions import FieldError @@ -20,8 +23,12 @@ class FieldInstanceTracker(object): else: self.saved_data.update(**self.current(fields=fields)) + # preventing mutable fields side effects + for field, field_value in self.saved_data.items(): + self.saved_data[field] = deepcopy(field_value) + def current(self, fields=None): - """Return dict of current values for all tracked fields""" + """Returns dict of current values for all tracked fields""" if fields is None: fields = self.fields return dict((f, self.get_field_value(f)) for f in fields) @@ -34,7 +41,7 @@ class FieldInstanceTracker(object): raise FieldError('field "%s" not tracked' % field) def previous(self, field): - """Return currently saved value of given field""" + """Returns currently saved value of given field""" return self.saved_data.get(field) def changed(self): @@ -54,7 +61,7 @@ class FieldTracker(object): self.fields = fields def get_field_map(self, cls): - """Return dict mapping fields names to model attribute names""" + """Returns dict mapping fields names to model attribute names""" field_map = dict((field, field) for field in self.fields) all_fields = dict((f.name, f.attname) for f in cls._meta.local_fields) field_map.update(**dict((k, v) for (k, v) in all_fields.items()