Add ModelTracker with tests

This commit is contained in:
Trey Hunner 2013-02-16 14:52:31 -08:00
parent 579abf8e66
commit c528a347e0
3 changed files with 202 additions and 4 deletions

View file

@ -3,7 +3,7 @@ from datetime import datetime
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.db.models.fields import FieldDoesNotExist
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import ImproperlyConfigured, FieldError
from model_utils.managers import QueryManager
from model_utils.fields import AutoCreatedField, AutoLastModifiedField, \
@ -97,3 +97,79 @@ def add_timeframed_query_manager(sender, **kwargs):
models.signals.class_prepared.connect(add_status_query_managers)
models.signals.class_prepared.connect(add_timeframed_query_manager)
class ModelTracker(object):
def __init__(self, fields=None):
self.fields = fields
def contribute_to_class(self, cls, name):
self.name = name
models.signals.class_prepared.connect(self.finalize, sender=cls)
def finalize(self, sender, **kwargs):
descriptor = ModelTrackerDescriptor(sender, self.name, self.fields)
setattr(sender, self.name, descriptor)
class ModelTrackerDescriptor(object):
def __init__(self, cls, name, fields):
self.attname = '_%s' % name
self.fields = fields
if self.fields is None:
self.fields = [field.attname for field in cls._meta.local_fields]
models.signals.post_init.connect(self.initialize, sender=cls)
def initialize(self, sender, instance, **kwargs):
tracker = ModelInstanceTracker(instance, self.fields)
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
self.patch_save(instance)
def patch_save(self, instance):
original_save = instance.save
def save(**kwargs):
ret = original_save()
getattr(instance, self.attname).set_saved_fields()
return ret
setattr(instance, 'save', save)
def __get__(self, instance, owner):
if instance is None:
return self
else:
return getattr(instance, self.attname)
class ModelInstanceTracker(object):
def __init__(self, instance, fields):
self.instance = instance
self.fields = fields
def set_saved_fields(self):
self.saved_data = self.current_fields()
def current_fields(self):
return (dict((f, getattr(self.instance, f)) for f in self.fields)
if self.instance.pk else {})
def has_changed(self, field):
"""Returns ``True`` if field has changed from currently saved value"""
if not self.instance.pk:
return True
elif field in self.saved_data:
return self.saved_data.get(field) != getattr(self.instance, field)
else:
raise FieldError('field "%s" not tracked' % field)
def previous(self, field):
"""Return currently saved value of given field"""
return self.saved_data.get(field)
def changed(self):
"""Returns dict of fields that changed since save (with old values)"""
if not self.instance.pk:
return {}
saved = self.saved_data.iteritems()
current = self.current_fields()
return dict((k, v) for k, v in saved if v != current[k])

View file

@ -1,7 +1,7 @@
from django.db import models
from django.utils.translation import ugettext_lazy as _
from model_utils.models import TimeStampedModel, StatusModel, TimeFramedModel
from model_utils.models import TimeStampedModel, StatusModel, TimeFramedModel, ModelTracker
from model_utils.managers import QueryManager, InheritanceManager, PassThroughManager
from model_utils.fields import SplitField, MonitorField
from model_utils import Choices
@ -218,3 +218,17 @@ class Spot(models.Model):
owner = models.ForeignKey(Dude, related_name='spots_owned')
objects = PassThroughManager.for_queryset_class(SpotQuerySet)()
class Tracked(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
tracker = ModelTracker()
class TrackedNotDefault(models.Model):
name = models.CharField(max_length=20)
number = models.IntegerField()
name_tracker = ModelTracker(fields=['name'])

View file

@ -6,7 +6,7 @@ from datetime import datetime, timedelta
import django
from django.db import models
from django.db.models.fields import FieldDoesNotExist
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import ImproperlyConfigured, FieldError
from django.test import TestCase
from model_utils import Choices
@ -18,7 +18,8 @@ from model_utils.tests.models import (
InheritanceManagerTestParent, InheritanceManagerTestChild1,
InheritanceManagerTestChild2, TimeStamp, Post, Article, Status,
StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot)
TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot,
Tracked, TrackedNotDefault)
@ -624,3 +625,110 @@ class CreatePassThroughManagerTests(TestCase):
def test_related_manager_create(self):
self.dude.spots_owned.create(name='The Crib', closed=True, secure=True)
class ModelTrackerTestCase(TestCase):
def assertHasChanged(self, **kwargs):
for field, value in kwargs.iteritems():
if value is None:
self.assertRaises(FieldError, self.tracker.has_changed, field)
else:
self.assertEqual(self.tracker.has_changed(field), value)
def assertPrevious(self, **kwargs):
for field, value in kwargs.iteritems():
self.assertEqual(self.tracker.previous(field), value)
def assertChanged(self, **kwargs):
self.assertEqual(self.tracker.changed(), kwargs)
def update_instance(self, **kwargs):
for field, value in kwargs.iteritems():
setattr(self.instance, field, value)
self.instance.save()
class ModelTrackerCommonTests(object):
def test_pre_save_has_changed(self):
self.assertHasChanged(name=True, number=True)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=True)
def test_pre_save_changed(self):
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged()
self.instance.number = 8
self.assertChanged()
self.instance.name = ''
self.assertChanged()
def test_pre_save_previous(self):
self.assertPrevious(name=None, number=None)
self.instance.name = 'new age'
self.instance.number = 8
self.assertPrevious(name=None, number=None)
class ModelTrackerTests(ModelTrackerTestCase, ModelTrackerCommonTests):
def setUp(self):
self.instance = Tracked()
self.tracker = self.instance.tracker
def test_post_save_has_changed(self):
self.update_instance(name='retro', number=4)
self.assertHasChanged(name=False, number=False)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=False)
self.instance.number = 8
self.assertHasChanged(name=True, number=True)
self.instance.name = 'retro'
self.assertHasChanged(name=False, number=True)
def test_post_save_previous(self):
self.update_instance(name='retro', number=4)
self.instance.name = 'new age'
self.assertPrevious(name='retro', number=4)
def test_post_save_changed(self):
self.update_instance(name='retro', number=4)
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged(name='retro')
self.instance.number = 8
self.assertChanged(name='retro', number=4)
self.instance.name = 'retro'
self.assertChanged(number=4)
class FieldTrackedModelCustomTests(ModelTrackerTestCase,
ModelTrackerCommonTests):
def setUp(self):
self.instance = TrackedNotDefault()
self.tracker = self.instance.name_tracker
def test_post_save_has_changed(self):
self.update_instance(name='retro', number=4)
self.assertHasChanged(name=False, number=None)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=None)
self.instance.number = 8
self.assertHasChanged(name=True, number=None)
self.instance.name = 'retro'
self.assertHasChanged(name=False, number=None)
def test_post_save_previous(self):
self.update_instance(name='retro', number=4)
self.instance.name = 'new age'
self.assertPrevious(name='retro', number=None)
def test_post_save_changed(self):
self.update_instance(name='retro', number=4)
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged(name='retro')
self.instance.number = 8
self.assertChanged(name='retro')
self.instance.name = 'retro'
self.assertChanged()