Annotate test_field_tracker module

This commit is contained in:
Maarten ter Huurne 2024-03-29 17:12:32 +01:00
parent 949d110d04
commit 7d6cad0200
3 changed files with 66 additions and 35 deletions

View file

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any
from django.core.cache import cache from django.core.cache import cache
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import models from django.db import models
@ -7,7 +9,7 @@ from django.db.models.fields.files import FieldFile
from django.test import TestCase from django.test import TestCase
from model_utils import FieldTracker from model_utils import FieldTracker
from model_utils.tracker import DescriptorWrapper from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker
from tests.models import ( from tests.models import (
InheritedModelTracked, InheritedModelTracked,
InheritedTracked, InheritedTracked,
@ -26,12 +28,18 @@ from tests.models import (
TrackerTimeStamped, TrackerTimeStamped,
) )
if TYPE_CHECKING:
MixinBase = TestCase
else:
MixinBase = object
class FieldTrackerTestCase(TestCase):
tracker = None class FieldTrackerMixin(MixinBase):
def assertHasChanged(self, *, tracker=None, **kwargs): tracker: FieldInstanceTracker
instance: models.Model
def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None: if tracker is None:
tracker = self.tracker tracker = self.tracker
for field, value in kwargs.items(): for field, value in kwargs.items():
@ -41,29 +49,35 @@ class FieldTrackerTestCase(TestCase):
else: else:
self.assertEqual(tracker.has_changed(field), value) self.assertEqual(tracker.has_changed(field), value)
def assertPrevious(self, *, tracker=None, **kwargs): def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None: if tracker is None:
tracker = self.tracker tracker = self.tracker
for field, value in kwargs.items(): for field, value in kwargs.items():
self.assertEqual(tracker.previous(field), value) self.assertEqual(tracker.previous(field), value)
def assertChanged(self, *, tracker=None, **kwargs): def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None: if tracker is None:
tracker = self.tracker tracker = self.tracker
self.assertEqual(tracker.changed(), kwargs) self.assertEqual(tracker.changed(), kwargs)
def assertCurrent(self, *, tracker=None, **kwargs): def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None: if tracker is None:
tracker = self.tracker tracker = self.tracker
self.assertEqual(tracker.current(), kwargs) self.assertEqual(tracker.current(), kwargs)
def update_instance(self, **kwargs): def update_instance(self, **kwargs: Any) -> None:
for field, value in kwargs.items(): for field, value in kwargs.items():
setattr(self.instance, field, value) setattr(self.instance, field, value)
self.instance.save() self.instance.save()
class FieldTrackerCommonTests: class FieldTrackerCommonMixin(FieldTrackerMixin):
instance: (
Tracked | TrackedNotDefault | TrackedMultiple
| ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple
| TrackedAbstract
)
def test_pre_save_previous(self) -> None: def test_pre_save_previous(self) -> None:
self.assertPrevious(name=None, number=None) self.assertPrevious(name=None, number=None)
@ -72,9 +86,10 @@ class FieldTrackerCommonTests:
self.assertPrevious(name=None, number=None) self.assertPrevious(name=None, number=None)
class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): class FieldTrackerTests(FieldTrackerCommonMixin, TestCase):
tracked_class: type[models.Model] = Tracked tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked
instance: Tracked | ModelTracked | TrackedAbstract
def setUp(self) -> None: def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
@ -219,6 +234,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.number = 1 self.instance.number = 1
self.instance.save() self.instance.save()
item = self.tracked_class.objects.only('name').first() item = self.tracked_class.objects.only('name').first()
assert item is not None
self.assertTrue(item.get_deferred_fields()) self.assertTrue(item.get_deferred_fields())
# has_changed() returns False for deferred fields, without un-deferring them. # has_changed() returns False for deferred fields, without un-deferring them.
@ -234,6 +250,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
# examining a deferred field un-defers it # examining a deferred field un-defers it
item = self.tracked_class.objects.only('name').first() item = self.tracked_class.objects.only('name').first()
assert item is not None
self.assertEqual(item.number, 1) self.assertEqual(item.number, 1)
self.assertTrue('number' not in item.get_deferred_fields()) self.assertTrue('number' not in item.get_deferred_fields())
self.assertEqual(item.tracker.previous('number'), 1) self.assertEqual(item.tracker.previous('number'), 1)
@ -252,6 +269,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
if self.tracked_class == Tracked: if self.tracked_class == Tracked:
item = self.tracked_class.objects.only('name').first() item = self.tracked_class.objects.only('name').first()
assert item is not None
item.number = 2 item.number = 2
# previous() fetches correct value from database after deferred field is assigned # previous() fetches correct value from database after deferred field is assigned
@ -278,10 +296,10 @@ class FieldTrackerMultipleInstancesTests(TestCase):
instance.name instance.name
class FieldTrackedModelCustomTests(FieldTrackerTestCase, class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase):
FieldTrackerCommonTests):
tracked_class: type[models.Model] = TrackedNotDefault tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault
instance: TrackedNotDefault | ModelTrackedNotDefault
def setUp(self) -> None: def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
@ -358,9 +376,10 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.assertChanged() self.assertChanged()
class FieldTrackedModelAttributeTests(FieldTrackerTestCase): class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase):
tracked_class = TrackedNonFieldAttr tracked_class = TrackedNonFieldAttr
instance: TrackedNonFieldAttr
def setUp(self) -> None: def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
@ -409,10 +428,10 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
self.assertCurrent(rounded=8) self.assertCurrent(rounded=8)
class FieldTrackedModelMultiTests(FieldTrackerTestCase, class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase):
FieldTrackerCommonTests):
tracked_class: type[models.Model] = TrackedMultiple tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple
instance: TrackedMultiple | ModelTrackedMultiple
def setUp(self) -> None: def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
@ -501,10 +520,11 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.assertCurrent(tracker=self.trackers[1], number=8) self.assertCurrent(tracker=self.trackers[1], number=8)
class FieldTrackerForeignKeyTests(FieldTrackerTestCase): class FieldTrackerForeignKeyMixin(FieldTrackerMixin):
fk_class: type[models.Model] = Tracked fk_class: type[Tracked | ModelTracked]
tracked_class: type[models.Model] = TrackedFK tracked_class: type[TrackedFK | ModelTrackedFK]
instance: TrackedFK | ModelTrackedFK
def setUp(self) -> None: def setUp(self) -> None:
self.old_fk = self.fk_class.objects.create(number=8) self.old_fk = self.fk_class.objects.create(number=8)
@ -543,11 +563,18 @@ class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
self.assertCurrent(fk=self.instance.fk_id) self.assertCurrent(fk=self.instance.fk_id)
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase): class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
fk_class = Tracked
tracked_class = TrackedFK
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase):
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError.""" """Test that using `prefetch_related` on a tracked field does not raise a ValueError."""
fk_class = Tracked fk_class = Tracked
tracked_class = TrackedFK tracked_class = TrackedFK
instance: TrackedFK
def setUp(self) -> None: def setUp(self) -> None:
model_tracked = self.fk_class.objects.create(name="", number=0) model_tracked = self.fk_class.objects.create(name="", number=0)
@ -566,10 +593,11 @@ class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase):
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
class FieldTrackerTimeStampedTests(FieldTrackerTestCase): class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase):
fk_class = Tracked fk_class = Tracked
tracked_class = TrackerTimeStamped tracked_class = TrackerTimeStamped
instance: TrackerTimeStamped
def setUp(self) -> None: def setUp(self) -> None:
self.instance = self.tracked_class.objects.create(name='old', number=1) self.instance = self.tracked_class.objects.create(name='old', number=1)
@ -605,9 +633,10 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests):
tracked_class = InheritedTrackedFK tracked_class = InheritedTrackedFK
class FieldTrackerFileFieldTests(FieldTrackerTestCase): class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase):
tracked_class = TrackedFileField tracked_class = TrackedFileField
instance: TrackedFileField
def setUp(self) -> None: def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
@ -629,7 +658,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
self.assertEqual(self.tracker.saved_data, {}) self.assertEqual(self.tracker.saved_data, {})
self.update_instance(some_file=self.some_file) self.update_instance(some_file=self.some_file)
field_file_copy = self.tracker.saved_data.get('some_file') field_file_copy = self.tracker.saved_data.get('some_file')
self.assertIsNotNone(field_file_copy) assert field_file_copy is not None
self.assertEqual(field_file_copy.__getstate__().get('instance'), None) self.assertEqual(field_file_copy.__getstate__().get('instance'), None)
self.assertEqual(self.instance.some_file.instance, self.instance) self.assertEqual(self.instance.some_file.instance, self.instance)
self.assertIsInstance(self.instance.some_file, FieldFile) self.assertIsInstance(self.instance.some_file, FieldFile)
@ -730,7 +759,8 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
class ModelTrackerTests(FieldTrackerTests): class ModelTrackerTests(FieldTrackerTests):
tracked_class: type[models.Model] = ModelTracked tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked
instance: ModelTracked
def test_cache_compatible(self) -> None: def test_cache_compatible(self) -> None:
cache.set('key', self.instance) cache.set('key', self.instance)
@ -846,10 +876,11 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
self.assertChanged() self.assertChanged()
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests): class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
fk_class = ModelTracked fk_class = ModelTracked
tracked_class = ModelTrackedFK tracked_class = ModelTrackedFK
instance: ModelTrackedFK
def test_custom_without_id(self) -> None: def test_custom_without_id(self) -> None:
with self.assertNumQueries(2): with self.assertNumQueries(2):
@ -886,11 +917,11 @@ class TrackerContextDecoratorTests(TestCase):
self.instance = Tracked.objects.create(number=1) self.instance = Tracked.objects.create(number=1)
self.tracker = self.instance.tracker self.tracker = self.instance.tracker
def assertChanged(self, *fields): def assertChanged(self, *fields: str) -> None:
for f in fields: for f in fields:
self.assertTrue(self.tracker.has_changed(f)) self.assertTrue(self.tracker.has_changed(f))
def assertNotChanged(self, *fields): def assertNotChanged(self, *fields: str) -> None:
for f in fields: for f in fields:
self.assertFalse(self.tracker.has_changed(f)) self.assertFalse(self.tracker.has_changed(f))
@ -921,7 +952,7 @@ class TrackerContextDecoratorTests(TestCase):
def test_tracker_decorator(self) -> None: def test_tracker_decorator(self) -> None:
@Tracked.tracker @Tracked.tracker
def tracked_method(obj): def tracked_method(obj: Tracked) -> None:
obj.name = 'new' obj.name = 'new'
self.assertChanged('name') self.assertChanged('name')
@ -932,7 +963,7 @@ class TrackerContextDecoratorTests(TestCase):
def test_tracker_decorator_fields(self) -> None: def test_tracker_decorator_fields(self) -> None:
@Tracked.tracker(fields=['name']) @Tracked.tracker(fields=['name'])
def tracked_method(obj): def tracked_method(obj: Tracked) -> None:
obj.name = 'new' obj.name = 'new'
obj.number += 1 obj.number += 1
self.assertChanged('name', 'number') self.assertChanged('name', 'number')

View file

@ -34,7 +34,7 @@ class MonitorFieldTests(TestCase):
def test_no_monitor_arg(self) -> None: def test_no_monitor_arg(self) -> None:
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
MonitorField() MonitorField() # type: ignore[call-arg]
def test_monitor_default_is_none_when_nullable(self) -> None: def test_monitor_default_is_none_when_nullable(self) -> None:
self.assertIsNone(self.instance.name_changed_nullable) self.assertIsNone(self.instance.name_changed_nullable)

View file

@ -31,7 +31,7 @@ class UrlsaftTokenFieldTests(TestCase):
def test_factory_not_callable(self) -> None: def test_factory_not_callable(self) -> None:
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
UrlsafeTokenField(factory='INVALID') UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type]
def test_get_default(self) -> None: def test_get_default(self) -> None:
field = UrlsafeTokenField() field = UrlsafeTokenField()
@ -57,8 +57,8 @@ class UrlsaftTokenFieldTests(TestCase):
self.assertIs(field.default, NOT_PROVIDED) self.assertIs(field.default, NOT_PROVIDED)
def test_deconstruct(self) -> None: def test_deconstruct(self) -> None:
def test_factory() -> None: def test_factory(max_length: int) -> str:
pass assert False
instance = UrlsafeTokenField(factory=test_factory) instance = UrlsafeTokenField(factory=test_factory)
name, path, args, kwargs = instance.deconstruct() name, path, args, kwargs = instance.deconstruct()
new_instance = UrlsafeTokenField(*args, **kwargs) new_instance = UrlsafeTokenField(*args, **kwargs)