From be52bc929009b58535de75dfdb40e6aa58bb4836 Mon Sep 17 00:00:00 2001 From: Lucas Wiman Date: Tue, 3 Apr 2018 13:18:03 -0700 Subject: [PATCH] Add failing test for deferred attributes. --- model_utils/tracker.py | 3 +- tests/models.py | 35 +++++++++++++++ tests/test_models/test_deferred_fields.py | 53 +++++++++++++++++++++++ 3 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 tests/test_models/test_deferred_fields.py diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 0fec85d..5da9dd4 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -107,8 +107,7 @@ class FieldInstanceTracker(object): field_tracker = FileDescriptorTracker(field_obj.field) setattr(self.instance.__class__, field, field_tracker) else: - field_tracker = DeferredAttributeTracker( - field_obj.field_name, None) + field_tracker = DeferredAttributeTracker(field, type(self.instance)) setattr(self.instance.__class__, field, field_tracker) diff --git a/tests/models.py b/tests/models.py index a65d499..841f612 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals, absolute_import from django.db import models +from django.db.models.query_utils import DeferredAttribute from django.db.models import Manager from django.utils.encoding import python_2_unicode_compatible from django.utils.translation import ugettext_lazy as _ @@ -331,3 +332,37 @@ class CustomSoftDelete(SoftDeletableModel): is_read = models.BooleanField(default=False) objects = CustomSoftDeleteManager() + + +class StringyDescriptor(object): + """ + Descriptor that returns a string version of the underlying integer value. + """ + def __init__(self, name): + self.name = name + + def __get__(self, obj, cls=None): + if obj is None: + return self + if self.name in obj.get_deferred_fields(): + # This queries the database, and sets the value on the instance. + DeferredAttribute(field_name=self.name, model=cls).__get__(obj, cls) + return str(obj.__dict__[self.name]) + + def __set__(self, obj, value): + obj.__dict__[self.name] = int(value) + + +class CustomDescriptorField(models.IntegerField): + def contribute_to_class(self, cls, name, **kwargs): + super(CustomDescriptorField, self).contribute_to_class(cls, name, **kwargs) + setattr(cls, name, StringyDescriptor(name)) + + +class ModelWithCustomDescriptor(models.Model): + custom_field = CustomDescriptorField() + tracked_custom_field = CustomDescriptorField() + regular_field = models.IntegerField() + tracked_regular_field = models.IntegerField() + + tracker = FieldTracker(fields=['tracked_custom_field', 'tracked_regular_field']) diff --git a/tests/test_models/test_deferred_fields.py b/tests/test_models/test_deferred_fields.py new file mode 100644 index 0000000..6a159be --- /dev/null +++ b/tests/test_models/test_deferred_fields.py @@ -0,0 +1,53 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from tests.models import ModelWithCustomDescriptor + + +class CustomDescriptorTests(TestCase): + def setUp(self): + self.instance = ModelWithCustomDescriptor.objects.create( + custom_field='1', + tracked_custom_field='1', + regular_field=1, + tracked_regular_field=1, + ) + + def test_custom_descriptor_works(self): + instance = self.instance + self.assertEqual(instance.custom_field, '1') + self.assertEqual(instance.__dict__['custom_field'], 1) + self.assertEqual(instance.regular_field, 1) + instance.custom_field = 2 + self.assertEqual(instance.custom_field, '2') + self.assertEqual(instance.__dict__['custom_field'], 2) + instance.save() + intance = ModelWithCustomDescriptor.objects.get(pk=instance.pk) + self.assertEqual(instance.custom_field, '2') + self.assertEqual(instance.__dict__['custom_field'], 2) + + def test_deferred(self): + instance = ModelWithCustomDescriptor.objects.only('id').get( + pk=self.instance.pk) + self.assertIn('custom_field', instance.get_deferred_fields()) + self.assertEqual(instance.custom_field, '1') + self.assertNotIn('custom_field', instance.get_deferred_fields()) + self.assertEqual(instance.regular_field, 1) + self.assertEqual(instance.tracked_custom_field, '1') + self.assertEqual(instance.tracked_regular_field, 1) + + self.assertFalse(instance.tracker.has_changed('tracked_custom_field')) + self.assertFalse(instance.tracker.has_changed('tracked_regular_field')) + + instance.tracked_custom_field = 2 + instance.tracked_regular_field = 2 + self.assertTrue(instance.tracker.has_changed('tracked_custom_field')) + self.assertTrue(instance.tracker.has_changed('tracked_regular_field')) + instance.save() + + instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk) + self.assertEqual(instance.custom_field, '1') + self.assertEqual(instance.regular_field, 1) + self.assertEqual(instance.tracked_custom_field, '2') + self.assertEqual(instance.tracked_regular_field, 2)