Create FieldTracker that mirrors ModelTracker

This commit is contained in:
Trey Hunner 2013-05-23 12:50:34 -07:00
parent 97e0f5edbf
commit 6532784acd
4 changed files with 142 additions and 66 deletions

View file

@ -1,2 +1,2 @@
from .choices import Choices
from .tracker import ModelTracker
from .tracker import FieldTracker, ModelTracker

View file

@ -2,7 +2,7 @@ from django.db import models
from django.utils.translation import ugettext_lazy as _
from model_utils.models import TimeStampedModel, StatusModel, TimeFramedModel
from model_utils.tracker import ModelTracker
from model_utils.tracker import FieldTracker, ModelTracker
from model_utils.managers import QueryManager, InheritanceManager, PassThroughManager
from model_utils.fields import SplitField, MonitorField, StatusField
from model_utils import Choices
@ -225,28 +225,58 @@ class Tracked(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
tracker = ModelTracker()
tracker = FieldTracker()
class TrackedFK(models.Model):
fk = models.ForeignKey('Tracked')
tracker = ModelTracker()
custom_tracker = ModelTracker(fields=['fk_id'])
custom_tracker_without_id = ModelTracker(fields=['fk'])
tracker = FieldTracker()
custom_tracker = FieldTracker(fields=['fk_id'])
custom_tracker_without_id = FieldTracker(fields=['fk'])
class TrackedNotDefault(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
name_tracker = ModelTracker(fields=['name'])
name_tracker = FieldTracker(fields=['name'])
class TrackedMultiple(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
name_tracker = FieldTracker(fields=['name'])
number_tracker = FieldTracker(fields=['number'])
class ModelTracked(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
tracker = ModelTracker()
class ModelTrackedFK(models.Model):
fk = models.ForeignKey('ModelTracked')
tracker = ModelTracker()
custom_tracker = ModelTracker(fields=['fk_id'])
custom_tracker_without_id = ModelTracker(fields=['fk'])
class ModelTrackedNotDefault(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
name_tracker = ModelTracker(fields=['name'])
class ModelTrackedMultiple(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
name_tracker = ModelTracker(fields=['name'])
number_tracker = ModelTracker(fields=['number'])

View file

@ -10,7 +10,7 @@ from django.utils.six import text_type
from django.core.exceptions import ImproperlyConfigured, FieldError
from django.test import TestCase
from model_utils import Choices, ModelTracker
from model_utils import Choices, FieldTracker, ModelTracker
from model_utils.fields import get_excerpt, MonitorField, StatusField
from model_utils.managers import QueryManager
from model_utils.models import StatusModel, TimeFramedModel
@ -20,6 +20,7 @@ from model_utils.tests.models import (
InheritanceManagerTestChild2, TimeStamp, Post, Article, Status,
StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot,
ModelTracked, ModelTrackedFK, ModelTrackedNotDefault, ModelTrackedMultiple,
Tracked, TrackedFK, TrackedNotDefault, TrackedMultiple,
StatusFieldDefaultFilled, StatusFieldDefaultNotFilled)
@ -652,7 +653,7 @@ class CreatePassThroughManagerTests(TestCase):
self.dude.spots_owned.create(name='The Crib', closed=True, secure=True)
class ModelTrackerTestCase(TestCase):
class FieldTrackerTestCase(TestCase):
tracker = None
@ -683,7 +684,7 @@ class ModelTrackerTestCase(TestCase):
self.instance.save()
class ModelTrackerCommonTests(object):
class FieldTrackerCommonTests(object):
def test_pre_save_has_changed(self):
self.assertHasChanged(name=True, number=True)
@ -706,13 +707,16 @@ class ModelTrackerCommonTests(object):
self.assertPrevious(name=None, number=None)
class ModelTrackerTests(ModelTrackerTestCase, ModelTrackerCommonTests):
class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
tracked_class = Tracked
def setUp(self):
self.instance = Tracked()
self.instance = self.tracked_class()
self.tracker = self.instance.tracker
def test_descriptor(self):
self.assertTrue(isinstance(Tracked.tracker, ModelTracker))
self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker))
def test_first_save(self):
self.assertHasChanged(name=True, number=True)
@ -780,21 +784,24 @@ class ModelTrackerTests(ModelTrackerTestCase, ModelTrackerCommonTests):
self.instance.save(update_fields=[])
self.assertChanged(name='retro', number=4)
self.instance.save(update_fields=['name'])
in_db = Tracked.objects.get(id=self.instance.id)
in_db = self.tracked_class.objects.get(id=self.instance.id)
self.assertEqual(in_db.name, self.instance.name)
self.assertNotEqual(in_db.number, self.instance.number)
self.assertChanged(number=4)
self.instance.save(update_fields=['number'])
self.assertChanged()
in_db = Tracked.objects.get(id=self.instance.id)
in_db = self.tracked_class.objects.get(id=self.instance.id)
self.assertEqual(in_db.name, self.instance.name)
self.assertEqual(in_db.number, self.instance.number)
class FieldTrackedModelCustomTests(ModelTrackerTestCase,
ModelTrackerCommonTests):
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
FieldTrackerCommonTests):
tracked_class = TrackedNotDefault
def setUp(self):
self.instance = TrackedNotDefault()
self.instance = self.tracked_class()
self.tracker = self.instance.name_tracker
def test_post_save_has_changed(self):
@ -832,10 +839,13 @@ class FieldTrackedModelCustomTests(ModelTrackerTestCase,
self.assertCurrent(name='new age')
class FieldTrackedModelMultiTests(ModelTrackerTestCase,
ModelTrackerCommonTests):
class FieldTrackedModelMultiTests(FieldTrackerTestCase,
FieldTrackerCommonTests):
tracked_class = TrackedMultiple
def setUp(self):
self.instance = TrackedMultiple()
self.instance = self.tracked_class()
self.trackers = [self.instance.name_tracker,
self.instance.number_tracker]
@ -905,17 +915,21 @@ class FieldTrackedModelMultiTests(ModelTrackerTestCase,
self.assertCurrent(tracker=self.trackers[1], number=8)
class ModelTrackerForeignKeyTests(ModelTrackerTestCase):
class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
fk_class = Tracked
tracked_class = TrackedFK
def setUp(self):
self.old_fk = Tracked.objects.create(number=8)
self.instance = TrackedFK.objects.create(fk=self.old_fk)
self.old_fk = self.fk_class.objects.create(number=8)
self.instance = self.tracked_class.objects.create(fk=self.old_fk)
def test_default(self):
self.tracker = self.instance.tracker
self.assertChanged()
self.assertPrevious()
self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id)
self.instance.fk = Tracked.objects.create(number=8)
self.instance.fk = self.fk_class.objects.create(number=8)
self.assertChanged(fk_id=self.old_fk.id)
self.assertPrevious(fk_id=self.old_fk.id)
self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id)
@ -925,20 +939,41 @@ class ModelTrackerForeignKeyTests(ModelTrackerTestCase):
self.assertChanged()
self.assertPrevious()
self.assertCurrent(fk_id=self.old_fk.id)
self.instance.fk = Tracked.objects.create(number=8)
self.instance.fk = self.fk_class.objects.create(number=8)
self.assertChanged(fk_id=self.old_fk.id)
self.assertPrevious(fk_id=self.old_fk.id)
self.assertCurrent(fk_id=self.instance.fk_id)
def test_custom_without_id(self):
with self.assertNumQueries(2):
TrackedFK.objects.get()
self.tracked_class.objects.get()
self.tracker = self.instance.custom_tracker_without_id
self.assertChanged()
self.assertPrevious()
self.assertCurrent(fk=self.old_fk)
self.instance.fk = Tracked.objects.create(number=8)
self.instance.fk = self.fk_class.objects.create(number=8)
self.assertNotEqual(self.instance.fk, self.old_fk)
self.assertChanged(fk=self.old_fk)
self.assertPrevious(fk=self.old_fk)
self.assertCurrent(fk=self.instance.fk)
class ModelTrackerTests(FieldTrackerTests):
tracked_class = ModelTracked
class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):
tracked_class = ModelTrackedNotDefault
class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
tracked_class = ModelTrackedMultiple
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):
fk_class = ModelTracked
tracked_class = ModelTrackedFK

View file

@ -3,44 +3,7 @@ from django.db import models
from django.core.exceptions import FieldError
class ModelTracker(object):
def __init__(self, fields=None):
self.fields = fields
def contribute_to_class(self, cls, name):
self.name = name
self.attname = '_%s' % name
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
def finalize_class(self, sender, **kwargs):
if self.fields is None:
self.fields = [field.attname for field in sender._meta.local_fields]
models.signals.post_init.connect(self.initialize_tracker, sender=sender)
setattr(sender, self.name, self)
def initialize_tracker(self, sender, instance, **kwargs):
tracker = ModelInstanceTracker(instance, self.fields)
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
self.patch_save(instance)
def patch_save(self, instance):
original_save = instance.save
def save(**kwargs):
ret = original_save(**kwargs)
getattr(instance, self.attname).set_saved_fields(
fields=kwargs.get('update_fields'))
return ret
instance.save = save
def __get__(self, instance, owner):
if instance is None:
return self
else:
return getattr(instance, self.attname)
class ModelInstanceTracker(object):
class FieldInstanceTracker(object):
def __init__(self, instance, fields):
self.instance = instance
self.fields = fields
@ -78,3 +41,51 @@ class ModelInstanceTracker(object):
saved = self.saved_data.items()
current = self.current()
return dict((k, v) for k, v in saved if v != current[k])
class FieldTracker(object):
tracker_class = FieldInstanceTracker
def __init__(self, fields=None):
self.fields = fields
def contribute_to_class(self, cls, name):
self.name = name
self.attname = '_%s' % name
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
def finalize_class(self, sender, **kwargs):
if self.fields is None:
self.fields = [field.attname for field in sender._meta.local_fields]
models.signals.post_init.connect(self.initialize_tracker, sender=sender)
setattr(sender, self.name, self)
def initialize_tracker(self, sender, instance, **kwargs):
tracker = self.tracker_class(instance, self.fields)
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
self.patch_save(instance)
def patch_save(self, instance):
original_save = instance.save
def save(**kwargs):
ret = original_save(**kwargs)
getattr(instance, self.attname).set_saved_fields(
fields=kwargs.get('update_fields'))
return ret
instance.save = save
def __get__(self, instance, owner):
if instance is None:
return self
else:
return getattr(instance, self.attname)
class ModelInstanceTracker(FieldInstanceTracker):
pass
class ModelTracker(FieldTracker):
tracker_class = ModelInstanceTracker