diff --git a/AUTHORS.rst b/AUTHORS.rst index 33224c5..058b613 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -44,3 +44,4 @@ Karl Wan Nan Wo zyegfryed Radosław Jan Ganczarek Lucas Wiman +Jack Cushman diff --git a/CHANGES.rst b/CHANGES.rst index 9bd6129..b367ada 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -11,6 +11,9 @@ master (unreleased) ModelIterable instead of BaseIterable, fixes GH-277. +- Fix `FieldTracker.has_changed()` and `FieldTracker.previous()` to return + correct responses for deferred fields. + 3.1.1 (2017.12.17) ------------------ diff --git a/docs/utilities.rst b/docs/utilities.rst index 44824f5..b763ba0 100644 --- a/docs/utilities.rst +++ b/docs/utilities.rst @@ -150,6 +150,10 @@ Returns the value of the given field during the last save: Returns ``None`` when the model instance isn't saved yet. +If a field is `deferred`_, calling ``previous()`` will load the previous value from the database. + +.. _deferred: https://docs.djangoproject.com/en/2.0/ref/models/querysets/#defer + has_changed ~~~~~~~~~~~ @@ -167,6 +171,8 @@ Returns ``True`` if the given field has changed since the last save. The ``has_c The ``has_changed`` method relies on ``previous`` to determine whether a field's values has changed. +If a field is `deferred`_ and has been assigned locally, calling ``has_changed()`` +will load the previous value from the database to perform the comparison. changed ~~~~~~~ diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 0ec3044..d1101e2 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -116,12 +116,31 @@ class FieldInstanceTracker(object): def has_changed(self, field): """Returns ``True`` if field has changed from currently saved value""" if field in self.fields: + # deferred fields haven't changed + if field in self.instance._deferred_fields and field not in self.instance.__dict__: + return False return self.previous(field) != self.get_field_value(field) else: raise FieldError('field "%s" not tracked' % field) def previous(self, field): """Returns currently saved value of given field""" + + # handle deferred fields that have not yet been loaded from the database + if self.instance.pk and field in self.instance._deferred_fields and field not in self.saved_data: + + # if the field has not been assigned locally, simply fetch and un-defer the value + if field not in self.instance.__dict__: + self.get_field_value(field) + + # if the field has been assigned locally, store the local value, fetch the database value, + # store database value to saved_data, and restore the local value + else: + current_value = self.get_field_value(field) + self.instance.refresh_from_db(fields=[field]) + self.saved_data[field] = deepcopy(self.get_field_value(field)) + setattr(self.instance, self.field_map[field], current_value) + return self.saved_data.get(field) def changed(self): diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index b389861..bc0da4c 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -180,18 +180,27 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.instance.name = 'new age' self.instance.number = 1 self.instance.save() - item = list(self.tracked_class.objects.only('name').all())[0] + item = self.tracked_class.objects.only('name').first() if django.VERSION >= (1, 10): self.assertTrue(item.get_deferred_fields()) else: self.assertTrue(item._deferred_fields) - self.assertEqual(item.tracker.previous('number'), None) - if django.VERSION >= (1, 10): - self.assertTrue('number' in item.get_deferred_fields()) - else: - self.assertTrue('number' in item._deferred_fields) + # has_changed() returns False for deferred fields, without un-deferring them. + # Use an if because ModelTracked doesn't support has_changed() in this case. + if self.tracked_class == Tracked: + self.assertEqual(item.tracker.previous('number'), None) + if django.VERSION >= (1, 10): + self.assertTrue('number' in item.get_deferred_fields()) + else: + self.assertTrue('number' in item._deferred_fields) + # previous() un-defers field and returns value + self.assertEqual(item.tracker.previous('number'), 1) + self.assertTrue('number' not in item._deferred_fields) + + # examining a deferred field un-defers it + item = self.tracked_class.objects.only('name').first() self.assertEqual(item.number, 1) if django.VERSION >= (1, 10): self.assertTrue('number' not in item.get_deferred_fields()) @@ -200,9 +209,36 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertEqual(item.tracker.previous('number'), 1) self.assertFalse(item.tracker.has_changed('number')) + # has_changed() returns correct values after deferred field is examined + self.assertFalse(item.tracker.has_changed('number')) item.number = 2 self.assertTrue(item.tracker.has_changed('number')) + # previous() returns correct value after deferred field is examined + self.assertEqual(item.tracker.previous('number'), 1) + + # assigning to a deferred field un-defers it + # Use an if because ModelTracked doesn't handle this case. + if self.tracked_class == Tracked: + + item = self.tracked_class.objects.only('name').first() + item.number = 2 + + # _deferred_fields is not updated by assignment + self.assertTrue('number' in item._deferred_fields) + + # previous() fetches correct value from database after deferred field is assigned + self.assertEqual(item.tracker.previous('number'), 1) + + # database fetch of previous() value doesn't affect current value + self.assertEqual(item.number, 2) + + # has_changed() returns correct values after deferred field is assigned + self.assertTrue(item.tracker.has_changed('number')) + item.number = 1 + self.assertFalse(item.tracker.has_changed('number')) + + class FieldTrackerMultipleInstancesTests(TestCase):