Merge pull request #317 from lucaswiman/django-1.11-compatibility

Fix handling of deferred fields on django 1.10+
This commit is contained in:
Lucas Wiman 2018-07-02 11:30:09 -07:00 committed by GitHub
commit 16dec4d12d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 287 additions and 66 deletions

View file

@ -8,7 +8,7 @@ python:
- 3.6
install: pip install tox-travis codecov
# positional args ({posargs}) to pass into tox.ini
script: tox -- --cov
script: tox -- --cov --cov-append
after_success: codecov
deploy:
provider: pypi

View file

@ -44,3 +44,5 @@
| Karl Wan Nan Wo <karl.wnw@gmail.com>
| zyegfryed
| Radosław Jan Ganczarek <radoslaw@ganczarek.in>
| Lucas Wiman <lucas.wiman@gmail.com>
| Jack Cushman <jcushman@law.harvard.edu>

View file

@ -3,13 +3,15 @@ CHANGES
master (unreleased)
-------------------
- Fix handling of deferred attributes on Django 1.10+, fixes GH-278
- Fix `FieldTracker.has_changed()` and `FieldTracker.previous()` to return
correct responses for deferred fields.
3.1.2 (2018.05.09)
------------------
* Update InheritanceIterable to inherit from
ModelIterable instead of BaseIterable, fixes GH-277.
3.1.1 (2017.12.17)
------------------

View file

@ -150,6 +150,10 @@ Returns the value of the given field during the last save:
Returns ``None`` when the model instance isn't saved yet.
If a field is `deferred`_, calling ``previous()`` will load the previous value from the database.
.. _deferred: https://docs.djangoproject.com/en/2.0/ref/models/querysets/#defer
has_changed
~~~~~~~~~~~
@ -167,6 +171,8 @@ Returns ``True`` if the given field has changed since the last save. The ``has_c
The ``has_changed`` method relies on ``previous`` to determine whether a
field's values has changed.
If a field is `deferred`_ and has been assigned locally, calling ``has_changed()``
will load the previous value from the database to perform the comparison.
changed
~~~~~~~

View file

@ -48,11 +48,10 @@ class InheritanceIterable(ModelIterable):
class InheritanceQuerySetMixin(object):
def __init__(self, *args, **kwargs):
super(InheritanceQuerySetMixin, self).__init__(*args, **kwargs)
if django.VERSION > (1, 8):
self._iterable_class = InheritanceIterable
self._iterable_class = InheritanceIterable
def select_subclasses(self, *subclasses):
levels = self._get_maximum_depth()
levels = None
calculated_subclasses = self._get_subclasses_recurse(
self.model, levels=levels)
# if none were passed in, we can just short circuit and select all
@ -151,12 +150,9 @@ class InheritanceQuerySetMixin(object):
recursively, returning a `list` of strings representing the
relations for select_related
"""
if django.VERSION < (1, 8):
related_objects = model._meta.get_all_related_objects()
else:
related_objects = [
f for f in model._meta.get_fields()
if isinstance(f, OneToOneRel)]
related_objects = [
f for f in model._meta.get_fields()
if isinstance(f, OneToOneRel)]
rels = [
rel for rel in related_objects
@ -199,10 +195,7 @@ class InheritanceQuerySetMixin(object):
related = parent_link.remote_field
ancestry.insert(0, related.get_accessor_name())
if levels or levels is None:
if django.VERSION < (1, 8):
parent_model = related.parent_model
else:
parent_model = related.model
parent_model = related.model
parent_link = parent_model._meta.get_ancestor_link(
self.model)
else:
@ -230,17 +223,6 @@ class InheritanceQuerySetMixin(object):
def get_subclass(self, *args, **kwargs):
return self.select_subclasses().get(*args, **kwargs)
def _get_maximum_depth(self):
"""
Under Django versions < 1.6, to avoid triggering
https://code.djangoproject.com/ticket/16572 we can only look
as far as children.
"""
levels = None
if django.VERSION < (1, 6, 0):
levels = 1
return levels
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
pass

View file

@ -29,12 +29,73 @@ class DescriptorMixin(object):
return self.field_name
class DescriptorWrapper(object):
def __init__(self, field_name, descriptor, tracker_attname):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname
def __get__(self, instance, owner):
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
value = self.descriptor.__get__(instance, owner)
if was_deferred:
tracker_instance = getattr(instance, self.tracker_attname)
tracker_instance.saved_data[self.field_name] = deepcopy(value)
return value
def __set__(self, instance, value):
initialized = hasattr(instance, '_instance_intialized')
was_deferred = self.field_name in instance.get_deferred_fields()
# Sentinel attribute to detect whether we are already trying to
# set the attribute higher up the stack. This prevents infinite
# recursion when retrieving deferred values from the database.
recursion_sentinel_attname = '_setting_' + self.field_name
already_setting = hasattr(instance, recursion_sentinel_attname)
if initialized and was_deferred and not already_setting:
setattr(instance, recursion_sentinel_attname, True)
try:
# Retrieve the value to set the saved_data value.
# This will undefer the field
getattr(instance, self.field_name)
finally:
instance.__dict__.pop(recursion_sentinel_attname, None)
if hasattr(self.descriptor, '__set__'):
self.descriptor.__set__(instance, value)
else:
instance.__dict__[self.field_name] = value
@staticmethod
def cls_for_descriptor(descriptor):
if hasattr(descriptor, '__delete__'):
return FullDescriptorWrapper
else:
return DescriptorWrapper
class FullDescriptorWrapper(DescriptorWrapper):
"""
Wrapper for descriptors with all three descriptor methods.
"""
def __delete__(self, obj):
self.descriptor.__delete__(obj)
class FieldInstanceTracker(object):
def __init__(self, instance, fields, field_map):
self.instance = instance
self.fields = fields
self.field_map = field_map
self.init_deferred_fields()
if django.VERSION < (1, 10):
self.init_deferred_fields()
@property
def deferred_fields(self):
return self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields()
def get_field_value(self, field):
return getattr(self.instance, self.field_map[field])
@ -54,10 +115,11 @@ class FieldInstanceTracker(object):
def current(self, fields=None):
"""Returns dict of current values for all tracked fields"""
if fields is None:
if self.instance._deferred_fields:
deferred_fields = self.deferred_fields
if deferred_fields:
fields = [
field for field in self.fields
if field not in self.instance._deferred_fields
if field not in deferred_fields
]
else:
fields = self.fields
@ -67,12 +129,31 @@ class FieldInstanceTracker(object):
def has_changed(self, field):
"""Returns ``True`` if field has changed from currently saved value"""
if field in self.fields:
# deferred fields haven't changed
if field in self.deferred_fields and field not in self.instance.__dict__:
return False
return self.previous(field) != self.get_field_value(field)
else:
raise FieldError('field "%s" not tracked' % field)
def previous(self, field):
"""Returns currently saved value of given field"""
# handle deferred fields that have not yet been loaded from the database
if self.instance.pk and field in self.deferred_fields and field not in self.saved_data:
# if the field has not been assigned locally, simply fetch and un-defer the value
if field not in self.instance.__dict__:
self.get_field_value(field)
# if the field has been assigned locally, store the local value, fetch the database value,
# store database value to saved_data, and restore the local value
else:
current_value = self.get_field_value(field)
self.instance.refresh_from_db(fields=[field])
self.saved_data[field] = deepcopy(self.get_field_value(field))
setattr(self.instance, self.field_map[field], current_value)
return self.saved_data.get(field)
def changed(self):
@ -97,35 +178,15 @@ class FieldInstanceTracker(object):
def _get_field_name(self):
return self.field.name
if django.VERSION >= (1, 8):
self.instance._deferred_fields = self.instance.get_deferred_fields()
for field in self.instance._deferred_fields:
if django.VERSION >= (1, 10):
field_obj = getattr(self.instance.__class__, field)
else:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, FileDescriptor):
field_tracker = FileDescriptorTracker(field_obj.field)
setattr(self.instance.__class__, field, field_tracker)
else:
field_tracker = DeferredAttributeTracker(
field_obj.field_name, None)
setattr(self.instance.__class__, field, field_tracker)
else:
for field in self.fields:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, DeferredAttribute):
self.instance._deferred_fields.add(field)
# Django 1.4
if django.VERSION >= (1, 5):
model = None
else:
model = field_obj.model_ref()
field_tracker = DeferredAttributeTracker(
field_obj.field_name, model)
setattr(self.instance.__class__, field, field_tracker)
self.instance._deferred_fields = self.instance.get_deferred_fields()
for field in self.instance._deferred_fields:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, FileDescriptor):
field_tracker = FileDescriptorTracker(field_obj.field)
setattr(self.instance.__class__, field, field_tracker)
else:
field_tracker = DeferredAttributeTracker(field, type(self.instance))
setattr(self.instance.__class__, field, field_tracker)
class FieldTracker(object):
@ -152,6 +213,12 @@ class FieldTracker(object):
if self.fields is None:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
if django.VERSION >= (1, 10):
for field_name in self.fields:
descriptor = getattr(sender, field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
self.field_map = self.get_field_map(sender)
models.signals.post_init.connect(self.initialize_tracker)
self.model_class = sender
@ -164,6 +231,7 @@ class FieldTracker(object):
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
self.patch_save(instance)
instance._instance_intialized = True
def patch_save(self, instance):
original_save = instance.save

View file

@ -1,6 +1,8 @@
from __future__ import unicode_literals, absolute_import
import django
from django.db import models
from django.db.models.query_utils import DeferredAttribute
from django.db.models import Manager
from django.utils.encoding import python_2_unicode_compatible
from django.utils.translation import ugettext_lazy as _
@ -331,3 +333,43 @@ class CustomSoftDelete(SoftDeletableModel):
is_read = models.BooleanField(default=False)
objects = CustomSoftDeleteManager()
class StringyDescriptor(object):
"""
Descriptor that returns a string version of the underlying integer value.
"""
def __init__(self, name):
self.name = name
def __get__(self, obj, cls=None):
if obj is None:
return self
if self.name in obj.get_deferred_fields():
# This queries the database, and sets the value on the instance.
if django.VERSION < (2, 1):
DeferredAttribute(field_name=self.name, model=cls).__get__(obj, cls)
else:
DeferredAttribute(field_name=self.name).__get__(obj, cls)
return str(obj.__dict__[self.name])
def __set__(self, obj, value):
obj.__dict__[self.name] = int(value)
def __delete__(self, obj):
del obj.__dict__[self.name]
class CustomDescriptorField(models.IntegerField):
def contribute_to_class(self, cls, name, **kwargs):
super(CustomDescriptorField, self).contribute_to_class(cls, name, **kwargs)
setattr(cls, name, StringyDescriptor(name))
class ModelWithCustomDescriptor(models.Model):
custom_field = CustomDescriptorField()
tracked_custom_field = CustomDescriptorField()
regular_field = models.IntegerField()
tracked_regular_field = models.IntegerField()
tracker = FieldTracker(fields=['tracked_custom_field', 'tracked_regular_field'])

View file

@ -7,6 +7,7 @@ from django.core.exceptions import FieldError
from django.test import TestCase
from model_utils import FieldTracker
from model_utils.tracker import DescriptorWrapper
from tests.models import (
Tracked, TrackedFK, InheritedTrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple,
InheritedTracked, TrackedFileField,
@ -180,20 +181,66 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.name = 'new age'
self.instance.number = 1
self.instance.save()
item = list(self.tracked_class.objects.only('name').all())[0]
self.assertTrue(item._deferred_fields)
item = self.tracked_class.objects.only('name').first()
if django.VERSION >= (1, 10):
self.assertTrue(item.get_deferred_fields())
else:
self.assertTrue(item._deferred_fields)
self.assertEqual(item.tracker.previous('number'), None)
self.assertTrue('number' in item._deferred_fields)
# has_changed() returns False for deferred fields, without un-deferring them.
# Use an if because ModelTracked doesn't support has_changed() in this case.
if self.tracked_class == Tracked:
self.assertFalse(item.tracker.has_changed('number'))
if django.VERSION >= (1, 10):
self.assertIsInstance(item.__class__.number, DescriptorWrapper)
self.assertTrue('number' in item.get_deferred_fields())
else:
self.assertTrue('number' in item._deferred_fields)
# previous() un-defers field and returns value
self.assertEqual(item.tracker.previous('number'), 1)
if django.VERSION >= (1, 10):
self.assertNotIn('number', item.get_deferred_fields())
else:
self.assertNotIn('number', item._deferred_fields)
# examining a deferred field un-defers it
item = self.tracked_class.objects.only('name').first()
self.assertEqual(item.number, 1)
self.assertTrue('number' not in item._deferred_fields)
if django.VERSION >= (1, 10):
self.assertTrue('number' not in item.get_deferred_fields())
else:
self.assertTrue('number' not in item._deferred_fields)
self.assertEqual(item.tracker.previous('number'), 1)
self.assertFalse(item.tracker.has_changed('number'))
# has_changed() returns correct values after deferred field is examined
self.assertFalse(item.tracker.has_changed('number'))
item.number = 2
self.assertTrue(item.tracker.has_changed('number'))
# previous() returns correct value after deferred field is examined
self.assertEqual(item.tracker.previous('number'), 1)
# assigning to a deferred field un-defers it
# Use an if because ModelTracked doesn't handle this case.
if self.tracked_class == Tracked:
item = self.tracked_class.objects.only('name').first()
item.number = 2
# previous() fetches correct value from database after deferred field is assigned
self.assertEqual(item.tracker.previous('number'), 1)
# database fetch of previous() value doesn't affect current value
self.assertEqual(item.number, 2)
# has_changed() returns correct values after deferred field is assigned
self.assertTrue(item.tracker.has_changed('number'))
item.number = 1
self.assertFalse(item.tracker.has_changed('number'))
class FieldTrackerMultipleInstancesTests(TestCase):

View file

@ -115,9 +115,6 @@ class InheritanceManagerTests(TestCase):
"inheritancemanagertestchild2").get(pk=self.child1.pk)
obj.inheritancemanagertestchild1
def test_version_determining_any_depth(self):
self.assertIsNone(self.get_manager().all()._get_maximum_depth())
def test_manually_specifying_parent_fk_including_grandchildren(self):
"""
given a Model which inherits from another Model, but also declares

View file

@ -0,0 +1,71 @@
from __future__ import unicode_literals
import django
from django.test import TestCase
from tests.models import ModelWithCustomDescriptor
class CustomDescriptorTests(TestCase):
def setUp(self):
self.instance = ModelWithCustomDescriptor.objects.create(
custom_field='1',
tracked_custom_field='1',
regular_field=1,
tracked_regular_field=1,
)
def test_custom_descriptor_works(self):
instance = self.instance
self.assertEqual(instance.custom_field, '1')
self.assertEqual(instance.__dict__['custom_field'], 1)
self.assertEqual(instance.regular_field, 1)
instance.custom_field = 2
self.assertEqual(instance.custom_field, '2')
self.assertEqual(instance.__dict__['custom_field'], 2)
instance.save()
intance = ModelWithCustomDescriptor.objects.get(pk=instance.pk)
self.assertEqual(instance.custom_field, '2')
self.assertEqual(instance.__dict__['custom_field'], 2)
def test_deferred(self):
instance = ModelWithCustomDescriptor.objects.only('id').get(
pk=self.instance.pk)
if django.VERSION >= (1, 10):
self.assertIn('custom_field', instance.get_deferred_fields())
else:
self.assertIn('custom_field', instance._deferred_fields)
self.assertEqual(instance.custom_field, '1')
if django.VERSION >= (1, 10):
self.assertNotIn('custom_field', instance.get_deferred_fields())
else:
self.assertNotIn('custom_field', instance._deferred_fields)
self.assertEqual(instance.regular_field, 1)
self.assertEqual(instance.tracked_custom_field, '1')
self.assertEqual(instance.tracked_regular_field, 1)
self.assertFalse(instance.tracker.has_changed('tracked_custom_field'))
self.assertFalse(instance.tracker.has_changed('tracked_regular_field'))
instance.tracked_custom_field = 2
instance.tracked_regular_field = 2
self.assertTrue(instance.tracker.has_changed('tracked_custom_field'))
self.assertTrue(instance.tracker.has_changed('tracked_regular_field'))
instance.save()
instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk)
self.assertEqual(instance.custom_field, '1')
self.assertEqual(instance.regular_field, 1)
self.assertEqual(instance.tracked_custom_field, '2')
self.assertEqual(instance.tracked_regular_field, 2)
instance = ModelWithCustomDescriptor.objects.only('id').get(pk=instance.pk)
if django.VERSION >= (1, 10):
# This fails on 1.8 and 1.9, which is a bug in the deferred field
# implementation on those versions.
instance.tracked_custom_field = 3
self.assertEqual(instance.tracked_custom_field, '3')
self.assertTrue(instance.tracker.has_changed('tracked_custom_field'))
del instance.tracked_custom_field
self.assertEqual(instance.tracked_custom_field, '2')
self.assertFalse(instance.tracker.has_changed('tracked_custom_field'))

View file

@ -18,6 +18,10 @@ deps =
pytest-cov
ignore_outcome =
djangotrunk: True
passenv =
CI
TRAVIS
TRAVIS_*
commands =
pip install -e .