From 9ee6065f811cc27a65e2dfcd5f5b6c9857ba00fb Mon Sep 17 00:00:00 2001 From: Alexey Evseev Date: Mon, 5 Sep 2016 17:51:48 +0300 Subject: [PATCH] Support Django 1.10 deferred FileField with FieldTracker --- AUTHORS.rst | 1 + CHANGES.rst | 2 + model_utils/tests/models.py | 7 +++ model_utils/tests/tests.py | 106 +++++++++++++++++++++++++++++++++++- model_utils/tracker.py | 50 ++++++++++++----- 5 files changed, 151 insertions(+), 15 deletions(-) diff --git a/AUTHORS.rst b/AUTHORS.rst index 668453e..1ba4b8f 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -1,6 +1,7 @@ ad-m Alejandro Varas Alex Orange +Alexey Evseev Andy Freeland Artis Avotins Bram Boogaard diff --git a/CHANGES.rst b/CHANGES.rst index f9b0a37..131a7fb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,8 @@ CHANGES master (unreleased) ------------------- +* Fix issue with field tracker and deferred FileField for Django 1.10. + 2.5.2 (2016.08.09) ------------------ diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index c77c034..c086838 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -242,6 +242,12 @@ class TrackedMultiple(models.Model): number_tracker = FieldTracker(fields=['number']) +class TrackedFileField(models.Model): + some_file = models.FileField(upload_to='test_location') + + tracker = FieldTracker() + + class InheritedTracked(Tracked): name2 = models.CharField(max_length=20) @@ -281,6 +287,7 @@ class ModelTrackedMultiple(models.Model): name_tracker = ModelTracker(fields=['name']) number_tracker = ModelTracker(fields=['number']) + class InheritedModelTracked(ModelTracked): name2 = models.CharField(max_length=20) diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 14b6329..719d185 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -27,7 +27,7 @@ from model_utils.tests.models import ( TimeFrameManagerAdded, SplitFieldAbstractParent, ModelTracked, ModelTrackedFK, ModelTrackedNotDefault, ModelTrackedMultiple, InheritedModelTracked, Tracked, TrackedFK, InheritedTrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple, - InheritedTracked, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled, + InheritedTracked, TrackedFileField, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled, InheritanceManagerTestChild3, StatusFieldChoicesName) @@ -1735,6 +1735,110 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests): tracked_class = InheritedTrackedFK +class FieldTrackerFileFieldTests(FieldTrackerTestCase): + + tracked_class = TrackedFileField + + def setUp(self): + self.instance = self.tracked_class() + self.tracker = self.instance.tracker + self.some_file = 'something.txt' + self.another_file = 'another.txt' + + def test_pre_save_changed(self): + self.assertChanged(some_file=None) + self.instance.some_file = self.some_file + self.assertChanged(some_file=None) + + def test_pre_save_has_changed(self): + self.assertHasChanged(some_file=True) + self.instance.some_file = self.some_file + self.assertHasChanged(some_file=True) + + def test_pre_save_previous(self): + self.assertPrevious(some_file=None) + self.instance.some_file = self.some_file + self.assertPrevious(some_file=None) + + def test_post_save_changed(self): + self.update_instance(some_file=self.some_file) + self.assertChanged() + previous_file = self.instance.some_file + self.instance.some_file = self.another_file + self.assertChanged(some_file=previous_file) + # test deferred file field + deferred_instance = self.tracked_class.objects.defer('some_file')[0] + deferred_instance.some_file # access field to fetch from database + self.assertChanged(tracker=deferred_instance.tracker) + + previous_file = deferred_instance.some_file + deferred_instance.some_file = self.another_file + self.assertChanged( + tracker=deferred_instance.tracker, + some_file=previous_file, + ) + + def test_post_save_has_changed(self): + self.update_instance(some_file=self.some_file) + self.assertHasChanged(some_file=False) + self.instance.some_file = self.another_file + self.assertHasChanged(some_file=True) + + # test deferred file field + deferred_instance = self.tracked_class.objects.defer('some_file')[0] + deferred_instance.some_file # access field to fetch from database + self.assertHasChanged( + tracker=deferred_instance.tracker, + some_file=False, + ) + + deferred_instance.some_file = self.another_file + self.assertHasChanged( + tracker=deferred_instance.tracker, + some_file=True, + ) + + def test_post_save_previous(self): + self.update_instance(some_file=self.some_file) + previous_file = self.instance.some_file + self.instance.some_file = self.another_file + self.assertPrevious(some_file=previous_file) + + # test deferred file field + deferred_instance = self.tracked_class.objects.defer('some_file')[0] + deferred_instance.some_file # access field to fetch from database + self.assertPrevious( + tracker=deferred_instance.tracker, + some_file=previous_file, + ) + + deferred_instance.some_file = self.another_file + self.assertPrevious( + tracker=deferred_instance.tracker, + some_file=previous_file, + ) + + def test_current(self): + self.assertCurrent(some_file=self.instance.some_file, id=None) + self.instance.some_file = self.some_file + self.assertCurrent(some_file=self.instance.some_file, id=None) + + # test deferred file field + self.instance.save() + deferred_instance = self.tracked_class.objects.defer('some_file')[0] + deferred_instance.some_file # access field to fetch from database + self.assertCurrent( + some_file=self.instance.some_file, + id=self.instance.id, + ) + + self.instance.some_file = self.another_file + self.assertCurrent( + some_file=self.instance.some_file, + id=self.instance.id, + ) + + class ModelTrackerTests(FieldTrackerTests): tracked_class = ModelTracked diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 6aa7d8a..93e4a5d 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -5,9 +5,30 @@ from copy import deepcopy import django from django.core.exceptions import FieldError from django.db import models +from django.db.models.fields.files import FileDescriptor from django.db.models.query_utils import DeferredAttribute +class DescriptorMixin(object): + tracker_instance = None + + def __get__(self, instance, owner): + if instance is None: + return self + was_deferred = False + field_name = self._get_field_name() + if field_name in instance._deferred_fields: + instance._deferred_fields.remove(field_name) + was_deferred = True + value = super(DescriptorMixin, self).__get__(instance, owner) + if was_deferred: + self.tracker_instance.saved_data[field_name] = deepcopy(value) + return value + + def _get_field_name(self): + return self.field_name + + class FieldInstanceTracker(object): def __init__(self, instance, fields, field_map): self.instance = instance @@ -67,17 +88,14 @@ class FieldInstanceTracker(object): if hasattr(self.instance, '_deferred') and not self.instance._deferred: return - class DeferredAttributeTracker(DeferredAttribute): - def __get__(field, instance, owner): - if instance is None: - return field - data = instance.__dict__ - if data.get(field.field_name, field) is field: - instance._deferred_fields.remove(field.field_name) - value = super(DeferredAttributeTracker, field).__get__( - instance, owner) - self.saved_data[field.field_name] = deepcopy(value) - return data[field.field_name] + class DeferredAttributeTracker(DescriptorMixin, DeferredAttribute): + tracker_instance = self + + class FileDescriptorTracker(DescriptorMixin, FileDescriptor): + tracker_instance = self + + def _get_field_name(self): + return self.field.name if django.VERSION >= (1, 8): self.instance._deferred_fields = self.instance.get_deferred_fields() @@ -86,9 +104,13 @@ class FieldInstanceTracker(object): field_obj = getattr(self.instance.__class__, field) else: field_obj = self.instance.__class__.__dict__.get(field) - field_tracker = DeferredAttributeTracker( - field_obj.field_name, None) - setattr(self.instance.__class__, field, field_tracker) + if isinstance(field_obj, FileDescriptor): + field_tracker = FileDescriptorTracker(field_obj.field) + setattr(self.instance.__class__, field, field_tracker) + else: + field_tracker = DeferredAttributeTracker( + field_obj.field_name, None) + setattr(self.instance.__class__, field, field_tracker) else: for field in self.fields: field_obj = self.instance.__class__.__dict__.get(field)