Don't compare FK instances in FieldTracker

This commit is contained in:
Trey Hunner 2013-05-23 13:18:21 -07:00
parent 4bae3e999b
commit 54c996f17f

View file

@ -4,9 +4,13 @@ from django.core.exceptions import FieldError
class FieldInstanceTracker(object):
def __init__(self, instance, fields):
def __init__(self, instance, fields, field_map):
self.instance = instance
self.fields = fields
self.field_map = field_map
def get_field(self, field):
return getattr(self.instance, self.field_map[field])
def set_saved_fields(self, fields=None):
if not self.instance.pk:
@ -19,14 +23,14 @@ class FieldInstanceTracker(object):
def current(self, fields=None):
if fields is None:
fields = self.fields
return dict((f, getattr(self.instance, f)) for f in fields)
return dict((f, self.get_field(f)) for f in fields)
def has_changed(self, field):
"""Returns ``True`` if field has changed from currently saved value"""
if not self.instance.pk:
return True
elif field in self.saved_data:
return self.saved_data.get(field) != getattr(self.instance, field)
return self.saved_data.get(field) != self.get_field(field)
else:
raise FieldError('field "%s" not tracked' % field)
@ -50,6 +54,12 @@ class FieldTracker(object):
def __init__(self, fields=None):
self.fields = fields
def set_field_map(self, cls):
self.field_map = dict((field, field) for field in self.fields)
all_fields = dict((f.name, f.attname) for f in cls._meta.local_fields)
self.field_map.update(**dict((k, v) for (k, v) in all_fields.items()
if k in self.field_map))
def contribute_to_class(self, cls, name):
self.name = name
self.attname = '_%s' % name
@ -58,11 +68,12 @@ class FieldTracker(object):
def finalize_class(self, sender, **kwargs):
if self.fields is None:
self.fields = [field.attname for field in sender._meta.local_fields]
self.set_field_map(sender)
models.signals.post_init.connect(self.initialize_tracker, sender=sender)
setattr(sender, self.name, self)
def initialize_tracker(self, sender, instance, **kwargs):
tracker = self.tracker_class(instance, self.fields)
tracker = self.tracker_class(instance, self.fields, self.field_map)
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
self.patch_save(instance)
@ -89,3 +100,6 @@ class ModelInstanceTracker(FieldInstanceTracker):
class ModelTracker(FieldTracker):
tracker_class = ModelInstanceTracker
def set_field_map(self, cls):
self.field_map = dict((field, field) for field in self.fields)