diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 38150a8..9de6892 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -1497,6 +1497,9 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): item.number = 2 self.assertTrue(item.tracker.has_changed('number')) + + def test_can_pickle_objects(self): + pickle.dumps(self.instance) class FieldTrackedModelCustomTests(FieldTrackerTestCase, diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 6cb4355..021b12f 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -2,10 +2,11 @@ from __future__ import unicode_literals from copy import deepcopy -from django.db import models from django.core.exceptions import FieldError +from django.db import models from django.db.models.query_utils import DeferredAttribute - +from django.db.models.signals import post_save +from django.dispatch import receiver class FieldInstanceTracker(object): def __init__(self, instance, fields, field_map): @@ -119,20 +120,16 @@ class FieldTracker(object): models.signals.post_init.connect(self.initialize_tracker) self.model_class = sender setattr(sender, self.name, self) - - def initialize_tracker(self, sender, instance, **kwargs): - if not isinstance(instance, self.model_class): - return # Only init instances of given model (including children) - tracker = self.tracker_class(instance, self.fields, self.field_map) - setattr(instance, self.attname, tracker) - tracker.set_saved_fields() - self.patch_save(instance) - - def patch_save(self, instance): - original_save = instance.save - def save(**kwargs): - ret = original_save(**kwargs) + + # Rather than patch the save method on the instance, + # we can observe the post_save signal on the class. + @receiver(post_save, sender=None, weak=False) + def handler(sender, instance, **kwargs): + if not isinstance(instance, self.model_class): + return + update_fields = kwargs.get('update_fields') + if not update_fields and update_fields is not None: # () or [] fields = update_fields elif update_fields is None: @@ -142,12 +139,19 @@ class FieldTracker(object): field for field in update_fields if field in self.fields ) + getattr(instance, self.attname).set_saved_fields( fields=fields ) - return ret - instance.save = save + + def initialize_tracker(self, sender, instance, **kwargs): + if not isinstance(instance, self.model_class): + return # Only init instances of given model (including children) + tracker = self.tracker_class(instance, self.fields, self.field_map) + setattr(instance, self.attname, tracker) + tracker.set_saved_fields() + def __get__(self, instance, owner): if instance is None: return self