mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Annotate test_field_tracker module
This commit is contained in:
parent
949d110d04
commit
7d6cad0200
3 changed files with 66 additions and 35 deletions
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import models
|
||||
|
|
@ -7,7 +9,7 @@ from django.db.models.fields.files import FieldFile
|
|||
from django.test import TestCase
|
||||
|
||||
from model_utils import FieldTracker
|
||||
from model_utils.tracker import DescriptorWrapper
|
||||
from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker
|
||||
from tests.models import (
|
||||
InheritedModelTracked,
|
||||
InheritedTracked,
|
||||
|
|
@ -26,12 +28,18 @@ from tests.models import (
|
|||
TrackerTimeStamped,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
MixinBase = TestCase
|
||||
else:
|
||||
MixinBase = object
|
||||
|
||||
class FieldTrackerTestCase(TestCase):
|
||||
|
||||
tracker = None
|
||||
class FieldTrackerMixin(MixinBase):
|
||||
|
||||
def assertHasChanged(self, *, tracker=None, **kwargs):
|
||||
tracker: FieldInstanceTracker
|
||||
instance: models.Model
|
||||
|
||||
def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
|
||||
if tracker is None:
|
||||
tracker = self.tracker
|
||||
for field, value in kwargs.items():
|
||||
|
|
@ -41,29 +49,35 @@ class FieldTrackerTestCase(TestCase):
|
|||
else:
|
||||
self.assertEqual(tracker.has_changed(field), value)
|
||||
|
||||
def assertPrevious(self, *, tracker=None, **kwargs):
|
||||
def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
|
||||
if tracker is None:
|
||||
tracker = self.tracker
|
||||
for field, value in kwargs.items():
|
||||
self.assertEqual(tracker.previous(field), value)
|
||||
|
||||
def assertChanged(self, *, tracker=None, **kwargs):
|
||||
def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
|
||||
if tracker is None:
|
||||
tracker = self.tracker
|
||||
self.assertEqual(tracker.changed(), kwargs)
|
||||
|
||||
def assertCurrent(self, *, tracker=None, **kwargs):
|
||||
def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
|
||||
if tracker is None:
|
||||
tracker = self.tracker
|
||||
self.assertEqual(tracker.current(), kwargs)
|
||||
|
||||
def update_instance(self, **kwargs):
|
||||
def update_instance(self, **kwargs: Any) -> None:
|
||||
for field, value in kwargs.items():
|
||||
setattr(self.instance, field, value)
|
||||
self.instance.save()
|
||||
|
||||
|
||||
class FieldTrackerCommonTests:
|
||||
class FieldTrackerCommonMixin(FieldTrackerMixin):
|
||||
|
||||
instance: (
|
||||
Tracked | TrackedNotDefault | TrackedMultiple
|
||||
| ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple
|
||||
| TrackedAbstract
|
||||
)
|
||||
|
||||
def test_pre_save_previous(self) -> None:
|
||||
self.assertPrevious(name=None, number=None)
|
||||
|
|
@ -72,9 +86,10 @@ class FieldTrackerCommonTests:
|
|||
self.assertPrevious(name=None, number=None)
|
||||
|
||||
|
||||
class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||
class FieldTrackerTests(FieldTrackerCommonMixin, TestCase):
|
||||
|
||||
tracked_class: type[models.Model] = Tracked
|
||||
tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked
|
||||
instance: Tracked | ModelTracked | TrackedAbstract
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
|
|
@ -219,6 +234,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.number = 1
|
||||
self.instance.save()
|
||||
item = self.tracked_class.objects.only('name').first()
|
||||
assert item is not None
|
||||
self.assertTrue(item.get_deferred_fields())
|
||||
|
||||
# has_changed() returns False for deferred fields, without un-deferring them.
|
||||
|
|
@ -234,6 +250,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
|
||||
# examining a deferred field un-defers it
|
||||
item = self.tracked_class.objects.only('name').first()
|
||||
assert item is not None
|
||||
self.assertEqual(item.number, 1)
|
||||
self.assertTrue('number' not in item.get_deferred_fields())
|
||||
self.assertEqual(item.tracker.previous('number'), 1)
|
||||
|
|
@ -252,6 +269,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
if self.tracked_class == Tracked:
|
||||
|
||||
item = self.tracked_class.objects.only('name').first()
|
||||
assert item is not None
|
||||
item.number = 2
|
||||
|
||||
# previous() fetches correct value from database after deferred field is assigned
|
||||
|
|
@ -278,10 +296,10 @@ class FieldTrackerMultipleInstancesTests(TestCase):
|
|||
instance.name
|
||||
|
||||
|
||||
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
||||
FieldTrackerCommonTests):
|
||||
class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase):
|
||||
|
||||
tracked_class: type[models.Model] = TrackedNotDefault
|
||||
tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault
|
||||
instance: TrackedNotDefault | ModelTrackedNotDefault
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
|
|
@ -358,9 +376,10 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
|||
self.assertChanged()
|
||||
|
||||
|
||||
class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
|
||||
class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase):
|
||||
|
||||
tracked_class = TrackedNonFieldAttr
|
||||
instance: TrackedNonFieldAttr
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
|
|
@ -409,10 +428,10 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
|
|||
self.assertCurrent(rounded=8)
|
||||
|
||||
|
||||
class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
||||
FieldTrackerCommonTests):
|
||||
class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase):
|
||||
|
||||
tracked_class: type[models.Model] = TrackedMultiple
|
||||
tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple
|
||||
instance: TrackedMultiple | ModelTrackedMultiple
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
|
|
@ -501,10 +520,11 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
|||
self.assertCurrent(tracker=self.trackers[1], number=8)
|
||||
|
||||
|
||||
class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
|
||||
class FieldTrackerForeignKeyMixin(FieldTrackerMixin):
|
||||
|
||||
fk_class: type[models.Model] = Tracked
|
||||
tracked_class: type[models.Model] = TrackedFK
|
||||
fk_class: type[Tracked | ModelTracked]
|
||||
tracked_class: type[TrackedFK | ModelTrackedFK]
|
||||
instance: TrackedFK | ModelTrackedFK
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.old_fk = self.fk_class.objects.create(number=8)
|
||||
|
|
@ -543,11 +563,18 @@ class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
|
|||
self.assertCurrent(fk=self.instance.fk_id)
|
||||
|
||||
|
||||
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase):
|
||||
class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
|
||||
|
||||
fk_class = Tracked
|
||||
tracked_class = TrackedFK
|
||||
|
||||
|
||||
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase):
|
||||
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError."""
|
||||
|
||||
fk_class = Tracked
|
||||
tracked_class = TrackedFK
|
||||
instance: TrackedFK
|
||||
|
||||
def setUp(self) -> None:
|
||||
model_tracked = self.fk_class.objects.create(name="", number=0)
|
||||
|
|
@ -566,10 +593,11 @@ class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase):
|
|||
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
|
||||
|
||||
|
||||
class FieldTrackerTimeStampedTests(FieldTrackerTestCase):
|
||||
class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase):
|
||||
|
||||
fk_class = Tracked
|
||||
tracked_class = TrackerTimeStamped
|
||||
instance: TrackerTimeStamped
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class.objects.create(name='old', number=1)
|
||||
|
|
@ -605,9 +633,10 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests):
|
|||
tracked_class = InheritedTrackedFK
|
||||
|
||||
|
||||
class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
||||
class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase):
|
||||
|
||||
tracked_class = TrackedFileField
|
||||
instance: TrackedFileField
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
|
|
@ -629,7 +658,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
|||
self.assertEqual(self.tracker.saved_data, {})
|
||||
self.update_instance(some_file=self.some_file)
|
||||
field_file_copy = self.tracker.saved_data.get('some_file')
|
||||
self.assertIsNotNone(field_file_copy)
|
||||
assert field_file_copy is not None
|
||||
self.assertEqual(field_file_copy.__getstate__().get('instance'), None)
|
||||
self.assertEqual(self.instance.some_file.instance, self.instance)
|
||||
self.assertIsInstance(self.instance.some_file, FieldFile)
|
||||
|
|
@ -730,7 +759,8 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
|||
|
||||
class ModelTrackerTests(FieldTrackerTests):
|
||||
|
||||
tracked_class: type[models.Model] = ModelTracked
|
||||
tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked
|
||||
instance: ModelTracked
|
||||
|
||||
def test_cache_compatible(self) -> None:
|
||||
cache.set('key', self.instance)
|
||||
|
|
@ -846,10 +876,11 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
|
|||
self.assertChanged()
|
||||
|
||||
|
||||
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):
|
||||
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
|
||||
|
||||
fk_class = ModelTracked
|
||||
tracked_class = ModelTrackedFK
|
||||
instance: ModelTrackedFK
|
||||
|
||||
def test_custom_without_id(self) -> None:
|
||||
with self.assertNumQueries(2):
|
||||
|
|
@ -886,11 +917,11 @@ class TrackerContextDecoratorTests(TestCase):
|
|||
self.instance = Tracked.objects.create(number=1)
|
||||
self.tracker = self.instance.tracker
|
||||
|
||||
def assertChanged(self, *fields):
|
||||
def assertChanged(self, *fields: str) -> None:
|
||||
for f in fields:
|
||||
self.assertTrue(self.tracker.has_changed(f))
|
||||
|
||||
def assertNotChanged(self, *fields):
|
||||
def assertNotChanged(self, *fields: str) -> None:
|
||||
for f in fields:
|
||||
self.assertFalse(self.tracker.has_changed(f))
|
||||
|
||||
|
|
@ -921,7 +952,7 @@ class TrackerContextDecoratorTests(TestCase):
|
|||
def test_tracker_decorator(self) -> None:
|
||||
|
||||
@Tracked.tracker
|
||||
def tracked_method(obj):
|
||||
def tracked_method(obj: Tracked) -> None:
|
||||
obj.name = 'new'
|
||||
self.assertChanged('name')
|
||||
|
||||
|
|
@ -932,7 +963,7 @@ class TrackerContextDecoratorTests(TestCase):
|
|||
def test_tracker_decorator_fields(self) -> None:
|
||||
|
||||
@Tracked.tracker(fields=['name'])
|
||||
def tracked_method(obj):
|
||||
def tracked_method(obj: Tracked) -> None:
|
||||
obj.name = 'new'
|
||||
obj.number += 1
|
||||
self.assertChanged('name', 'number')
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class MonitorFieldTests(TestCase):
|
|||
|
||||
def test_no_monitor_arg(self) -> None:
|
||||
with self.assertRaises(TypeError):
|
||||
MonitorField()
|
||||
MonitorField() # type: ignore[call-arg]
|
||||
|
||||
def test_monitor_default_is_none_when_nullable(self) -> None:
|
||||
self.assertIsNone(self.instance.name_changed_nullable)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class UrlsaftTokenFieldTests(TestCase):
|
|||
|
||||
def test_factory_not_callable(self) -> None:
|
||||
with self.assertRaises(TypeError):
|
||||
UrlsafeTokenField(factory='INVALID')
|
||||
UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type]
|
||||
|
||||
def test_get_default(self) -> None:
|
||||
field = UrlsafeTokenField()
|
||||
|
|
@ -57,8 +57,8 @@ class UrlsaftTokenFieldTests(TestCase):
|
|||
self.assertIs(field.default, NOT_PROVIDED)
|
||||
|
||||
def test_deconstruct(self) -> None:
|
||||
def test_factory() -> None:
|
||||
pass
|
||||
def test_factory(max_length: int) -> str:
|
||||
assert False
|
||||
instance = UrlsafeTokenField(factory=test_factory)
|
||||
name, path, args, kwargs = instance.deconstruct()
|
||||
new_instance = UrlsafeTokenField(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue