diff --git a/model_utils/__init__.py b/model_utils/__init__.py index 0281dbb..731e9b7 100644 --- a/model_utils/__init__.py +++ b/model_utils/__init__.py @@ -1,2 +1,2 @@ from .choices import Choices -from .tracker import ModelTracker +from .tracker import FieldTracker, ModelTracker diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index ce2109d..5032d82 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -2,7 +2,7 @@ from django.db import models from django.utils.translation import ugettext_lazy as _ from model_utils.models import TimeStampedModel, StatusModel, TimeFramedModel -from model_utils.tracker import ModelTracker +from model_utils.tracker import FieldTracker, ModelTracker from model_utils.managers import QueryManager, InheritanceManager, PassThroughManager from model_utils.fields import SplitField, MonitorField, StatusField from model_utils import Choices @@ -225,28 +225,58 @@ class Tracked(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() - tracker = ModelTracker() + tracker = FieldTracker() class TrackedFK(models.Model): fk = models.ForeignKey('Tracked') - tracker = ModelTracker() - custom_tracker = ModelTracker(fields=['fk_id']) - custom_tracker_without_id = ModelTracker(fields=['fk']) + tracker = FieldTracker() + custom_tracker = FieldTracker(fields=['fk_id']) + custom_tracker_without_id = FieldTracker(fields=['fk']) class TrackedNotDefault(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() - name_tracker = ModelTracker(fields=['name']) + name_tracker = FieldTracker(fields=['name']) class TrackedMultiple(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() + name_tracker = FieldTracker(fields=['name']) + number_tracker = FieldTracker(fields=['number']) + + +class ModelTracked(models.Model): + name = models.CharField(max_length=20) + number = models.IntegerField() + + tracker = ModelTracker() + + +class ModelTrackedFK(models.Model): + fk = models.ForeignKey('ModelTracked') + + tracker = ModelTracker() + custom_tracker = ModelTracker(fields=['fk_id']) + custom_tracker_without_id = ModelTracker(fields=['fk']) + + +class ModelTrackedNotDefault(models.Model): + name = models.CharField(max_length=20) + number = models.IntegerField() + + name_tracker = ModelTracker(fields=['name']) + + +class ModelTrackedMultiple(models.Model): + name = models.CharField(max_length=20) + number = models.IntegerField() + name_tracker = ModelTracker(fields=['name']) number_tracker = ModelTracker(fields=['number']) diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 5be08c0..c841590 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -10,7 +10,7 @@ from django.utils.six import text_type from django.core.exceptions import ImproperlyConfigured, FieldError from django.test import TestCase -from model_utils import Choices, ModelTracker +from model_utils import Choices, FieldTracker, ModelTracker from model_utils.fields import get_excerpt, MonitorField, StatusField from model_utils.managers import QueryManager from model_utils.models import StatusModel, TimeFramedModel @@ -20,6 +20,7 @@ from model_utils.tests.models import ( InheritanceManagerTestChild2, TimeStamp, Post, Article, Status, StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded, TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot, + ModelTracked, ModelTrackedFK, ModelTrackedNotDefault, ModelTrackedMultiple, Tracked, TrackedFK, TrackedNotDefault, TrackedMultiple, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled) @@ -652,7 +653,7 @@ class CreatePassThroughManagerTests(TestCase): self.dude.spots_owned.create(name='The Crib', closed=True, secure=True) -class ModelTrackerTestCase(TestCase): +class FieldTrackerTestCase(TestCase): tracker = None @@ -683,7 +684,7 @@ class ModelTrackerTestCase(TestCase): self.instance.save() -class ModelTrackerCommonTests(object): +class FieldTrackerCommonTests(object): def test_pre_save_has_changed(self): self.assertHasChanged(name=True, number=True) @@ -706,13 +707,16 @@ class ModelTrackerCommonTests(object): self.assertPrevious(name=None, number=None) -class ModelTrackerTests(ModelTrackerTestCase, ModelTrackerCommonTests): +class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): + + tracked_class = Tracked + def setUp(self): - self.instance = Tracked() + self.instance = self.tracked_class() self.tracker = self.instance.tracker def test_descriptor(self): - self.assertTrue(isinstance(Tracked.tracker, ModelTracker)) + self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker)) def test_first_save(self): self.assertHasChanged(name=True, number=True) @@ -780,21 +784,24 @@ class ModelTrackerTests(ModelTrackerTestCase, ModelTrackerCommonTests): self.instance.save(update_fields=[]) self.assertChanged(name='retro', number=4) self.instance.save(update_fields=['name']) - in_db = Tracked.objects.get(id=self.instance.id) + in_db = self.tracked_class.objects.get(id=self.instance.id) self.assertEqual(in_db.name, self.instance.name) self.assertNotEqual(in_db.number, self.instance.number) self.assertChanged(number=4) self.instance.save(update_fields=['number']) self.assertChanged() - in_db = Tracked.objects.get(id=self.instance.id) + in_db = self.tracked_class.objects.get(id=self.instance.id) self.assertEqual(in_db.name, self.instance.name) self.assertEqual(in_db.number, self.instance.number) -class FieldTrackedModelCustomTests(ModelTrackerTestCase, - ModelTrackerCommonTests): +class FieldTrackedModelCustomTests(FieldTrackerTestCase, + FieldTrackerCommonTests): + + tracked_class = TrackedNotDefault + def setUp(self): - self.instance = TrackedNotDefault() + self.instance = self.tracked_class() self.tracker = self.instance.name_tracker def test_post_save_has_changed(self): @@ -832,10 +839,13 @@ class FieldTrackedModelCustomTests(ModelTrackerTestCase, self.assertCurrent(name='new age') -class FieldTrackedModelMultiTests(ModelTrackerTestCase, - ModelTrackerCommonTests): +class FieldTrackedModelMultiTests(FieldTrackerTestCase, + FieldTrackerCommonTests): + + tracked_class = TrackedMultiple + def setUp(self): - self.instance = TrackedMultiple() + self.instance = self.tracked_class() self.trackers = [self.instance.name_tracker, self.instance.number_tracker] @@ -905,17 +915,21 @@ class FieldTrackedModelMultiTests(ModelTrackerTestCase, self.assertCurrent(tracker=self.trackers[1], number=8) -class ModelTrackerForeignKeyTests(ModelTrackerTestCase): +class FieldTrackerForeignKeyTests(FieldTrackerTestCase): + + fk_class = Tracked + tracked_class = TrackedFK + def setUp(self): - self.old_fk = Tracked.objects.create(number=8) - self.instance = TrackedFK.objects.create(fk=self.old_fk) + self.old_fk = self.fk_class.objects.create(number=8) + self.instance = self.tracked_class.objects.create(fk=self.old_fk) def test_default(self): self.tracker = self.instance.tracker self.assertChanged() self.assertPrevious() self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id) - self.instance.fk = Tracked.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) self.assertChanged(fk_id=self.old_fk.id) self.assertPrevious(fk_id=self.old_fk.id) self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id) @@ -925,20 +939,41 @@ class ModelTrackerForeignKeyTests(ModelTrackerTestCase): self.assertChanged() self.assertPrevious() self.assertCurrent(fk_id=self.old_fk.id) - self.instance.fk = Tracked.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) self.assertChanged(fk_id=self.old_fk.id) self.assertPrevious(fk_id=self.old_fk.id) self.assertCurrent(fk_id=self.instance.fk_id) def test_custom_without_id(self): with self.assertNumQueries(2): - TrackedFK.objects.get() + self.tracked_class.objects.get() self.tracker = self.instance.custom_tracker_without_id self.assertChanged() self.assertPrevious() self.assertCurrent(fk=self.old_fk) - self.instance.fk = Tracked.objects.create(number=8) + self.instance.fk = self.fk_class.objects.create(number=8) self.assertNotEqual(self.instance.fk, self.old_fk) self.assertChanged(fk=self.old_fk) self.assertPrevious(fk=self.old_fk) self.assertCurrent(fk=self.instance.fk) + + +class ModelTrackerTests(FieldTrackerTests): + + tracked_class = ModelTracked + + +class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests): + + tracked_class = ModelTrackedNotDefault + + +class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests): + + tracked_class = ModelTrackedMultiple + + +class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests): + + fk_class = ModelTracked + tracked_class = ModelTrackedFK diff --git a/model_utils/tracker.py b/model_utils/tracker.py index a366490..cbd0fcd 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -3,44 +3,7 @@ from django.db import models from django.core.exceptions import FieldError -class ModelTracker(object): - def __init__(self, fields=None): - self.fields = fields - - def contribute_to_class(self, cls, name): - self.name = name - self.attname = '_%s' % name - models.signals.class_prepared.connect(self.finalize_class, sender=cls) - - def finalize_class(self, sender, **kwargs): - if self.fields is None: - self.fields = [field.attname for field in sender._meta.local_fields] - models.signals.post_init.connect(self.initialize_tracker, sender=sender) - setattr(sender, self.name, self) - - def initialize_tracker(self, sender, instance, **kwargs): - tracker = ModelInstanceTracker(instance, self.fields) - setattr(instance, self.attname, tracker) - tracker.set_saved_fields() - self.patch_save(instance) - - def patch_save(self, instance): - original_save = instance.save - def save(**kwargs): - ret = original_save(**kwargs) - getattr(instance, self.attname).set_saved_fields( - fields=kwargs.get('update_fields')) - return ret - instance.save = save - - def __get__(self, instance, owner): - if instance is None: - return self - else: - return getattr(instance, self.attname) - - -class ModelInstanceTracker(object): +class FieldInstanceTracker(object): def __init__(self, instance, fields): self.instance = instance self.fields = fields @@ -78,3 +41,51 @@ class ModelInstanceTracker(object): saved = self.saved_data.items() current = self.current() return dict((k, v) for k, v in saved if v != current[k]) + + +class FieldTracker(object): + + tracker_class = FieldInstanceTracker + + def __init__(self, fields=None): + self.fields = fields + + def contribute_to_class(self, cls, name): + self.name = name + self.attname = '_%s' % name + models.signals.class_prepared.connect(self.finalize_class, sender=cls) + + def finalize_class(self, sender, **kwargs): + if self.fields is None: + self.fields = [field.attname for field in sender._meta.local_fields] + models.signals.post_init.connect(self.initialize_tracker, sender=sender) + setattr(sender, self.name, self) + + def initialize_tracker(self, sender, instance, **kwargs): + tracker = self.tracker_class(instance, self.fields) + setattr(instance, self.attname, tracker) + tracker.set_saved_fields() + self.patch_save(instance) + + def patch_save(self, instance): + original_save = instance.save + def save(**kwargs): + ret = original_save(**kwargs) + getattr(instance, self.attname).set_saved_fields( + fields=kwargs.get('update_fields')) + return ret + instance.save = save + + def __get__(self, instance, owner): + if instance is None: + return self + else: + return getattr(instance, self.attname) + + +class ModelInstanceTracker(FieldInstanceTracker): + pass + + +class ModelTracker(FieldTracker): + tracker_class = ModelInstanceTracker