mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-05-27 11:03:59 +00:00
Use a signal handler instead of patching save.
References #83. Instead of patching the save method of a tracked model class, we can use a signal handler on post_save, which means we can still pickle our model class. Note we can't just listen for the signal from the class we have, but instead listen for all post_save signals. This means we actually install a new signal handler for each tracked model class, which fires on all model save occurrences (and returns immediately if this handler doesn't care). We probably could improve this to have a registry of tracked models, or something, that allows us to just install one signal handler, and filter according to membership.
This commit is contained in:
parent
fdf20e9d13
commit
3496fe4291
2 changed files with 24 additions and 17 deletions
|
|
@ -1497,6 +1497,9 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||||
|
|
||||||
item.number = 2
|
item.number = 2
|
||||||
self.assertTrue(item.tracker.has_changed('number'))
|
self.assertTrue(item.tracker.has_changed('number'))
|
||||||
|
|
||||||
|
def test_can_pickle_objects(self):
|
||||||
|
pickle.dumps(self.instance)
|
||||||
|
|
||||||
|
|
||||||
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,11 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from django.db import models
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
|
from django.db import models
|
||||||
from django.db.models.query_utils import DeferredAttribute
|
from django.db.models.query_utils import DeferredAttribute
|
||||||
|
from django.db.models.signals import post_save
|
||||||
|
from django.dispatch import receiver
|
||||||
|
|
||||||
class FieldInstanceTracker(object):
|
class FieldInstanceTracker(object):
|
||||||
def __init__(self, instance, fields, field_map):
|
def __init__(self, instance, fields, field_map):
|
||||||
|
|
@ -119,20 +120,16 @@ class FieldTracker(object):
|
||||||
models.signals.post_init.connect(self.initialize_tracker)
|
models.signals.post_init.connect(self.initialize_tracker)
|
||||||
self.model_class = sender
|
self.model_class = sender
|
||||||
setattr(sender, self.name, self)
|
setattr(sender, self.name, self)
|
||||||
|
|
||||||
def initialize_tracker(self, sender, instance, **kwargs):
|
# Rather than patch the save method on the instance,
|
||||||
if not isinstance(instance, self.model_class):
|
# we can observe the post_save signal on the class.
|
||||||
return # Only init instances of given model (including children)
|
@receiver(post_save, sender=None, weak=False)
|
||||||
tracker = self.tracker_class(instance, self.fields, self.field_map)
|
def handler(sender, instance, **kwargs):
|
||||||
setattr(instance, self.attname, tracker)
|
if not isinstance(instance, self.model_class):
|
||||||
tracker.set_saved_fields()
|
return
|
||||||
self.patch_save(instance)
|
|
||||||
|
|
||||||
def patch_save(self, instance):
|
|
||||||
original_save = instance.save
|
|
||||||
def save(**kwargs):
|
|
||||||
ret = original_save(**kwargs)
|
|
||||||
update_fields = kwargs.get('update_fields')
|
update_fields = kwargs.get('update_fields')
|
||||||
|
|
||||||
if not update_fields and update_fields is not None: # () or []
|
if not update_fields and update_fields is not None: # () or []
|
||||||
fields = update_fields
|
fields = update_fields
|
||||||
elif update_fields is None:
|
elif update_fields is None:
|
||||||
|
|
@ -142,12 +139,19 @@ class FieldTracker(object):
|
||||||
field for field in update_fields if
|
field for field in update_fields if
|
||||||
field in self.fields
|
field in self.fields
|
||||||
)
|
)
|
||||||
|
|
||||||
getattr(instance, self.attname).set_saved_fields(
|
getattr(instance, self.attname).set_saved_fields(
|
||||||
fields=fields
|
fields=fields
|
||||||
)
|
)
|
||||||
return ret
|
|
||||||
instance.save = save
|
|
||||||
|
|
||||||
|
def initialize_tracker(self, sender, instance, **kwargs):
|
||||||
|
if not isinstance(instance, self.model_class):
|
||||||
|
return # Only init instances of given model (including children)
|
||||||
|
tracker = self.tracker_class(instance, self.fields, self.field_map)
|
||||||
|
setattr(instance, self.attname, tracker)
|
||||||
|
tracker.set_saved_fields()
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue