mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Do not override custom descriptors when present.
This commit adds a collection of wrapper classes for tracking fields while still using custom descriptors that may be present. This fixes a bug where deferring a model field with a custom descriptor meant that the descriptor was overridden in all subsequent queries.
This commit is contained in:
parent
be52bc9290
commit
80b099f129
3 changed files with 95 additions and 8 deletions
|
|
@ -29,12 +29,73 @@ class DescriptorMixin(object):
|
|||
return self.field_name
|
||||
|
||||
|
||||
class DescriptorWrapper(object):
|
||||
|
||||
def __init__(self, field_name, descriptor, tracker_attname):
|
||||
self.field_name = field_name
|
||||
self.descriptor = descriptor
|
||||
self.tracker_attname = tracker_attname
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||
if self.descriptor:
|
||||
value = self.descriptor.__get__(instance, owner)
|
||||
else:
|
||||
value = instance.__dict__[self.field_name]
|
||||
if was_deferred:
|
||||
tracker_instance = getattr(instance, self.tracker_attname)
|
||||
tracker_instance.saved_data[self.field_name] = deepcopy(value)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def cls_for_descriptor(descriptor):
|
||||
has_set = hasattr(descriptor, '__set__')
|
||||
has_del = hasattr(descriptor, '__delete__')
|
||||
if has_set and has_del:
|
||||
return FullDescriptorWrapper
|
||||
elif has_set:
|
||||
return SettableDescriptorWrapper
|
||||
elif has_del:
|
||||
return DeleteableDescriptorWrapper
|
||||
else:
|
||||
return DescriptorWrapper
|
||||
|
||||
|
||||
class SettableDescriptorWrapper(DescriptorWrapper):
|
||||
"""
|
||||
Descriptor wrapper for descriptors with a __delete__ method.
|
||||
|
||||
This should not be used for descriptors
|
||||
"""
|
||||
def __set__(self, instance, value):
|
||||
return self.descriptor.__set__(instance, value)
|
||||
|
||||
|
||||
class DeleteableDescriptorWrapper(DescriptorWrapper):
|
||||
"""
|
||||
Descriptor wrapper for descriptors with a __delete__ method.
|
||||
|
||||
This should not be used for descriptors
|
||||
"""
|
||||
def __delete__(self, instance):
|
||||
self.descriptor.__delete__(instance)
|
||||
|
||||
|
||||
class FullDescriptorWrapper(SettableDescriptorWrapper, DeleteableDescriptorWrapper):
|
||||
"""
|
||||
Wrapper for descriptors with all three descriptor methods.
|
||||
"""
|
||||
|
||||
|
||||
class FieldInstanceTracker(object):
|
||||
def __init__(self, instance, fields, field_map):
|
||||
self.instance = instance
|
||||
self.fields = fields
|
||||
self.field_map = field_map
|
||||
self.init_deferred_fields()
|
||||
if django.VERSION < (1, 10):
|
||||
self.init_deferred_fields()
|
||||
|
||||
def get_field_value(self, field):
|
||||
return getattr(self.instance, self.field_map[field])
|
||||
|
|
@ -54,10 +115,11 @@ class FieldInstanceTracker(object):
|
|||
def current(self, fields=None):
|
||||
"""Returns dict of current values for all tracked fields"""
|
||||
if fields is None:
|
||||
if self.instance._deferred_fields:
|
||||
deferred_fields = self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields()
|
||||
if deferred_fields:
|
||||
fields = [
|
||||
field for field in self.fields
|
||||
if field not in self.instance._deferred_fields
|
||||
if field not in deferred_fields
|
||||
]
|
||||
else:
|
||||
fields = self.fields
|
||||
|
|
@ -135,6 +197,15 @@ class FieldTracker(object):
|
|||
if self.fields is None:
|
||||
self.fields = (field.attname for field in sender._meta.fields)
|
||||
self.fields = set(self.fields)
|
||||
if django.VERSION >= (1, 10):
|
||||
for field_name in self.fields:
|
||||
if django.VERSION >= (1, 10):
|
||||
descriptor = getattr(sender, field_name)
|
||||
else:
|
||||
descriptor = sender.__dict__.get(field_name)
|
||||
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
|
||||
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
|
||||
setattr(sender, field_name, wrapped_descriptor)
|
||||
self.field_map = self.get_field_map(sender)
|
||||
models.signals.post_init.connect(self.initialize_tracker)
|
||||
self.model_class = sender
|
||||
|
|
|
|||
|
|
@ -181,13 +181,22 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.number = 1
|
||||
self.instance.save()
|
||||
item = list(self.tracked_class.objects.only('name').all())[0]
|
||||
self.assertTrue(item._deferred_fields)
|
||||
if django.VERSION >= (1, 10):
|
||||
self.assertTrue(item.get_deferred_fields())
|
||||
else:
|
||||
self.assertTrue(item._deferred_fields)
|
||||
|
||||
self.assertEqual(item.tracker.previous('number'), None)
|
||||
self.assertTrue('number' in item._deferred_fields)
|
||||
if django.VERSION >= (1, 10):
|
||||
self.assertTrue('number' in item.get_deferred_fields())
|
||||
else:
|
||||
self.assertTrue('number' in item._deferred_fields)
|
||||
|
||||
self.assertEqual(item.number, 1)
|
||||
self.assertTrue('number' not in item._deferred_fields)
|
||||
if django.VERSION >= (1, 10):
|
||||
self.assertTrue('number' not in item.get_deferred_fields())
|
||||
else:
|
||||
self.assertTrue('number' not in item._deferred_fields)
|
||||
self.assertEqual(item.tracker.previous('number'), 1)
|
||||
self.assertFalse(item.tracker.has_changed('number'))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import django
|
||||
from django.test import TestCase
|
||||
|
||||
from tests.models import ModelWithCustomDescriptor
|
||||
|
|
@ -30,9 +31,15 @@ class CustomDescriptorTests(TestCase):
|
|||
def test_deferred(self):
|
||||
instance = ModelWithCustomDescriptor.objects.only('id').get(
|
||||
pk=self.instance.pk)
|
||||
self.assertIn('custom_field', instance.get_deferred_fields())
|
||||
if django.VERSION >= (1, 10):
|
||||
self.assertIn('custom_field', instance.get_deferred_fields())
|
||||
else:
|
||||
self.assertIn('custom_field', instance._deferred_fields)
|
||||
self.assertEqual(instance.custom_field, '1')
|
||||
self.assertNotIn('custom_field', instance.get_deferred_fields())
|
||||
if django.VERSION >= (1, 10):
|
||||
self.assertNotIn('custom_field', instance.get_deferred_fields())
|
||||
else:
|
||||
self.assertNotIn('custom_field', instance._deferred_fields)
|
||||
self.assertEqual(instance.regular_field, 1)
|
||||
self.assertEqual(instance.tracked_custom_field, '1')
|
||||
self.assertEqual(instance.tracked_regular_field, 1)
|
||||
|
|
|
|||
Loading…
Reference in a new issue