From 90ed7fc9052f46ea1007af26ff9953f3f1e198b7 Mon Sep 17 00:00:00 2001 From: Lucas Wiman Date: Wed, 4 Apr 2018 10:02:46 -0700 Subject: [PATCH] Improve coverage. --- model_utils/tracker.py | 33 +++++------------------ tests/models.py | 3 +++ tests/test_models/test_deferred_fields.py | 11 ++++++++ 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 095d87e..0ec3044 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -40,10 +40,7 @@ class DescriptorWrapper(object): if instance is None: return self was_deferred = self.field_name in instance.get_deferred_fields() - if self.descriptor: - value = self.descriptor.__get__(instance, owner) - else: - value = instance.__dict__[self.field_name] + value = self.descriptor.__get__(instance, owner) if was_deferred: tracker_instance = getattr(instance, self.tracker_attname) tracker_instance.saved_data[self.field_name] = deepcopy(value) @@ -53,12 +50,10 @@ class DescriptorWrapper(object): def cls_for_descriptor(descriptor): has_set = hasattr(descriptor, '__set__') has_del = hasattr(descriptor, '__delete__') - if has_set and has_del: + if has_del: return FullDescriptorWrapper elif has_set: return SettableDescriptorWrapper - elif has_del: - return DeleteableDescriptorWrapper else: return DescriptorWrapper @@ -73,20 +68,12 @@ class SettableDescriptorWrapper(DescriptorWrapper): return self.descriptor.__set__(instance, value) -class DeleteableDescriptorWrapper(DescriptorWrapper): - """ - Descriptor wrapper for descriptors with a __delete__ method. - - This should not be used for descriptors - """ - def __delete__(self, instance): - self.descriptor.__delete__(instance) - - -class FullDescriptorWrapper(SettableDescriptorWrapper, DeleteableDescriptorWrapper): +class FullDescriptorWrapper(SettableDescriptorWrapper): """ Wrapper for descriptors with all three descriptor methods. """ + def __delete__(self, obj): + self.descriptor.__delete__(obj) class FieldInstanceTracker(object): @@ -161,10 +148,7 @@ class FieldInstanceTracker(object): self.instance._deferred_fields = self.instance.get_deferred_fields() for field in self.instance._deferred_fields: - if django.VERSION >= (1, 10): - field_obj = getattr(self.instance.__class__, field) - else: - field_obj = self.instance.__class__.__dict__.get(field) + field_obj = self.instance.__class__.__dict__.get(field) if isinstance(field_obj, FileDescriptor): field_tracker = FileDescriptorTracker(field_obj.field) setattr(self.instance.__class__, field, field_tracker) @@ -199,10 +183,7 @@ class FieldTracker(object): self.fields = set(self.fields) if django.VERSION >= (1, 10): for field_name in self.fields: - if django.VERSION >= (1, 10): - descriptor = getattr(sender, field_name) - else: - descriptor = sender.__dict__.get(field_name) + descriptor = getattr(sender, field_name) wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) setattr(sender, field_name, wrapped_descriptor) diff --git a/tests/models.py b/tests/models.py index 841f612..91df0b6 100644 --- a/tests/models.py +++ b/tests/models.py @@ -352,6 +352,9 @@ class StringyDescriptor(object): def __set__(self, obj, value): obj.__dict__[self.name] = int(value) + def __delete__(self, obj): + del obj.__dict__[self.name] + class CustomDescriptorField(models.IntegerField): def contribute_to_class(self, cls, name, **kwargs): diff --git a/tests/test_models/test_deferred_fields.py b/tests/test_models/test_deferred_fields.py index b235843..05ba336 100644 --- a/tests/test_models/test_deferred_fields.py +++ b/tests/test_models/test_deferred_fields.py @@ -58,3 +58,14 @@ class CustomDescriptorTests(TestCase): self.assertEqual(instance.regular_field, 1) self.assertEqual(instance.tracked_custom_field, '2') self.assertEqual(instance.tracked_regular_field, 2) + + instance = ModelWithCustomDescriptor.objects.only('id').get(pk=instance.pk) + if django.VERSION >= (1, 10): + # This fails on 1.8 and 1.9, which is a bug in the deferred field + # implementation on those versions. + instance.tracked_custom_field = 3 + self.assertEqual(instance.tracked_custom_field, '3') + self.assertTrue(instance.tracker.has_changed('tracked_custom_field')) + del instance.tracked_custom_field + self.assertEqual(instance.tracked_custom_field, '2') + self.assertFalse(instance.tracker.has_changed('tracked_custom_field'))