mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Create FieldTracker that mirrors ModelTracker
This commit is contained in:
parent
97e0f5edbf
commit
6532784acd
4 changed files with 142 additions and 66 deletions
|
|
@ -1,2 +1,2 @@
|
|||
from .choices import Choices
|
||||
from .tracker import ModelTracker
|
||||
from .tracker import FieldTracker, ModelTracker
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from django.db import models
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from model_utils.models import TimeStampedModel, StatusModel, TimeFramedModel
|
||||
from model_utils.tracker import ModelTracker
|
||||
from model_utils.tracker import FieldTracker, ModelTracker
|
||||
from model_utils.managers import QueryManager, InheritanceManager, PassThroughManager
|
||||
from model_utils.fields import SplitField, MonitorField, StatusField
|
||||
from model_utils import Choices
|
||||
|
|
@ -225,28 +225,58 @@ class Tracked(models.Model):
|
|||
name = models.CharField(max_length=20)
|
||||
number = models.IntegerField()
|
||||
|
||||
tracker = ModelTracker()
|
||||
tracker = FieldTracker()
|
||||
|
||||
|
||||
class TrackedFK(models.Model):
|
||||
fk = models.ForeignKey('Tracked')
|
||||
|
||||
tracker = ModelTracker()
|
||||
custom_tracker = ModelTracker(fields=['fk_id'])
|
||||
custom_tracker_without_id = ModelTracker(fields=['fk'])
|
||||
tracker = FieldTracker()
|
||||
custom_tracker = FieldTracker(fields=['fk_id'])
|
||||
custom_tracker_without_id = FieldTracker(fields=['fk'])
|
||||
|
||||
|
||||
class TrackedNotDefault(models.Model):
|
||||
name = models.CharField(max_length=20)
|
||||
number = models.IntegerField()
|
||||
|
||||
name_tracker = ModelTracker(fields=['name'])
|
||||
name_tracker = FieldTracker(fields=['name'])
|
||||
|
||||
|
||||
class TrackedMultiple(models.Model):
|
||||
name = models.CharField(max_length=20)
|
||||
number = models.IntegerField()
|
||||
|
||||
name_tracker = FieldTracker(fields=['name'])
|
||||
number_tracker = FieldTracker(fields=['number'])
|
||||
|
||||
|
||||
class ModelTracked(models.Model):
|
||||
name = models.CharField(max_length=20)
|
||||
number = models.IntegerField()
|
||||
|
||||
tracker = ModelTracker()
|
||||
|
||||
|
||||
class ModelTrackedFK(models.Model):
|
||||
fk = models.ForeignKey('ModelTracked')
|
||||
|
||||
tracker = ModelTracker()
|
||||
custom_tracker = ModelTracker(fields=['fk_id'])
|
||||
custom_tracker_without_id = ModelTracker(fields=['fk'])
|
||||
|
||||
|
||||
class ModelTrackedNotDefault(models.Model):
|
||||
name = models.CharField(max_length=20)
|
||||
number = models.IntegerField()
|
||||
|
||||
name_tracker = ModelTracker(fields=['name'])
|
||||
|
||||
|
||||
class ModelTrackedMultiple(models.Model):
|
||||
name = models.CharField(max_length=20)
|
||||
number = models.IntegerField()
|
||||
|
||||
name_tracker = ModelTracker(fields=['name'])
|
||||
number_tracker = ModelTracker(fields=['number'])
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from django.utils.six import text_type
|
|||
from django.core.exceptions import ImproperlyConfigured, FieldError
|
||||
from django.test import TestCase
|
||||
|
||||
from model_utils import Choices, ModelTracker
|
||||
from model_utils import Choices, FieldTracker, ModelTracker
|
||||
from model_utils.fields import get_excerpt, MonitorField, StatusField
|
||||
from model_utils.managers import QueryManager
|
||||
from model_utils.models import StatusModel, TimeFramedModel
|
||||
|
|
@ -20,6 +20,7 @@ from model_utils.tests.models import (
|
|||
InheritanceManagerTestChild2, TimeStamp, Post, Article, Status,
|
||||
StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
|
||||
TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot,
|
||||
ModelTracked, ModelTrackedFK, ModelTrackedNotDefault, ModelTrackedMultiple,
|
||||
Tracked, TrackedFK, TrackedNotDefault, TrackedMultiple,
|
||||
StatusFieldDefaultFilled, StatusFieldDefaultNotFilled)
|
||||
|
||||
|
|
@ -652,7 +653,7 @@ class CreatePassThroughManagerTests(TestCase):
|
|||
self.dude.spots_owned.create(name='The Crib', closed=True, secure=True)
|
||||
|
||||
|
||||
class ModelTrackerTestCase(TestCase):
|
||||
class FieldTrackerTestCase(TestCase):
|
||||
|
||||
tracker = None
|
||||
|
||||
|
|
@ -683,7 +684,7 @@ class ModelTrackerTestCase(TestCase):
|
|||
self.instance.save()
|
||||
|
||||
|
||||
class ModelTrackerCommonTests(object):
|
||||
class FieldTrackerCommonTests(object):
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
|
|
@ -706,13 +707,16 @@ class ModelTrackerCommonTests(object):
|
|||
self.assertPrevious(name=None, number=None)
|
||||
|
||||
|
||||
class ModelTrackerTests(ModelTrackerTestCase, ModelTrackerCommonTests):
|
||||
class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||
|
||||
tracked_class = Tracked
|
||||
|
||||
def setUp(self):
|
||||
self.instance = Tracked()
|
||||
self.instance = self.tracked_class()
|
||||
self.tracker = self.instance.tracker
|
||||
|
||||
def test_descriptor(self):
|
||||
self.assertTrue(isinstance(Tracked.tracker, ModelTracker))
|
||||
self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker))
|
||||
|
||||
def test_first_save(self):
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
|
|
@ -780,21 +784,24 @@ class ModelTrackerTests(ModelTrackerTestCase, ModelTrackerCommonTests):
|
|||
self.instance.save(update_fields=[])
|
||||
self.assertChanged(name='retro', number=4)
|
||||
self.instance.save(update_fields=['name'])
|
||||
in_db = Tracked.objects.get(id=self.instance.id)
|
||||
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
||||
self.assertEqual(in_db.name, self.instance.name)
|
||||
self.assertNotEqual(in_db.number, self.instance.number)
|
||||
self.assertChanged(number=4)
|
||||
self.instance.save(update_fields=['number'])
|
||||
self.assertChanged()
|
||||
in_db = Tracked.objects.get(id=self.instance.id)
|
||||
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
||||
self.assertEqual(in_db.name, self.instance.name)
|
||||
self.assertEqual(in_db.number, self.instance.number)
|
||||
|
||||
|
||||
class FieldTrackedModelCustomTests(ModelTrackerTestCase,
|
||||
ModelTrackerCommonTests):
|
||||
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
||||
FieldTrackerCommonTests):
|
||||
|
||||
tracked_class = TrackedNotDefault
|
||||
|
||||
def setUp(self):
|
||||
self.instance = TrackedNotDefault()
|
||||
self.instance = self.tracked_class()
|
||||
self.tracker = self.instance.name_tracker
|
||||
|
||||
def test_post_save_has_changed(self):
|
||||
|
|
@ -832,10 +839,13 @@ class FieldTrackedModelCustomTests(ModelTrackerTestCase,
|
|||
self.assertCurrent(name='new age')
|
||||
|
||||
|
||||
class FieldTrackedModelMultiTests(ModelTrackerTestCase,
|
||||
ModelTrackerCommonTests):
|
||||
class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
||||
FieldTrackerCommonTests):
|
||||
|
||||
tracked_class = TrackedMultiple
|
||||
|
||||
def setUp(self):
|
||||
self.instance = TrackedMultiple()
|
||||
self.instance = self.tracked_class()
|
||||
self.trackers = [self.instance.name_tracker,
|
||||
self.instance.number_tracker]
|
||||
|
||||
|
|
@ -905,17 +915,21 @@ class FieldTrackedModelMultiTests(ModelTrackerTestCase,
|
|||
self.assertCurrent(tracker=self.trackers[1], number=8)
|
||||
|
||||
|
||||
class ModelTrackerForeignKeyTests(ModelTrackerTestCase):
|
||||
class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
|
||||
|
||||
fk_class = Tracked
|
||||
tracked_class = TrackedFK
|
||||
|
||||
def setUp(self):
|
||||
self.old_fk = Tracked.objects.create(number=8)
|
||||
self.instance = TrackedFK.objects.create(fk=self.old_fk)
|
||||
self.old_fk = self.fk_class.objects.create(number=8)
|
||||
self.instance = self.tracked_class.objects.create(fk=self.old_fk)
|
||||
|
||||
def test_default(self):
|
||||
self.tracker = self.instance.tracker
|
||||
self.assertChanged()
|
||||
self.assertPrevious()
|
||||
self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id)
|
||||
self.instance.fk = Tracked.objects.create(number=8)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8)
|
||||
self.assertChanged(fk_id=self.old_fk.id)
|
||||
self.assertPrevious(fk_id=self.old_fk.id)
|
||||
self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id)
|
||||
|
|
@ -925,20 +939,41 @@ class ModelTrackerForeignKeyTests(ModelTrackerTestCase):
|
|||
self.assertChanged()
|
||||
self.assertPrevious()
|
||||
self.assertCurrent(fk_id=self.old_fk.id)
|
||||
self.instance.fk = Tracked.objects.create(number=8)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8)
|
||||
self.assertChanged(fk_id=self.old_fk.id)
|
||||
self.assertPrevious(fk_id=self.old_fk.id)
|
||||
self.assertCurrent(fk_id=self.instance.fk_id)
|
||||
|
||||
def test_custom_without_id(self):
|
||||
with self.assertNumQueries(2):
|
||||
TrackedFK.objects.get()
|
||||
self.tracked_class.objects.get()
|
||||
self.tracker = self.instance.custom_tracker_without_id
|
||||
self.assertChanged()
|
||||
self.assertPrevious()
|
||||
self.assertCurrent(fk=self.old_fk)
|
||||
self.instance.fk = Tracked.objects.create(number=8)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8)
|
||||
self.assertNotEqual(self.instance.fk, self.old_fk)
|
||||
self.assertChanged(fk=self.old_fk)
|
||||
self.assertPrevious(fk=self.old_fk)
|
||||
self.assertCurrent(fk=self.instance.fk)
|
||||
|
||||
|
||||
class ModelTrackerTests(FieldTrackerTests):
|
||||
|
||||
tracked_class = ModelTracked
|
||||
|
||||
|
||||
class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):
|
||||
|
||||
tracked_class = ModelTrackedNotDefault
|
||||
|
||||
|
||||
class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
|
||||
|
||||
tracked_class = ModelTrackedMultiple
|
||||
|
||||
|
||||
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):
|
||||
|
||||
fk_class = ModelTracked
|
||||
tracked_class = ModelTrackedFK
|
||||
|
|
|
|||
|
|
@ -3,44 +3,7 @@ from django.db import models
|
|||
from django.core.exceptions import FieldError
|
||||
|
||||
|
||||
class ModelTracker(object):
|
||||
def __init__(self, fields=None):
|
||||
self.fields = fields
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
self.name = name
|
||||
self.attname = '_%s' % name
|
||||
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
|
||||
|
||||
def finalize_class(self, sender, **kwargs):
|
||||
if self.fields is None:
|
||||
self.fields = [field.attname for field in sender._meta.local_fields]
|
||||
models.signals.post_init.connect(self.initialize_tracker, sender=sender)
|
||||
setattr(sender, self.name, self)
|
||||
|
||||
def initialize_tracker(self, sender, instance, **kwargs):
|
||||
tracker = ModelInstanceTracker(instance, self.fields)
|
||||
setattr(instance, self.attname, tracker)
|
||||
tracker.set_saved_fields()
|
||||
self.patch_save(instance)
|
||||
|
||||
def patch_save(self, instance):
|
||||
original_save = instance.save
|
||||
def save(**kwargs):
|
||||
ret = original_save(**kwargs)
|
||||
getattr(instance, self.attname).set_saved_fields(
|
||||
fields=kwargs.get('update_fields'))
|
||||
return ret
|
||||
instance.save = save
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return getattr(instance, self.attname)
|
||||
|
||||
|
||||
class ModelInstanceTracker(object):
|
||||
class FieldInstanceTracker(object):
|
||||
def __init__(self, instance, fields):
|
||||
self.instance = instance
|
||||
self.fields = fields
|
||||
|
|
@ -78,3 +41,51 @@ class ModelInstanceTracker(object):
|
|||
saved = self.saved_data.items()
|
||||
current = self.current()
|
||||
return dict((k, v) for k, v in saved if v != current[k])
|
||||
|
||||
|
||||
class FieldTracker(object):
|
||||
|
||||
tracker_class = FieldInstanceTracker
|
||||
|
||||
def __init__(self, fields=None):
|
||||
self.fields = fields
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
self.name = name
|
||||
self.attname = '_%s' % name
|
||||
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
|
||||
|
||||
def finalize_class(self, sender, **kwargs):
|
||||
if self.fields is None:
|
||||
self.fields = [field.attname for field in sender._meta.local_fields]
|
||||
models.signals.post_init.connect(self.initialize_tracker, sender=sender)
|
||||
setattr(sender, self.name, self)
|
||||
|
||||
def initialize_tracker(self, sender, instance, **kwargs):
|
||||
tracker = self.tracker_class(instance, self.fields)
|
||||
setattr(instance, self.attname, tracker)
|
||||
tracker.set_saved_fields()
|
||||
self.patch_save(instance)
|
||||
|
||||
def patch_save(self, instance):
|
||||
original_save = instance.save
|
||||
def save(**kwargs):
|
||||
ret = original_save(**kwargs)
|
||||
getattr(instance, self.attname).set_saved_fields(
|
||||
fields=kwargs.get('update_fields'))
|
||||
return ret
|
||||
instance.save = save
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
return getattr(instance, self.attname)
|
||||
|
||||
|
||||
class ModelInstanceTracker(FieldInstanceTracker):
|
||||
pass
|
||||
|
||||
|
||||
class ModelTracker(FieldTracker):
|
||||
tracker_class = ModelInstanceTracker
|
||||
|
|
|
|||
Loading…
Reference in a new issue