Merge branch 'master' into fix-choices-deepcopy

* master: (23 commits)
  only accepting iterables to the when field
  adding 'when' parameter to MonitorField
  Update AUTHORS and changelog.
  Add test to verify get_subclass() on QuerySet
  Refactor to make sure get_subclass() is on QuerySet
  Fixed indexing into Choices so its useful.
  Fix bug with child/grandchild select_subclasses in Django 1.6+; thanks Keryn Knight.
  fixed code block
  Bump version for dev.
  Bump version for 1.5.0 release.
  Add option-groups capability to Choices.
  Add Changelog note about Choices equality/addition.
  Added tests to improve coverage
  Alphabetised authors
  Removed redundant inequality method on Choices
  Moved documentation for Choices field to the right place
  Corrected typo
  Added self to Authors file
  Added equality methods to Choices objects, and overrode + for Choices for easy concatenation with other Choices and choice-like iterables. Also wrote tests for them, and extended the readme to reflect this
  Fix typo noted by @silonov
  ...

Conflicts:
	model_utils/choices.py
This commit is contained in:
Carl Meyer 2013-10-11 13:23:56 -06:00
commit 3211c92a4d
13 changed files with 548 additions and 81 deletions

View file

@ -18,5 +18,8 @@ Rinat Shigapov <rinatshigapov@gmail.com>
Ryan Kaskel <dev@ryankaskel.com>
Simon Meers <simon@simonmeers.com>
sayane
Tony Aldridge <zaragopha@hotmail.com>
Travis Swicegood <travis@domain51.com>
Trey Hunner <trey@treyhunner.com>
zyegfryed
Filipe Ximenes <filipeximenes@gmail.com>

View file

@ -4,6 +4,29 @@ CHANGES
master (unreleased)
-------------------
* `get_subclass()` method is now available on both managers and
querysets. Thanks Travis Swicegood. Merge of GH-82.
* Indexing into a ``Choices`` instance now translates database representations
to human-readable choice names, rather than simply indexing into an array of
choice tuples. (Indexing into ``Choices`` was previously not documented.)
* Fix bug in `InheritanceManager` with grandchild classes on Django 1.6+;
`select_subclasses('child', 'child__grandchild')` would only ever get to the
child class. Thanks Keryn Knight for report and proposed fix.
* MonitorField now accepts a 'when' parameter. It will update only when the field
changes to one of the values specified.
1.5.0 (2013.08.29)
------------------
* `Choices` now accepts option-groupings. Fixes GH-14.
* `Choices` can now be added to other `Choices` or to any iterable, and can be
compared for equality with itself. Thanks Tony Aldridge. (Merge of GH-76.)
* `Choices` now `__contains__` its Python identifier values. Thanks Keryn
Knight. (Merge of GH-69).
@ -13,6 +36,8 @@ master (unreleased)
* Fixed ``FieldTracker`` usage on inherited models. Fixes GH-57.
* Added mutable field support to ``FieldTracker`` (Merge of GH-73, fixes GH-74)
1.4.0 (2013.06.03)
------------------

View file

@ -42,3 +42,4 @@ pull requests tracked in it are closed, but all new issues should be filed at
GitHub.)
.. _BitBucket: https://bitbucket.org/carljm/django-model-utils/overview

View file

@ -53,6 +53,19 @@ field changes:
(A ``MonitorField`` can monitor any type of field for changes, not only a
``StatusField``.)
If a list is passed to the ``when`` parameter, the field will only
update when it matches one of the specified values:
.. code-block:: python
from model_utils.fields import MonitorField, StatusField
class Article(models.Model):
STATUS = Choices('draft', 'published')
status = StatusField()
published_at = MonitorField(monitor='status', when=['published'])
SplitField
----------

View file

@ -15,7 +15,6 @@ Choices
class Article(models.Model):
STATUS = Choices('draft', 'published')
# ...
status = models.CharField(choices=STATUS, default=STATUS.draft, max_length=20)
A ``Choices`` object is initialized with any number of choices. In the
@ -34,7 +33,6 @@ representation. In this case you can provide choices as two-tuples:
class Article(models.Model):
STATUS = Choices(('draft', _('draft')), ('published', _('published')))
# ...
status = models.CharField(choices=STATUS, default=STATUS.draft, max_length=20)
But what if your database representation of choices is constrained in
@ -52,7 +50,38 @@ the third is the human-readable version:
class Article(models.Model):
STATUS = Choices((0, 'draft', _('draft')), (1, 'published', _('published')))
# ...
status = models.IntegerField(choices=STATUS, default=STATUS.draft)
You can index into a ``Choices`` instance to translate a database
representation to its display name:
.. code-block:: python
status_display = Article.STATUS[article.status]
Option groups can also be used with ``Choices``; in that case each
argument is a tuple consisting of the option group name and a list of
options, where each option in the list is either a string, a two-tuple,
or a triple as outlined above. For example:
.. code-block:: python
from model_utils import Choices
class Article(models.Model):
STATUS = Choices(('Visible', ['new', 'archived']), ('Invisible', ['draft', 'deleted']))
Choices can be concatenated with the ``+`` operator, both to other Choices
instances and other iterable objects that could be converted into Choices:
.. code-block:: python
from model_utils import Choices
GENERIC_CHOICES = Choices((0, 'draft', _('draft')), (1, 'published', _('published')))
class Article(models.Model):
STATUS = GENERIC_CHOICES + [(2, 'featured', _('featured'))]
status = models.IntegerField(choices=STATUS, default=STATUS.draft)

View file

@ -1,4 +1,4 @@
from .choices import Choices
from .tracker import FieldTracker, ModelTracker
__version__ = '1.4.0.post1'
__version__ = '1.5.0.post1'

View file

@ -36,52 +36,118 @@ class Choices(object):
identifier, the database representation itself is available as an
attribute on the ``Choices`` object, returning itself.)
Option groups can also be used with ``Choices``; in that case each
argument is a tuple consisting of the option group name and a list
of options, where each option in the list is either a string, a
two-tuple, or a triple as outlined above.
"""
def __init__(self, *choices):
self._full = []
self._choices = []
self._choice_dict = {}
for choice in self.equalize(choices):
self._full.append(choice)
self._choices.append((choice[0], choice[2]))
self._choice_dict[choice[1]] = choice[0]
# list of choices expanded to triples - can include optgroups
self._triples = []
# list of choices as (db, human-readable) - can include optgroups
self._doubles = []
# dictionary mapping db representation to human-readable
self._display_map = {}
# dictionary mapping Python identifier to db representation
self._identifier_map = {}
# set of db representations
self._db_values = set()
self._process(choices)
def _store(self, triple, triple_collector, double_collector):
self._identifier_map[triple[1]] = triple[0]
self._display_map[triple[0]] = triple[2]
self._db_values.add(triple[0])
triple_collector.append(triple)
double_collector.append((triple[0], triple[2]))
def _process(self, choices, triple_collector=None, double_collector=None):
if triple_collector is None:
triple_collector = self._triples
if double_collector is None:
double_collector = self._doubles
store = lambda c: self._store(c, triple_collector, double_collector)
def equalize(self, choices):
for choice in choices:
if isinstance(choice, (list, tuple)):
if len(choice) == 3:
yield choice
store(choice)
elif len(choice) == 2:
yield (choice[0], choice[0], choice[1])
if isinstance(choice[1], (list, tuple)):
# option group
group_name = choice[0]
subchoices = choice[1]
tc = []
triple_collector.append((group_name, tc))
dc = []
double_collector.append((group_name, dc))
self._process(subchoices, tc, dc)
else:
store((choice[0], choice[0], choice[1]))
else:
raise ValueError("Choices can't handle a list/tuple of length %s, only 2 or 3"
% len(choice))
raise ValueError(
"Choices can't take a list of length %s, only 2 or 3"
% len(choice)
)
else:
yield (choice, choice, choice)
store((choice, choice, choice))
def __len__(self):
return len(self._choices)
return len(self._doubles)
def __iter__(self):
return iter(self._choices)
return iter(self._doubles)
def __getattr__(self, attname):
try:
return self._choice_dict[attname]
return self._identifier_map[attname]
except KeyError:
raise AttributeError(attname)
def __getitem__(self, index):
return self._choices[index]
def __getitem__(self, key):
return self._display_map[key]
def __add__(self, other):
if isinstance(other, self.__class__):
other = other._triples
else:
other = list(other)
return Choices(*(self._triples + other))
def __radd__(self, other):
# radd is never called for matching types, so we don't check here
other = list(other)
return Choices(*(other + self._triples))
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._triples == other._triples
return False
def __repr__(self):
return '%s(%s)' % (self.__class__.__name__,
', '.join(("%s" % repr(i) for i in self._full)))
return '%s(%s)' % (
self.__class__.__name__,
', '.join(("%s" % repr(i) for i in self._triples))
)
def __contains__(self, item):
if item in self._choice_dict.values():
return True
return item in self._db_values
def __deepcopy__(self, memo):
return self.__class__(*copy.deepcopy(self._full, memo))
return self.__class__(*copy.deepcopy(self._triples, memo))

View file

@ -82,6 +82,10 @@ class MonitorField(models.DateTimeField):
raise TypeError(
'%s requires a "monitor" argument' % self.__class__.__name__)
self.monitor = monitor
when = kwargs.pop('when', None)
if when is not None:
when = set(when)
self.when = when
super(MonitorField, self).__init__(*args, **kwargs)
def contribute_to_class(self, cls, name):
@ -101,8 +105,9 @@ class MonitorField(models.DateTimeField):
previous = getattr(model_instance, self.monitor_attname, None)
current = self.get_monitored_value(model_instance)
if previous != current:
setattr(model_instance, self.attname, value)
self._save_initial(model_instance.__class__, model_instance)
if self.when is None or current in self.when:
setattr(model_instance, self.attname, value)
self._save_initial(model_instance.__class__, model_instance)
return super(MonitorField, self).pre_save(model_instance, add)

View file

@ -46,9 +46,12 @@ class InheritanceQuerySet(QuerySet):
def iterator(self):
iter = super(InheritanceQuerySet, self).iterator()
if getattr(self, 'subclasses', False):
# sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible
subclasses = sorted(self.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in self.subclasses:
for s in subclasses:
sub_obj = self._get_sub_obj_recurse(obj, s)
if sub_obj:
break
@ -93,6 +96,8 @@ class InheritanceQuerySet(QuerySet):
else:
return node
def get_subclass(self, *args, **kwargs):
return self.select_subclasses().get(*args, **kwargs)
class InheritanceManager(models.Manager):
@ -105,8 +110,7 @@ class InheritanceManager(models.Manager):
return self.get_query_set().select_subclasses(*subclasses)
def get_subclass(self, *args, **kwargs):
return self.get_query_set().select_subclasses().get(*args, **kwargs)
return self.get_query_set().get_subclass(*args, **kwargs)
class QueryManager(models.Manager):

View file

@ -0,0 +1,26 @@
from django.db import models
from django.utils.six import with_metaclass, string_types
class MutableField(with_metaclass(models.SubfieldBase, models.TextField)):
def to_python(self, value):
if value == '':
return None
try:
if isinstance(value, string_types):
return [int(i) for i in value.split(',')]
except ValueError:
pass
return value
def get_db_prep_save(self, value, connection):
if value is None:
return ''
if isinstance(value, list):
value = ','.join((str(i) for i in value))
return super(MutableField, self).get_db_prep_save(value, connection)

View file

@ -1,10 +1,14 @@
from __future__ import unicode_literals
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
from django.utils.translation import ugettext_lazy as _
from model_utils.models import TimeStampedModel, StatusModel, TimeFramedModel
from model_utils.tracker import FieldTracker, ModelTracker
from model_utils.managers import QueryManager, InheritanceManager, PassThroughManager
from model_utils.fields import SplitField, MonitorField, StatusField
from model_utils.tests.fields import MutableField
from model_utils import Choices
@ -14,6 +18,7 @@ class InheritanceManagerTestRelated(models.Model):
@python_2_unicode_compatible
class InheritanceManagerTestParent(models.Model):
# FileField is just a handy descriptor-using field. Refs #6.
non_related_field_using_descriptor = models.FileField(upload_to="test")
@ -23,11 +28,17 @@ class InheritanceManagerTestParent(models.Model):
objects = InheritanceManager()
def __str__(self):
return "%s(%s)" % (
self.__class__.__name__[len('InheritanceManagerTest'):],
self.pk,
)
class InheritanceManagerTestChild1(InheritanceManagerTestParent):
non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
normal_field_2 = models.TextField()
pass
class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1):
@ -66,6 +77,18 @@ class Monitored(models.Model):
class MonitorWhen(models.Model):
name = models.CharField(max_length=25)
name_changed = MonitorField(monitor="name", when=["Jose", "Maria"])
class MonitorWhenEmpty(models.Model):
name = models.CharField(max_length=25)
name_changed = MonitorField(monitor="name", when=[])
class Status(StatusModel):
STATUS = Choices(
("active", _("active")),
@ -234,6 +257,7 @@ class Spot(models.Model):
class Tracked(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
mutable = MutableField()
tracker = FieldTracker()
@ -278,6 +302,7 @@ class InheritedTracked(Tracked):
class ModelTracked(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
mutable = MutableField()
tracker = ModelTracker()

View file

@ -23,12 +23,11 @@ from model_utils.tests.models import (
InheritanceManagerTestGrandChild1_2,
InheritanceManagerTestParent, InheritanceManagerTestChild1,
InheritanceManagerTestChild2, TimeStamp, Post, Article, Status,
StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
StatusPlainTuple, TimeFrame, Monitored, MonitorWhen, MonitorWhenEmpty, StatusManagerAdded,
TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot,
ModelTracked, ModelTrackedFK, ModelTrackedNotDefault, ModelTrackedMultiple, InheritedModelTracked,
Tracked, TrackedFK, TrackedNotDefault, TrackedNonFieldAttr,
TrackedMultiple, InheritedTracked, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled)
Tracked, TrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple,
InheritedTracked, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled)
class GetExcerptTests(TestCase):
@ -171,6 +170,75 @@ class MonitorFieldTests(TestCase):
MonitorField()
class MonitorWhenFieldTests(TestCase):
"""
Will record changes only when name is 'Jose' or 'Maria'
"""
def setUp(self):
self.instance = MonitorWhen(name='Charlie')
self.created = self.instance.name_changed
def test_save_no_change(self):
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Jose(self):
self.instance.name = 'Jose'
self.instance.save()
self.assertTrue(self.instance.name_changed > self.created)
def test_save_changed_to_Maria(self):
self.instance.name = 'Maria'
self.instance.save()
self.assertTrue(self.instance.name_changed > self.created)
def test_save_changed_to_Pedro(self):
self.instance.name = 'Pedro'
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_double_save(self):
self.instance.name = 'Jose'
self.instance.save()
changed = self.instance.name_changed
self.instance.save()
self.assertEqual(self.instance.name_changed, changed)
class MonitorWhenEmptyFieldTests(TestCase):
"""
Monitor should never be updated id when is an empty list.
"""
def setUp(self):
self.instance = MonitorWhenEmpty(name='Charlie')
self.created = self.instance.name_changed
def test_save_no_change(self):
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Jose(self):
self.instance.name = 'Jose'
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Maria(self):
self.instance.name = 'Maria'
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
class StatusFieldTests(TestCase):
def test_status_with_default_filled(self):
@ -201,7 +269,7 @@ class ChoicesTests(TestCase):
def test_indexing(self):
self.assertEqual(self.STATUS[1], ('PUBLISHED', 'PUBLISHED'))
self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED')
def test_iteration(self):
@ -223,10 +291,12 @@ class ChoicesTests(TestCase):
with self.assertRaises(ValueError):
Choices(('a',))
def test_contains_value(self):
self.assertTrue('PUBLISHED' in self.STATUS)
self.assertTrue('DRAFT' in self.STATUS)
def test_doesnt_contain_value(self):
self.assertFalse('UNPUBLISHED' in self.STATUS)
@ -236,6 +306,32 @@ class ChoicesTests(TestCase):
list(copy.deepcopy(self.STATUS)))
def test_equality(self):
self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED'))
def test_inequality(self):
self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED'])
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_composability(self):
self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS)
self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS)
self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS)
def test_option_groups(self):
c = Choices(('group a', ['one', 'two']), ['group b', ('three',)])
self.assertEqual(
list(c),
[
('group a', [('one', 'one'), ('two', 'two')]),
('group b', [('three', 'three')]),
],
)
class LabelChoicesTests(ChoicesTests):
def setUp(self):
self.STATUS = Choices(
@ -254,7 +350,7 @@ class LabelChoicesTests(ChoicesTests):
def test_indexing(self):
self.assertEqual(self.STATUS[1], ('PUBLISHED', 'is published'))
self.assertEqual(self.STATUS['PUBLISHED'], 'is published')
def test_default(self):
@ -269,6 +365,23 @@ class LabelChoicesTests(ChoicesTests):
self.assertEqual(len(self.STATUS), 3)
def test_equality(self):
self.assertEqual(self.STATUS, Choices(
('DRAFT', 'is draft'),
('PUBLISHED', 'is published'),
'DELETED',
))
def test_inequality(self):
self.assertNotEqual(self.STATUS, [
('DRAFT', 'is draft'),
('PUBLISHED', 'is published'),
'DELETED'
])
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_repr(self):
self.assertEqual(repr(self.STATUS), "Choices" + repr((
('DRAFT', 'DRAFT', 'is draft'),
@ -276,6 +389,7 @@ class LabelChoicesTests(ChoicesTests):
('DELETED', 'DELETED', 'DELETED'),
)))
def test_contains_value(self):
self.assertTrue('PUBLISHED' in self.STATUS)
self.assertTrue('DRAFT' in self.STATUS)
@ -283,13 +397,46 @@ class LabelChoicesTests(ChoicesTests):
# and the internal representation are both DELETED.
self.assertTrue('DELETED' in self.STATUS)
def test_doesnt_contain_value(self):
self.assertFalse('UNPUBLISHED' in self.STATUS)
def test_doesnt_contain_display_value(self):
self.assertFalse('is draft' in self.STATUS)
def test_composability(self):
self.assertEqual(
Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'),
self.STATUS
)
self.assertEqual(
(('DRAFT', 'is draft',),) + Choices(('PUBLISHED', 'is published'), 'DELETED'),
self.STATUS
)
self.assertEqual(
Choices(('DRAFT', 'is draft',)) + (('PUBLISHED', 'is published'), 'DELETED'),
self.STATUS
)
def test_option_groups(self):
c = Choices(
('group a', [(1, 'one'), (2, 'two')]),
['group b', ((3, 'three'),)]
)
self.assertEqual(
list(c),
[
('group a', [(1, 'one'), (2, 'two')]),
('group b', [(3, 'three')]),
],
)
class IdentifierChoicesTests(ChoicesTests):
def setUp(self):
@ -307,7 +454,7 @@ class IdentifierChoicesTests(ChoicesTests):
def test_indexing(self):
self.assertEqual(self.STATUS[1], (1, 'is published'))
self.assertEqual(self.STATUS[1], 'is published')
def test_getattr(self):
@ -325,20 +472,88 @@ class IdentifierChoicesTests(ChoicesTests):
(2, 'DELETED', 'is deleted'),
)))
def test_contains_value(self):
self.assertTrue(0 in self.STATUS)
self.assertTrue(1 in self.STATUS)
self.assertTrue(2 in self.STATUS)
def test_doesnt_contain_value(self):
self.assertFalse(3 in self.STATUS)
def test_doesnt_contain_display_value(self):
self.assertFalse('is draft' in self.STATUS)
def test_doesnt_contain_python_attr(self):
self.assertFalse('PUBLISHED' in self.STATUS)
def test_equality(self):
self.assertEqual(self.STATUS, Choices(
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'),
(2, 'DELETED', 'is deleted')
))
def test_inequality(self):
self.assertNotEqual(self.STATUS, [
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'),
(2, 'DELETED', 'is deleted')
])
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_composability(self):
self.assertEqual(
Choices(
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published')
) + Choices(
(2, 'DELETED', 'is deleted'),
),
self.STATUS
)
self.assertEqual(
Choices(
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published')
) + (
(2, 'DELETED', 'is deleted'),
),
self.STATUS
)
self.assertEqual(
(
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published')
) + Choices(
(2, 'DELETED', 'is deleted'),
),
self.STATUS
)
def test_option_groups(self):
c = Choices(
('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]),
['group b', ((3, 'THREE', 'three'),)]
)
self.assertEqual(
list(c),
[
('group a', [(1, 'one'), (2, 'two')]),
('group b', [(3, 'three')]),
],
)
class InheritanceManagerTests(TestCase):
def setUp(self):
self.child1 = InheritanceManagerTestChild1.objects.create()
@ -401,8 +616,26 @@ class InheritanceManagerTests(TestCase):
self.assertEqual(
set(
self.get_manager().select_subclasses(
"inheritancemanagertestchild1__"
"inheritancemanagertestgrandchild1"
"inheritancemanagertestchild1__inheritancemanagertestgrandchild1"
)
),
children,
)
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_children_and_grandchildren(self):
children = set([
self.child1,
InheritanceManagerTestParent(pk=self.child2.pk),
self.grandchild1,
InheritanceManagerTestChild1(pk=self.grandchild1_2.pk),
])
self.assertEqual(
set(
self.get_manager().select_subclasses(
"inheritancemanagertestchild1",
"inheritancemanagertestchild1__inheritancemanagertestgrandchild1"
)
),
children,
@ -415,6 +648,12 @@ class InheritanceManagerTests(TestCase):
self.child1)
def test_get_subclass_on_queryset(self):
self.assertEqual(
self.get_manager().all().get_subclass(pk=self.child1.pk),
self.child1)
def test_prior_select_related(self):
with self.assertNumQueries(1):
obj = self.get_manager().select_related(
@ -767,52 +1006,61 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertChanged(name=None, number=None)
self.instance.name = ''
self.assertChanged(name=None, number=None)
self.instance.mutable = [1,2,3]
self.assertChanged(name=None, number=None, mutable=None)
def test_pre_save_has_changed(self):
self.assertHasChanged(name=True, number=False)
self.assertHasChanged(name=True, number=False, mutable=False)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=False)
self.assertHasChanged(name=True, number=False, mutable=False)
self.instance.number = 7
self.assertHasChanged(name=True, number=True)
self.instance.mutable = [1,2,3]
self.assertHasChanged(name=True, number=True, mutable=True)
def test_first_save(self):
self.assertHasChanged(name=True, number=False)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='', number=None, id=None)
self.assertHasChanged(name=True, number=False, mutable=False)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='', number=None, id=None, mutable=None)
self.assertChanged(name=None)
self.instance.name = 'retro'
self.instance.number = 4
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='retro', number=4, id=None)
self.assertChanged(name=None, number=None)
self.instance.mutable = [1,2,3]
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
self.assertChanged(name=None, number=None, mutable=None)
# Django 1.4 doesn't have update_fields
if django.VERSION >= (1, 5, 0):
self.instance.save(update_fields=[])
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='retro', number=4, id=None)
self.assertChanged(name=None, number=None)
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
self.assertChanged(name=None, number=None, mutable=None)
with self.assertRaises(ValueError):
self.instance.save(update_fields=['number'])
def test_post_save_has_changed(self):
self.update_instance(name='retro', number=4)
self.assertHasChanged(name=False, number=False)
self.update_instance(name='retro', number=4, mutable=[1,2,3])
self.assertHasChanged(name=False, number=False, mutable=False)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=False)
self.instance.number = 8
self.assertHasChanged(name=True, number=True)
self.instance.mutable[1] = 4
self.assertHasChanged(name=True, number=True, mutable=True)
self.instance.name = 'retro'
self.assertHasChanged(name=False, number=True)
self.assertHasChanged(name=False, number=True, mutable=True)
def test_post_save_previous(self):
self.update_instance(name='retro', number=4)
self.update_instance(name='retro', number=4, mutable=[1,2,3])
self.instance.name = 'new age'
self.assertPrevious(name='retro', number=4)
self.assertPrevious(name='retro', number=4, mutable=[1,2,3])
self.instance.mutable[1] = 4
self.assertPrevious(name='retro', number=4, mutable=[1,2,3])
def test_post_save_changed(self):
self.update_instance(name='retro', number=4)
self.update_instance(name='retro', number=4, mutable=[1,2,3])
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged(name='retro')
@ -820,36 +1068,48 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertChanged(name='retro', number=4)
self.instance.name = 'retro'
self.assertChanged(number=4)
self.instance.mutable[1] = 4
self.assertChanged(number=4, mutable=[1,2,3])
self.instance.mutable = [1,2,3]
self.assertChanged(number=4)
def test_current(self):
self.assertCurrent(id=None, name='', number=None)
self.assertCurrent(id=None, name='', number=None, mutable=None)
self.instance.name = 'new age'
self.assertCurrent(id=None, name='new age', number=None)
self.assertCurrent(id=None, name='new age', number=None, mutable=None)
self.instance.number = 8
self.assertCurrent(id=None, name='new age', number=8)
self.assertCurrent(id=None, name='new age', number=8, mutable=None)
self.instance.mutable = [1,2,3]
self.assertCurrent(id=None, name='new age', number=8, mutable=[1,2,3])
self.instance.mutable[1] = 4
self.assertCurrent(id=None, name='new age', number=8, mutable=[1,4,3])
self.instance.save()
self.assertCurrent(id=self.instance.id, name='new age', number=8)
self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1,4,3])
@skipUnless(
django.VERSION >= (1, 5, 0), "Django 1.4 doesn't have update_fields")
def test_update_fields(self):
self.update_instance(name='retro', number=4)
self.update_instance(name='retro', number=4, mutable=[1,2,3])
self.assertChanged()
self.instance.name = 'new age'
self.instance.number = 8
self.assertChanged(name='retro', number=4)
self.instance.mutable = [4,5,6]
self.assertChanged(name='retro', number=4, mutable=[1,2,3])
self.instance.save(update_fields=[])
self.assertChanged(name='retro', number=4)
self.assertChanged(name='retro', number=4, mutable=[1,2,3])
self.instance.save(update_fields=['name'])
in_db = self.tracked_class.objects.get(id=self.instance.id)
self.assertEqual(in_db.name, self.instance.name)
self.assertNotEqual(in_db.number, self.instance.number)
self.assertChanged(number=4)
self.assertChanged(number=4, mutable=[1,2,3])
self.instance.save(update_fields=['number'])
self.assertChanged(mutable=[1,2,3])
self.instance.save(update_fields=['mutable'])
self.assertChanged()
in_db = self.tracked_class.objects.get(id=self.instance.id)
self.assertEqual(in_db.name, self.instance.name)
self.assertEqual(in_db.number, self.instance.number)
self.assertEqual(in_db.mutable, self.instance.mutable)
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
@ -1141,24 +1401,27 @@ class ModelTrackerTests(FieldTrackerTests):
self.assertChanged()
self.instance.name = ''
self.assertChanged()
self.instance.mutable = [1,2,3]
self.assertChanged()
def test_first_save(self):
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='', number=None, id=None)
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='', number=None, id=None, mutable=None)
self.assertChanged()
self.instance.name = 'retro'
self.instance.number = 4
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='retro', number=4, id=None)
self.instance.mutable = [1,2,3]
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
self.assertChanged()
# Django 1.4 doesn't have update_fields
if django.VERSION >= (1, 5, 0):
self.instance.save(update_fields=[])
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='retro', number=4, id=None)
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
self.assertChanged()
with self.assertRaises(ValueError):
self.instance.save(update_fields=['number'])

View file

@ -1,4 +1,7 @@
from __future__ import unicode_literals
from copy import deepcopy
from django.db import models
from django.core.exceptions import FieldError
@ -20,8 +23,12 @@ class FieldInstanceTracker(object):
else:
self.saved_data.update(**self.current(fields=fields))
# preventing mutable fields side effects
for field, field_value in self.saved_data.items():
self.saved_data[field] = deepcopy(field_value)
def current(self, fields=None):
"""Return dict of current values for all tracked fields"""
"""Returns dict of current values for all tracked fields"""
if fields is None:
fields = self.fields
return dict((f, self.get_field_value(f)) for f in fields)
@ -34,7 +41,7 @@ class FieldInstanceTracker(object):
raise FieldError('field "%s" not tracked' % field)
def previous(self, field):
"""Return currently saved value of given field"""
"""Returns currently saved value of given field"""
return self.saved_data.get(field)
def changed(self):
@ -54,7 +61,7 @@ class FieldTracker(object):
self.fields = fields
def get_field_map(self, cls):
"""Return dict mapping fields names to model attribute names"""
"""Returns dict mapping fields names to model attribute names"""
field_map = dict((field, field) for field in self.fields)
all_fields = dict((f.name, f.attname) for f in cls._meta.local_fields)
field_map.update(**dict((k, v) for (k, v) in all_fields.items()