From b9f954074c4988b4f6ffc02991a92e9286c469c6 Mon Sep 17 00:00:00 2001 From: Mikhail Silonov Date: Thu, 8 Aug 2013 13:18:33 +0400 Subject: [PATCH] Fixed a bug causing `KeyError` when saving with the parameter `update_fields` in which there are untracked fields. --- AUTHORS.rst | 1 + CHANGES.rst | 3 +++ model_utils/tests/tests.py | 12 +++++++++++- model_utils/tracker.py | 16 ++++++++++++++-- 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/AUTHORS.rst b/AUTHORS.rst index 3a33840..5ea4eba 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -19,3 +19,4 @@ Simon Meers sayane Trey Hunner zyegfryed +Mikhail Silonov diff --git a/CHANGES.rst b/CHANGES.rst index 9b3bb75..44d3959 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,9 @@ master (unreleased) * `Choices` now `__contains__` its Python identifier values. Thanks Keryn Knight. (Merge of GH-69). +* Fixed a bug causing ``KeyError`` when saving with the parameter + ``update_fields`` in which there are untracked fields. + 1.4.0 (2013.06.03) ------------------ diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index ab60a62..b85b08d 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -923,6 +923,16 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase, self.instance.save() self.assertCurrent(name='new age') + def test_update_fields(self): + # Django 1.4 doesn't have update_fields + if django.VERSION >= (1, 5, 0): + self.update_instance(name='retro', number=4) + self.assertChanged() + self.instance.name = 'new age' + self.instance.number = 8 + self.instance.save(update_fields=['name', 'number']) + self.assertChanged() + class FieldTrackedModelAttributeTests(FieldTrackerTestCase): @@ -976,7 +986,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase): class FieldTrackedModelMultiTests(FieldTrackerTestCase, - FieldTrackerCommonTests): + FieldTrackerCommonTests): tracked_class = TrackedMultiple diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 31fd0f2..aecf27c 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -68,7 +68,8 @@ class FieldTracker(object): def finalize_class(self, sender, **kwargs): if self.fields is None: - self.fields = [field.attname for field in sender._meta.local_fields] + self.fields = (field.attname for field in sender._meta.local_fields) + self.fields = set(self.fields) self.field_map = self.get_field_map(sender) models.signals.post_init.connect(self.initialize_tracker, sender=sender) setattr(sender, self.name, self) @@ -83,8 +84,19 @@ class FieldTracker(object): original_save = instance.save def save(**kwargs): ret = original_save(**kwargs) + update_fields = kwargs.get('update_fields') + if not update_fields and update_fields is not None: # () or [] + fields = update_fields + elif update_fields is None: + fields = None + else: + fields = ( + field for field in update_fields if + field in self.fields + ) getattr(instance, self.attname).set_saved_fields( - fields=kwargs.get('update_fields')) + fields=fields + ) return ret instance.save = save