diff --git a/AUTHORS.rst b/AUTHORS.rst index 7ceed61..8da575b 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -16,6 +16,7 @@ Jannis Leidel Javier GarcĂ­a Sogo Jeff Elmore Keryn Knight +Michael van Tellingen Mikhail Silonov Patryk Zawadzki Paul McLanahan diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index abad8e1..fac85ac 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -1432,6 +1432,24 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertEqual(in_db.number, self.instance.number) self.assertEqual(in_db.mutable, self.instance.mutable) + def test_with_deferred(self): + self.instance.name = 'new age' + self.instance.number = 1 + self.instance.save() + item = list(self.tracked_class.objects.only('name').all())[0] + self.assertTrue(item.tracker.deferred_fields) + + self.assertEqual(item.tracker.previous('number'), None) + self.assertTrue('number' in item.tracker.deferred_fields) + + self.assertEqual(item.number, 1) + self.assertTrue('number' not in item.tracker.deferred_fields) + self.assertEqual(item.tracker.previous('number'), 1) + self.assertFalse(item.tracker.has_changed('number')) + + item.number = 2 + self.assertTrue(item.tracker.has_changed('number')) + class FieldTrackedModelCustomTests(FieldTrackerTestCase, FieldTrackerCommonTests): diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 77eaaa4..6cb4355 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -4,6 +4,7 @@ from copy import deepcopy from django.db import models from django.core.exceptions import FieldError +from django.db.models.query_utils import DeferredAttribute class FieldInstanceTracker(object): @@ -11,6 +12,7 @@ class FieldInstanceTracker(object): self.instance = instance self.fields = fields self.field_map = field_map + self.init_deferred_fields() def get_field_value(self, field): return getattr(self.instance, self.field_map[field]) @@ -30,7 +32,14 @@ class FieldInstanceTracker(object): def current(self, fields=None): """Returns dict of current values for all tracked fields""" if fields is None: - fields = self.fields + if self.deferred_fields: + fields = [ + field for field in self.fields + if field not in self.deferred_fields + ] + else: + fields = self.fields + return dict((f, self.get_field_value(f)) for f in fields) def has_changed(self, field): @@ -52,6 +61,35 @@ class FieldInstanceTracker(object): if self.has_changed(field) ) + def init_deferred_fields(self): + self.deferred_fields = [] + if not self.instance._deferred: + return + + class DeferredAttributeTracker(DeferredAttribute): + def __get__(field, instance, owner): + data = instance.__dict__ + if data.get(field.field_name, field) is field: + self.deferred_fields.remove(field.field_name) + value = super(DeferredAttributeTracker, field).__get__( + instance, owner) + self.saved_data[field.field_name] = deepcopy(value) + return data[field.field_name] + + for field in self.fields: + field_obj = self.instance.__class__.__dict__.get(field) + if isinstance(field_obj, DeferredAttribute): + self.deferred_fields.append(field) + + # Django 1.4 + model = None + if hasattr(field_obj, 'model_ref'): + model = field_obj.model_ref() + + field_tracker = DeferredAttributeTracker( + field_obj.field_name, model) + setattr(self.instance.__class__, field, field_tracker) + class FieldTracker(object):