Do not override custom descriptors when present.

This commit adds a collection of wrapper classes for tracking fields
while still using custom descriptors that may be present. This fixes
a bug where deferring a model field with a custom descriptor meant
that the descriptor was overridden in all subsequent queries.
This commit is contained in:
Lucas Wiman 2018-04-03 15:43:29 -07:00 committed by Lucas Wiman
parent be52bc9290
commit 80b099f129
3 changed files with 95 additions and 8 deletions

View file

@ -29,12 +29,73 @@ class DescriptorMixin(object):
return self.field_name
class DescriptorWrapper(object):
def __init__(self, field_name, descriptor, tracker_attname):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname
def __get__(self, instance, owner):
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
if self.descriptor:
value = self.descriptor.__get__(instance, owner)
else:
value = instance.__dict__[self.field_name]
if was_deferred:
tracker_instance = getattr(instance, self.tracker_attname)
tracker_instance.saved_data[self.field_name] = deepcopy(value)
return value
@staticmethod
def cls_for_descriptor(descriptor):
has_set = hasattr(descriptor, '__set__')
has_del = hasattr(descriptor, '__delete__')
if has_set and has_del:
return FullDescriptorWrapper
elif has_set:
return SettableDescriptorWrapper
elif has_del:
return DeleteableDescriptorWrapper
else:
return DescriptorWrapper
class SettableDescriptorWrapper(DescriptorWrapper):
"""
Descriptor wrapper for descriptors with a __delete__ method.
This should not be used for descriptors
"""
def __set__(self, instance, value):
return self.descriptor.__set__(instance, value)
class DeleteableDescriptorWrapper(DescriptorWrapper):
"""
Descriptor wrapper for descriptors with a __delete__ method.
This should not be used for descriptors
"""
def __delete__(self, instance):
self.descriptor.__delete__(instance)
class FullDescriptorWrapper(SettableDescriptorWrapper, DeleteableDescriptorWrapper):
"""
Wrapper for descriptors with all three descriptor methods.
"""
class FieldInstanceTracker(object):
def __init__(self, instance, fields, field_map):
self.instance = instance
self.fields = fields
self.field_map = field_map
self.init_deferred_fields()
if django.VERSION < (1, 10):
self.init_deferred_fields()
def get_field_value(self, field):
return getattr(self.instance, self.field_map[field])
@ -54,10 +115,11 @@ class FieldInstanceTracker(object):
def current(self, fields=None):
"""Returns dict of current values for all tracked fields"""
if fields is None:
if self.instance._deferred_fields:
deferred_fields = self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields()
if deferred_fields:
fields = [
field for field in self.fields
if field not in self.instance._deferred_fields
if field not in deferred_fields
]
else:
fields = self.fields
@ -135,6 +197,15 @@ class FieldTracker(object):
if self.fields is None:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
if django.VERSION >= (1, 10):
for field_name in self.fields:
if django.VERSION >= (1, 10):
descriptor = getattr(sender, field_name)
else:
descriptor = sender.__dict__.get(field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
self.field_map = self.get_field_map(sender)
models.signals.post_init.connect(self.initialize_tracker)
self.model_class = sender

View file

@ -181,13 +181,22 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.number = 1
self.instance.save()
item = list(self.tracked_class.objects.only('name').all())[0]
self.assertTrue(item._deferred_fields)
if django.VERSION >= (1, 10):
self.assertTrue(item.get_deferred_fields())
else:
self.assertTrue(item._deferred_fields)
self.assertEqual(item.tracker.previous('number'), None)
self.assertTrue('number' in item._deferred_fields)
if django.VERSION >= (1, 10):
self.assertTrue('number' in item.get_deferred_fields())
else:
self.assertTrue('number' in item._deferred_fields)
self.assertEqual(item.number, 1)
self.assertTrue('number' not in item._deferred_fields)
if django.VERSION >= (1, 10):
self.assertTrue('number' not in item.get_deferred_fields())
else:
self.assertTrue('number' not in item._deferred_fields)
self.assertEqual(item.tracker.previous('number'), 1)
self.assertFalse(item.tracker.has_changed('number'))

View file

@ -1,5 +1,6 @@
from __future__ import unicode_literals
import django
from django.test import TestCase
from tests.models import ModelWithCustomDescriptor
@ -30,9 +31,15 @@ class CustomDescriptorTests(TestCase):
def test_deferred(self):
instance = ModelWithCustomDescriptor.objects.only('id').get(
pk=self.instance.pk)
self.assertIn('custom_field', instance.get_deferred_fields())
if django.VERSION >= (1, 10):
self.assertIn('custom_field', instance.get_deferred_fields())
else:
self.assertIn('custom_field', instance._deferred_fields)
self.assertEqual(instance.custom_field, '1')
self.assertNotIn('custom_field', instance.get_deferred_fields())
if django.VERSION >= (1, 10):
self.assertNotIn('custom_field', instance.get_deferred_fields())
else:
self.assertNotIn('custom_field', instance._deferred_fields)
self.assertEqual(instance.regular_field, 1)
self.assertEqual(instance.tracked_custom_field, '1')
self.assertEqual(instance.tracked_regular_field, 1)