diff --git a/model_utils/tracker.py b/model_utils/tracker.py index b80d3b9..e38d81f 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -9,7 +9,7 @@ class FieldInstanceTracker(object): self.fields = fields self.field_map = field_map - def get_field(self, field): + def get_field_value(self, field): return getattr(self.instance, self.field_map[field]) def set_saved_fields(self, fields=None): @@ -23,7 +23,7 @@ class FieldInstanceTracker(object): def current(self, fields=None): if fields is None: fields = self.fields - return dict((f, self.get_field(f)) for f in fields) + return dict((f, self.get_field_value(f)) for f in fields) def has_changed(self, field): """Returns ``True`` if field has changed from currently saved value""" @@ -54,11 +54,13 @@ class FieldTracker(object): def __init__(self, fields=None): self.fields = fields - def set_field_map(self, cls): - self.field_map = dict((field, field) for field in self.fields) + def get_field_map(self, cls): + """Return ``dict`` mapping fields to model instance attribute names""" + field_map = dict((field, field) for field in self.fields) all_fields = dict((f.name, f.attname) for f in cls._meta.local_fields) - self.field_map.update(**dict((k, v) for (k, v) in all_fields.items() - if k in self.field_map)) + field_map.update(**dict((k, v) for (k, v) in all_fields.items() + if k in field_map)) + return field_map def contribute_to_class(self, cls, name): self.name = name @@ -68,7 +70,7 @@ class FieldTracker(object): def finalize_class(self, sender, **kwargs): if self.fields is None: self.fields = [field.attname for field in sender._meta.local_fields] - self.set_field_map(sender) + self.field_map = self.get_field_map(sender) models.signals.post_init.connect(self.initialize_tracker, sender=sender) setattr(sender, self.name, self) @@ -101,5 +103,5 @@ class ModelInstanceTracker(FieldInstanceTracker): class ModelTracker(FieldTracker): tracker_class = ModelInstanceTracker - def set_field_map(self, cls): - self.field_map = dict((field, field) for field in self.fields) + def get_field_map(self, cls): + return dict((field, field) for field in self.fields)