From 80b099f12918d28c297005fbbb15ed1ebaf10a12 Mon Sep 17 00:00:00 2001 From: Lucas Wiman Date: Tue, 3 Apr 2018 15:43:29 -0700 Subject: [PATCH] Do not override custom descriptors when present. This commit adds a collection of wrapper classes for tracking fields while still using custom descriptors that may be present. This fixes a bug where deferring a model field with a custom descriptor meant that the descriptor was overridden in all subsequent queries. --- model_utils/tracker.py | 77 ++++++++++++++++++++++- tests/test_fields/test_field_tracker.py | 15 ++++- tests/test_models/test_deferred_fields.py | 11 +++- 3 files changed, 95 insertions(+), 8 deletions(-) diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 5da9dd4..095d87e 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -29,12 +29,73 @@ class DescriptorMixin(object): return self.field_name +class DescriptorWrapper(object): + + def __init__(self, field_name, descriptor, tracker_attname): + self.field_name = field_name + self.descriptor = descriptor + self.tracker_attname = tracker_attname + + def __get__(self, instance, owner): + if instance is None: + return self + was_deferred = self.field_name in instance.get_deferred_fields() + if self.descriptor: + value = self.descriptor.__get__(instance, owner) + else: + value = instance.__dict__[self.field_name] + if was_deferred: + tracker_instance = getattr(instance, self.tracker_attname) + tracker_instance.saved_data[self.field_name] = deepcopy(value) + return value + + @staticmethod + def cls_for_descriptor(descriptor): + has_set = hasattr(descriptor, '__set__') + has_del = hasattr(descriptor, '__delete__') + if has_set and has_del: + return FullDescriptorWrapper + elif has_set: + return SettableDescriptorWrapper + elif has_del: + return DeleteableDescriptorWrapper + else: + return DescriptorWrapper + + +class SettableDescriptorWrapper(DescriptorWrapper): + """ + Descriptor wrapper for descriptors with a __delete__ method. + + This should not be used for descriptors + """ + def __set__(self, instance, value): + return self.descriptor.__set__(instance, value) + + +class DeleteableDescriptorWrapper(DescriptorWrapper): + """ + Descriptor wrapper for descriptors with a __delete__ method. + + This should not be used for descriptors + """ + def __delete__(self, instance): + self.descriptor.__delete__(instance) + + +class FullDescriptorWrapper(SettableDescriptorWrapper, DeleteableDescriptorWrapper): + """ + Wrapper for descriptors with all three descriptor methods. + """ + + class FieldInstanceTracker(object): def __init__(self, instance, fields, field_map): self.instance = instance self.fields = fields self.field_map = field_map - self.init_deferred_fields() + if django.VERSION < (1, 10): + self.init_deferred_fields() def get_field_value(self, field): return getattr(self.instance, self.field_map[field]) @@ -54,10 +115,11 @@ class FieldInstanceTracker(object): def current(self, fields=None): """Returns dict of current values for all tracked fields""" if fields is None: - if self.instance._deferred_fields: + deferred_fields = self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields() + if deferred_fields: fields = [ field for field in self.fields - if field not in self.instance._deferred_fields + if field not in deferred_fields ] else: fields = self.fields @@ -135,6 +197,15 @@ class FieldTracker(object): if self.fields is None: self.fields = (field.attname for field in sender._meta.fields) self.fields = set(self.fields) + if django.VERSION >= (1, 10): + for field_name in self.fields: + if django.VERSION >= (1, 10): + descriptor = getattr(sender, field_name) + else: + descriptor = sender.__dict__.get(field_name) + wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) + wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) + setattr(sender, field_name, wrapped_descriptor) self.field_map = self.get_field_map(sender) models.signals.post_init.connect(self.initialize_tracker) self.model_class = sender diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index 00c28b3..b389861 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -181,13 +181,22 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.instance.number = 1 self.instance.save() item = list(self.tracked_class.objects.only('name').all())[0] - self.assertTrue(item._deferred_fields) + if django.VERSION >= (1, 10): + self.assertTrue(item.get_deferred_fields()) + else: + self.assertTrue(item._deferred_fields) self.assertEqual(item.tracker.previous('number'), None) - self.assertTrue('number' in item._deferred_fields) + if django.VERSION >= (1, 10): + self.assertTrue('number' in item.get_deferred_fields()) + else: + self.assertTrue('number' in item._deferred_fields) self.assertEqual(item.number, 1) - self.assertTrue('number' not in item._deferred_fields) + if django.VERSION >= (1, 10): + self.assertTrue('number' not in item.get_deferred_fields()) + else: + self.assertTrue('number' not in item._deferred_fields) self.assertEqual(item.tracker.previous('number'), 1) self.assertFalse(item.tracker.has_changed('number')) diff --git a/tests/test_models/test_deferred_fields.py b/tests/test_models/test_deferred_fields.py index 6a159be..b235843 100644 --- a/tests/test_models/test_deferred_fields.py +++ b/tests/test_models/test_deferred_fields.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import django from django.test import TestCase from tests.models import ModelWithCustomDescriptor @@ -30,9 +31,15 @@ class CustomDescriptorTests(TestCase): def test_deferred(self): instance = ModelWithCustomDescriptor.objects.only('id').get( pk=self.instance.pk) - self.assertIn('custom_field', instance.get_deferred_fields()) + if django.VERSION >= (1, 10): + self.assertIn('custom_field', instance.get_deferred_fields()) + else: + self.assertIn('custom_field', instance._deferred_fields) self.assertEqual(instance.custom_field, '1') - self.assertNotIn('custom_field', instance.get_deferred_fields()) + if django.VERSION >= (1, 10): + self.assertNotIn('custom_field', instance.get_deferred_fields()) + else: + self.assertNotIn('custom_field', instance._deferred_fields) self.assertEqual(instance.regular_field, 1) self.assertEqual(instance.tracked_custom_field, '1') self.assertEqual(instance.tracked_regular_field, 1)