From 7d6cad0200dba7787707a1396c36fc0fc0036914 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Fri, 29 Mar 2024 17:12:32 +0100 Subject: [PATCH] Annotate `test_field_tracker` module --- tests/test_fields/test_field_tracker.py | 93 ++++++++++++------- tests/test_fields/test_monitor_field.py | 2 +- tests/test_fields/test_urlsafe_token_field.py | 6 +- 3 files changed, 66 insertions(+), 35 deletions(-) diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index 91a84c9..4dc2dc3 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any + from django.core.cache import cache from django.core.exceptions import FieldError from django.db import models @@ -7,7 +9,7 @@ from django.db.models.fields.files import FieldFile from django.test import TestCase from model_utils import FieldTracker -from model_utils.tracker import DescriptorWrapper +from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker from tests.models import ( InheritedModelTracked, InheritedTracked, @@ -26,12 +28,18 @@ from tests.models import ( TrackerTimeStamped, ) +if TYPE_CHECKING: + MixinBase = TestCase +else: + MixinBase = object -class FieldTrackerTestCase(TestCase): - tracker = None +class FieldTrackerMixin(MixinBase): - def assertHasChanged(self, *, tracker=None, **kwargs): + tracker: FieldInstanceTracker + instance: models.Model + + def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker for field, value in kwargs.items(): @@ -41,29 +49,35 @@ class FieldTrackerTestCase(TestCase): else: self.assertEqual(tracker.has_changed(field), value) - def assertPrevious(self, *, tracker=None, **kwargs): + def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker for field, value in kwargs.items(): self.assertEqual(tracker.previous(field), value) - def assertChanged(self, *, tracker=None, **kwargs): + def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker self.assertEqual(tracker.changed(), kwargs) - def assertCurrent(self, *, tracker=None, **kwargs): + def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None: if tracker is None: tracker = self.tracker self.assertEqual(tracker.current(), kwargs) - def update_instance(self, **kwargs): + def update_instance(self, **kwargs: Any) -> None: for field, value in kwargs.items(): setattr(self.instance, field, value) self.instance.save() -class FieldTrackerCommonTests: +class FieldTrackerCommonMixin(FieldTrackerMixin): + + instance: ( + Tracked | TrackedNotDefault | TrackedMultiple + | ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple + | TrackedAbstract + ) def test_pre_save_previous(self) -> None: self.assertPrevious(name=None, number=None) @@ -72,9 +86,10 @@ class FieldTrackerCommonTests: self.assertPrevious(name=None, number=None) -class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): +class FieldTrackerTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = Tracked + tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked + instance: Tracked | ModelTracked | TrackedAbstract def setUp(self) -> None: self.instance = self.tracked_class() @@ -219,6 +234,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.instance.number = 1 self.instance.save() item = self.tracked_class.objects.only('name').first() + assert item is not None self.assertTrue(item.get_deferred_fields()) # has_changed() returns False for deferred fields, without un-deferring them. @@ -234,6 +250,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): # examining a deferred field un-defers it item = self.tracked_class.objects.only('name').first() + assert item is not None self.assertEqual(item.number, 1) self.assertTrue('number' not in item.get_deferred_fields()) self.assertEqual(item.tracker.previous('number'), 1) @@ -252,6 +269,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): if self.tracked_class == Tracked: item = self.tracked_class.objects.only('name').first() + assert item is not None item.number = 2 # previous() fetches correct value from database after deferred field is assigned @@ -278,10 +296,10 @@ class FieldTrackerMultipleInstancesTests(TestCase): instance.name -class FieldTrackedModelCustomTests(FieldTrackerTestCase, - FieldTrackerCommonTests): +class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = TrackedNotDefault + tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault + instance: TrackedNotDefault | ModelTrackedNotDefault def setUp(self) -> None: self.instance = self.tracked_class() @@ -358,9 +376,10 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase, self.assertChanged() -class FieldTrackedModelAttributeTests(FieldTrackerTestCase): +class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase): tracked_class = TrackedNonFieldAttr + instance: TrackedNonFieldAttr def setUp(self) -> None: self.instance = self.tracked_class() @@ -409,10 +428,10 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase): self.assertCurrent(rounded=8) -class FieldTrackedModelMultiTests(FieldTrackerTestCase, - FieldTrackerCommonTests): +class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase): - tracked_class: type[models.Model] = TrackedMultiple + tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple + instance: TrackedMultiple | ModelTrackedMultiple def setUp(self) -> None: self.instance = self.tracked_class() @@ -501,10 +520,11 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase, self.assertCurrent(tracker=self.trackers[1], number=8) -class FieldTrackerForeignKeyTests(FieldTrackerTestCase): +class FieldTrackerForeignKeyMixin(FieldTrackerMixin): - fk_class: type[models.Model] = Tracked - tracked_class: type[models.Model] = TrackedFK + fk_class: type[Tracked | ModelTracked] + tracked_class: type[TrackedFK | ModelTrackedFK] + instance: TrackedFK | ModelTrackedFK def setUp(self) -> None: self.old_fk = self.fk_class.objects.create(number=8) @@ -543,11 +563,18 @@ class FieldTrackerForeignKeyTests(FieldTrackerTestCase): self.assertCurrent(fk=self.instance.fk_id) -class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase): +class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase): + + fk_class = Tracked + tracked_class = TrackedFK + + +class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase): """Test that using `prefetch_related` on a tracked field does not raise a ValueError.""" fk_class = Tracked tracked_class = TrackedFK + instance: TrackedFK def setUp(self) -> None: model_tracked = self.fk_class.objects.create(name="", number=0) @@ -566,10 +593,11 @@ class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase): self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) -class FieldTrackerTimeStampedTests(FieldTrackerTestCase): +class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase): fk_class = Tracked tracked_class = TrackerTimeStamped + instance: TrackerTimeStamped def setUp(self) -> None: self.instance = self.tracked_class.objects.create(name='old', number=1) @@ -605,9 +633,10 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests): tracked_class = InheritedTrackedFK -class FieldTrackerFileFieldTests(FieldTrackerTestCase): +class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase): tracked_class = TrackedFileField + instance: TrackedFileField def setUp(self) -> None: self.instance = self.tracked_class() @@ -629,7 +658,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase): self.assertEqual(self.tracker.saved_data, {}) self.update_instance(some_file=self.some_file) field_file_copy = self.tracker.saved_data.get('some_file') - self.assertIsNotNone(field_file_copy) + assert field_file_copy is not None self.assertEqual(field_file_copy.__getstate__().get('instance'), None) self.assertEqual(self.instance.some_file.instance, self.instance) self.assertIsInstance(self.instance.some_file, FieldFile) @@ -730,7 +759,8 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase): class ModelTrackerTests(FieldTrackerTests): - tracked_class: type[models.Model] = ModelTracked + tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked + instance: ModelTracked def test_cache_compatible(self) -> None: cache.set('key', self.instance) @@ -846,10 +876,11 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests): self.assertChanged() -class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests): +class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase): fk_class = ModelTracked tracked_class = ModelTrackedFK + instance: ModelTrackedFK def test_custom_without_id(self) -> None: with self.assertNumQueries(2): @@ -886,11 +917,11 @@ class TrackerContextDecoratorTests(TestCase): self.instance = Tracked.objects.create(number=1) self.tracker = self.instance.tracker - def assertChanged(self, *fields): + def assertChanged(self, *fields: str) -> None: for f in fields: self.assertTrue(self.tracker.has_changed(f)) - def assertNotChanged(self, *fields): + def assertNotChanged(self, *fields: str) -> None: for f in fields: self.assertFalse(self.tracker.has_changed(f)) @@ -921,7 +952,7 @@ class TrackerContextDecoratorTests(TestCase): def test_tracker_decorator(self) -> None: @Tracked.tracker - def tracked_method(obj): + def tracked_method(obj: Tracked) -> None: obj.name = 'new' self.assertChanged('name') @@ -932,7 +963,7 @@ class TrackerContextDecoratorTests(TestCase): def test_tracker_decorator_fields(self) -> None: @Tracked.tracker(fields=['name']) - def tracked_method(obj): + def tracked_method(obj: Tracked) -> None: obj.name = 'new' obj.number += 1 self.assertChanged('name', 'number') diff --git a/tests/test_fields/test_monitor_field.py b/tests/test_fields/test_monitor_field.py index 9c9ba84..19ed902 100644 --- a/tests/test_fields/test_monitor_field.py +++ b/tests/test_fields/test_monitor_field.py @@ -34,7 +34,7 @@ class MonitorFieldTests(TestCase): def test_no_monitor_arg(self) -> None: with self.assertRaises(TypeError): - MonitorField() + MonitorField() # type: ignore[call-arg] def test_monitor_default_is_none_when_nullable(self) -> None: self.assertIsNone(self.instance.name_changed_nullable) diff --git a/tests/test_fields/test_urlsafe_token_field.py b/tests/test_fields/test_urlsafe_token_field.py index 6146fe6..72bbcda 100644 --- a/tests/test_fields/test_urlsafe_token_field.py +++ b/tests/test_fields/test_urlsafe_token_field.py @@ -31,7 +31,7 @@ class UrlsaftTokenFieldTests(TestCase): def test_factory_not_callable(self) -> None: with self.assertRaises(TypeError): - UrlsafeTokenField(factory='INVALID') + UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type] def test_get_default(self) -> None: field = UrlsafeTokenField() @@ -57,8 +57,8 @@ class UrlsaftTokenFieldTests(TestCase): self.assertIs(field.default, NOT_PROVIDED) def test_deconstruct(self) -> None: - def test_factory() -> None: - pass + def test_factory(max_length: int) -> str: + assert False instance = UrlsafeTokenField(factory=test_factory) name, path, args, kwargs = instance.deconstruct() new_instance = UrlsafeTokenField(*args, **kwargs)