Annotate return type of test methods

This commit is contained in:
Maarten ter Huurne 2023-03-22 18:50:18 +01:00
parent 218843d754
commit 23f1811b9d
20 changed files with 295 additions and 289 deletions

View file

@ -6,61 +6,61 @@ from model_utils import Choices
class ChoicesTests(TestCase): class ChoicesTests(TestCase):
def setUp(self): def setUp(self) -> None:
self.STATUS = Choices('DRAFT', 'PUBLISHED') self.STATUS = Choices('DRAFT', 'PUBLISHED')
def test_getattr(self): def test_getattr(self) -> None:
self.assertEqual(self.STATUS.DRAFT, 'DRAFT') self.assertEqual(self.STATUS.DRAFT, 'DRAFT')
def test_indexing(self): def test_indexing(self) -> None:
self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED')
def test_iteration(self): def test_iteration(self) -> None:
self.assertEqual(tuple(self.STATUS), self.assertEqual(tuple(self.STATUS),
(('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED')))
def test_reversed(self): def test_reversed(self) -> None:
self.assertEqual(tuple(reversed(self.STATUS)), self.assertEqual(tuple(reversed(self.STATUS)),
(('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT'))) (('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT')))
def test_len(self): def test_len(self) -> None:
self.assertEqual(len(self.STATUS), 2) self.assertEqual(len(self.STATUS), 2)
def test_repr(self): def test_repr(self) -> None:
self.assertEqual(repr(self.STATUS), "Choices" + repr(( self.assertEqual(repr(self.STATUS), "Choices" + repr((
('DRAFT', 'DRAFT', 'DRAFT'), ('DRAFT', 'DRAFT', 'DRAFT'),
('PUBLISHED', 'PUBLISHED', 'PUBLISHED'), ('PUBLISHED', 'PUBLISHED', 'PUBLISHED'),
))) )))
def test_wrong_length_tuple(self): def test_wrong_length_tuple(self) -> None:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
Choices(('a',)) Choices(('a',))
def test_contains_value(self): def test_contains_value(self) -> None:
self.assertTrue('PUBLISHED' in self.STATUS) self.assertTrue('PUBLISHED' in self.STATUS)
self.assertTrue('DRAFT' in self.STATUS) self.assertTrue('DRAFT' in self.STATUS)
def test_doesnt_contain_value(self): def test_doesnt_contain_value(self) -> None:
self.assertFalse('UNPUBLISHED' in self.STATUS) self.assertFalse('UNPUBLISHED' in self.STATUS)
def test_deepcopy(self): def test_deepcopy(self) -> None:
import copy import copy
self.assertEqual(list(self.STATUS), self.assertEqual(list(self.STATUS),
list(copy.deepcopy(self.STATUS))) list(copy.deepcopy(self.STATUS)))
def test_equality(self): def test_equality(self) -> None:
self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED')) self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED'))
def test_inequality(self): def test_inequality(self) -> None:
self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED']) self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED'])
self.assertNotEqual(self.STATUS, Choices('DRAFT')) self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_composability(self): def test_composability(self) -> None:
self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS) self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS)
self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS) self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS)
self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS) self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS)
def test_option_groups(self): def test_option_groups(self) -> None:
c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) c = Choices(('group a', ['one', 'two']), ['group b', ('three',)])
self.assertEqual( self.assertEqual(
list(c), list(c),
@ -72,47 +72,47 @@ class ChoicesTests(TestCase):
class LabelChoicesTests(ChoicesTests): class LabelChoicesTests(ChoicesTests):
def setUp(self): def setUp(self) -> None:
self.STATUS = Choices( self.STATUS = Choices(
('DRAFT', 'is draft'), ('DRAFT', 'is draft'),
('PUBLISHED', 'is published'), ('PUBLISHED', 'is published'),
'DELETED', 'DELETED',
) )
def test_iteration(self): def test_iteration(self) -> None:
self.assertEqual(tuple(self.STATUS), ( self.assertEqual(tuple(self.STATUS), (
('DRAFT', 'is draft'), ('DRAFT', 'is draft'),
('PUBLISHED', 'is published'), ('PUBLISHED', 'is published'),
('DELETED', 'DELETED'), ('DELETED', 'DELETED'),
)) ))
def test_reversed(self): def test_reversed(self) -> None:
self.assertEqual(tuple(reversed(self.STATUS)), ( self.assertEqual(tuple(reversed(self.STATUS)), (
('DELETED', 'DELETED'), ('DELETED', 'DELETED'),
('PUBLISHED', 'is published'), ('PUBLISHED', 'is published'),
('DRAFT', 'is draft'), ('DRAFT', 'is draft'),
)) ))
def test_indexing(self): def test_indexing(self) -> None:
self.assertEqual(self.STATUS['PUBLISHED'], 'is published') self.assertEqual(self.STATUS['PUBLISHED'], 'is published')
def test_default(self): def test_default(self) -> None:
self.assertEqual(self.STATUS.DELETED, 'DELETED') self.assertEqual(self.STATUS.DELETED, 'DELETED')
def test_provided(self): def test_provided(self) -> None:
self.assertEqual(self.STATUS.DRAFT, 'DRAFT') self.assertEqual(self.STATUS.DRAFT, 'DRAFT')
def test_len(self): def test_len(self) -> None:
self.assertEqual(len(self.STATUS), 3) self.assertEqual(len(self.STATUS), 3)
def test_equality(self): def test_equality(self) -> None:
self.assertEqual(self.STATUS, Choices( self.assertEqual(self.STATUS, Choices(
('DRAFT', 'is draft'), ('DRAFT', 'is draft'),
('PUBLISHED', 'is published'), ('PUBLISHED', 'is published'),
'DELETED', 'DELETED',
)) ))
def test_inequality(self): def test_inequality(self) -> None:
self.assertNotEqual(self.STATUS, [ self.assertNotEqual(self.STATUS, [
('DRAFT', 'is draft'), ('DRAFT', 'is draft'),
('PUBLISHED', 'is published'), ('PUBLISHED', 'is published'),
@ -120,27 +120,27 @@ class LabelChoicesTests(ChoicesTests):
]) ])
self.assertNotEqual(self.STATUS, Choices('DRAFT')) self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_repr(self): def test_repr(self) -> None:
self.assertEqual(repr(self.STATUS), "Choices" + repr(( self.assertEqual(repr(self.STATUS), "Choices" + repr((
('DRAFT', 'DRAFT', 'is draft'), ('DRAFT', 'DRAFT', 'is draft'),
('PUBLISHED', 'PUBLISHED', 'is published'), ('PUBLISHED', 'PUBLISHED', 'is published'),
('DELETED', 'DELETED', 'DELETED'), ('DELETED', 'DELETED', 'DELETED'),
))) )))
def test_contains_value(self): def test_contains_value(self) -> None:
self.assertTrue('PUBLISHED' in self.STATUS) self.assertTrue('PUBLISHED' in self.STATUS)
self.assertTrue('DRAFT' in self.STATUS) self.assertTrue('DRAFT' in self.STATUS)
# This should be True, because both the display value # This should be True, because both the display value
# and the internal representation are both DELETED. # and the internal representation are both DELETED.
self.assertTrue('DELETED' in self.STATUS) self.assertTrue('DELETED' in self.STATUS)
def test_doesnt_contain_value(self): def test_doesnt_contain_value(self) -> None:
self.assertFalse('UNPUBLISHED' in self.STATUS) self.assertFalse('UNPUBLISHED' in self.STATUS)
def test_doesnt_contain_display_value(self): def test_doesnt_contain_display_value(self) -> None:
self.assertFalse('is draft' in self.STATUS) self.assertFalse('is draft' in self.STATUS)
def test_composability(self): def test_composability(self) -> None:
self.assertEqual( self.assertEqual(
Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'), Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'),
self.STATUS self.STATUS
@ -156,7 +156,7 @@ class LabelChoicesTests(ChoicesTests):
self.STATUS self.STATUS
) )
def test_option_groups(self): def test_option_groups(self) -> None:
c = Choices( c = Choices(
('group a', [(1, 'one'), (2, 'two')]), ('group a', [(1, 'one'), (2, 'two')]),
['group b', ((3, 'three'),)] ['group b', ((3, 'three'),)]
@ -171,64 +171,64 @@ class LabelChoicesTests(ChoicesTests):
class IdentifierChoicesTests(ChoicesTests): class IdentifierChoicesTests(ChoicesTests):
def setUp(self): def setUp(self) -> None:
self.STATUS = Choices( self.STATUS = Choices(
(0, 'DRAFT', 'is draft'), (0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'), (1, 'PUBLISHED', 'is published'),
(2, 'DELETED', 'is deleted')) (2, 'DELETED', 'is deleted'))
def test_iteration(self): def test_iteration(self) -> None:
self.assertEqual(tuple(self.STATUS), ( self.assertEqual(tuple(self.STATUS), (
(0, 'is draft'), (0, 'is draft'),
(1, 'is published'), (1, 'is published'),
(2, 'is deleted'), (2, 'is deleted'),
)) ))
def test_reversed(self): def test_reversed(self) -> None:
self.assertEqual(tuple(reversed(self.STATUS)), ( self.assertEqual(tuple(reversed(self.STATUS)), (
(2, 'is deleted'), (2, 'is deleted'),
(1, 'is published'), (1, 'is published'),
(0, 'is draft'), (0, 'is draft'),
)) ))
def test_indexing(self): def test_indexing(self) -> None:
self.assertEqual(self.STATUS[1], 'is published') self.assertEqual(self.STATUS[1], 'is published')
def test_getattr(self): def test_getattr(self) -> None:
self.assertEqual(self.STATUS.DRAFT, 0) self.assertEqual(self.STATUS.DRAFT, 0)
def test_len(self): def test_len(self) -> None:
self.assertEqual(len(self.STATUS), 3) self.assertEqual(len(self.STATUS), 3)
def test_repr(self): def test_repr(self) -> None:
self.assertEqual(repr(self.STATUS), "Choices" + repr(( self.assertEqual(repr(self.STATUS), "Choices" + repr((
(0, 'DRAFT', 'is draft'), (0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'), (1, 'PUBLISHED', 'is published'),
(2, 'DELETED', 'is deleted'), (2, 'DELETED', 'is deleted'),
))) )))
def test_contains_value(self): def test_contains_value(self) -> None:
self.assertTrue(0 in self.STATUS) self.assertTrue(0 in self.STATUS)
self.assertTrue(1 in self.STATUS) self.assertTrue(1 in self.STATUS)
self.assertTrue(2 in self.STATUS) self.assertTrue(2 in self.STATUS)
def test_doesnt_contain_value(self): def test_doesnt_contain_value(self) -> None:
self.assertFalse(3 in self.STATUS) self.assertFalse(3 in self.STATUS)
def test_doesnt_contain_display_value(self): def test_doesnt_contain_display_value(self) -> None:
self.assertFalse('is draft' in self.STATUS) self.assertFalse('is draft' in self.STATUS)
def test_doesnt_contain_python_attr(self): def test_doesnt_contain_python_attr(self) -> None:
self.assertFalse('PUBLISHED' in self.STATUS) self.assertFalse('PUBLISHED' in self.STATUS)
def test_equality(self): def test_equality(self) -> None:
self.assertEqual(self.STATUS, Choices( self.assertEqual(self.STATUS, Choices(
(0, 'DRAFT', 'is draft'), (0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'), (1, 'PUBLISHED', 'is published'),
(2, 'DELETED', 'is deleted') (2, 'DELETED', 'is deleted')
)) ))
def test_inequality(self): def test_inequality(self) -> None:
self.assertNotEqual(self.STATUS, [ self.assertNotEqual(self.STATUS, [
(0, 'DRAFT', 'is draft'), (0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'), (1, 'PUBLISHED', 'is published'),
@ -236,7 +236,7 @@ class IdentifierChoicesTests(ChoicesTests):
]) ])
self.assertNotEqual(self.STATUS, Choices('DRAFT')) self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_composability(self): def test_composability(self) -> None:
self.assertEqual( self.assertEqual(
Choices( Choices(
(0, 'DRAFT', 'is draft'), (0, 'DRAFT', 'is draft'),
@ -267,7 +267,7 @@ class IdentifierChoicesTests(ChoicesTests):
self.STATUS self.STATUS
) )
def test_option_groups(self): def test_option_groups(self) -> None:
c = Choices( c = Choices(
('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]),
['group b', ((3, 'THREE', 'three'),)] ['group b', ((3, 'THREE', 'three'),)]
@ -283,26 +283,26 @@ class IdentifierChoicesTests(ChoicesTests):
class SubsetChoicesTest(TestCase): class SubsetChoicesTest(TestCase):
def setUp(self): def setUp(self) -> None:
self.choices = Choices( self.choices = Choices(
(0, 'a', 'A'), (0, 'a', 'A'),
(1, 'b', 'B'), (1, 'b', 'B'),
) )
def test_nonexistent_identifiers_raise(self): def test_nonexistent_identifiers_raise(self) -> None:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.choices.subset('a', 'c') self.choices.subset('a', 'c')
def test_solo_nonexistent_identifiers_raise(self): def test_solo_nonexistent_identifiers_raise(self) -> None:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.choices.subset('c') self.choices.subset('c')
def test_empty_subset_passes(self): def test_empty_subset_passes(self) -> None:
subset = self.choices.subset() subset = self.choices.subset()
self.assertEqual(subset, Choices()) self.assertEqual(subset, Choices())
def test_subset_returns_correct_subset(self): def test_subset_returns_correct_subset(self) -> None:
subset = self.choices.subset('a') subset = self.choices.subset('a')
self.assertEqual(subset, Choices((0, 'a', 'A'))) self.assertEqual(subset, Choices((0, 'a', 'A')))

View file

@ -65,7 +65,7 @@ class FieldTrackerTestCase(TestCase):
class FieldTrackerCommonTests: class FieldTrackerCommonTests:
def test_pre_save_previous(self): def test_pre_save_previous(self) -> None:
self.assertPrevious(name=None, number=None) self.assertPrevious(name=None, number=None)
self.instance.name = 'new age' self.instance.name = 'new age'
self.instance.number = 8 self.instance.number = 8
@ -76,14 +76,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
tracked_class: type[models.Model] = Tracked tracked_class: type[models.Model] = Tracked
def setUp(self): def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
self.tracker = self.instance.tracker self.tracker = self.instance.tracker
def test_descriptor(self): def test_descriptor(self) -> None:
self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker)) self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker))
def test_pre_save_changed(self): def test_pre_save_changed(self) -> None:
self.assertChanged(name=None) self.assertChanged(name=None)
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertChanged(name=None) self.assertChanged(name=None)
@ -94,7 +94,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.mutable = [1, 2, 3] self.instance.mutable = [1, 2, 3]
self.assertChanged(name=None, number=None, mutable=None) self.assertChanged(name=None, number=None, mutable=None)
def test_pre_save_has_changed(self): def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(name=True, number=False, mutable=False) self.assertHasChanged(name=True, number=False, mutable=False)
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertHasChanged(name=True, number=False, mutable=False) self.assertHasChanged(name=True, number=False, mutable=False)
@ -103,12 +103,12 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.mutable = [1, 2, 3] self.instance.mutable = [1, 2, 3]
self.assertHasChanged(name=True, number=True, mutable=True) self.assertHasChanged(name=True, number=True, mutable=True)
def test_save_with_args(self): def test_save_with_args(self) -> None:
self.instance.number = 1 self.instance.number = 1
self.instance.save(False, False, None, None) self.instance.save(False, False, None, None)
self.assertChanged() self.assertChanged()
def test_first_save(self): def test_first_save(self) -> None:
self.assertHasChanged(name=True, number=False, mutable=False) self.assertHasChanged(name=True, number=False, mutable=False)
self.assertPrevious(name=None, number=None, mutable=None) self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='', number=None, id=None, mutable=None) self.assertCurrent(name='', number=None, id=None, mutable=None)
@ -129,7 +129,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.instance.save(update_fields=['number']) self.instance.save(update_fields=['number'])
def test_post_save_has_changed(self): def test_post_save_has_changed(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertHasChanged(name=False, number=False, mutable=False) self.assertHasChanged(name=False, number=False, mutable=False)
self.instance.name = 'new age' self.instance.name = 'new age'
@ -141,14 +141,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.name = 'retro' self.instance.name = 'retro'
self.assertHasChanged(name=False, number=True, mutable=True) self.assertHasChanged(name=False, number=True, mutable=True)
def test_post_save_previous(self): def test_post_save_previous(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
self.instance.mutable[1] = 4 self.instance.mutable[1] = 4
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
def test_post_save_changed(self): def test_post_save_changed(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertChanged() self.assertChanged()
self.instance.name = 'new age' self.instance.name = 'new age'
@ -162,7 +162,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.mutable = [1, 2, 3] self.instance.mutable = [1, 2, 3]
self.assertChanged(number=4) self.assertChanged(number=4)
def test_current(self): def test_current(self) -> None:
self.assertCurrent(id=None, name='', number=None, mutable=None) self.assertCurrent(id=None, name='', number=None, mutable=None)
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertCurrent(id=None, name='new age', number=None, mutable=None) self.assertCurrent(id=None, name='new age', number=None, mutable=None)
@ -175,7 +175,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.save() self.instance.save()
self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3]) self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3])
def test_update_fields(self): def test_update_fields(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertChanged() self.assertChanged()
self.instance.name = 'new age' self.instance.name = 'new age'
@ -198,7 +198,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertEqual(in_db.number, self.instance.number) self.assertEqual(in_db.number, self.instance.number)
self.assertEqual(in_db.mutable, self.instance.mutable) self.assertEqual(in_db.mutable, self.instance.mutable)
def test_refresh_from_db(self): def test_refresh_from_db(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.tracked_class.objects.filter(pk=self.instance.pk).update( self.tracked_class.objects.filter(pk=self.instance.pk).update(
name='new age', number=8, mutable=[3, 2, 1]) name='new age', number=8, mutable=[3, 2, 1])
@ -214,7 +214,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.refresh_from_db() self.instance.refresh_from_db()
self.assertChanged() self.assertChanged()
def test_with_deferred(self): def test_with_deferred(self) -> None:
self.instance.name = 'new age' self.instance.name = 'new age'
self.instance.number = 1 self.instance.number = 1
self.instance.save() self.instance.save()
@ -268,7 +268,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
class FieldTrackerMultipleInstancesTests(TestCase): class FieldTrackerMultipleInstancesTests(TestCase):
def test_with_deferred_fields_access_multiple(self): def test_with_deferred_fields_access_multiple(self) -> None:
Tracked.objects.create(pk=1, name='foo', number=1) Tracked.objects.create(pk=1, name='foo', number=1)
Tracked.objects.create(pk=2, name='bar', number=2) Tracked.objects.create(pk=2, name='bar', number=2)
@ -283,11 +283,11 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
tracked_class: type[models.Model] = TrackedNotDefault tracked_class: type[models.Model] = TrackedNotDefault
def setUp(self): def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
self.tracker = self.instance.name_tracker self.tracker = self.instance.name_tracker
def test_pre_save_changed(self): def test_pre_save_changed(self) -> None:
self.assertChanged(name=None) self.assertChanged(name=None)
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertChanged(name=None) self.assertChanged(name=None)
@ -296,7 +296,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.instance.name = '' self.instance.name = ''
self.assertChanged(name=None) self.assertChanged(name=None)
def test_first_save(self): def test_first_save(self) -> None:
self.assertHasChanged(name=True, number=None) self.assertHasChanged(name=True, number=None)
self.assertPrevious(name=None, number=None) self.assertPrevious(name=None, number=None)
self.assertCurrent(name='') self.assertCurrent(name='')
@ -308,14 +308,14 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.assertCurrent(name='retro') self.assertCurrent(name='retro')
self.assertChanged(name=None) self.assertChanged(name=None)
def test_pre_save_has_changed(self): def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(name=True, number=None) self.assertHasChanged(name=True, number=None)
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertHasChanged(name=True, number=None) self.assertHasChanged(name=True, number=None)
self.instance.number = 7 self.instance.number = 7
self.assertHasChanged(name=True, number=None) self.assertHasChanged(name=True, number=None)
def test_post_save_has_changed(self): def test_post_save_has_changed(self) -> None:
self.update_instance(name='retro', number=4) self.update_instance(name='retro', number=4)
self.assertHasChanged(name=False, number=None) self.assertHasChanged(name=False, number=None)
self.instance.name = 'new age' self.instance.name = 'new age'
@ -325,12 +325,12 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.instance.name = 'retro' self.instance.name = 'retro'
self.assertHasChanged(name=False, number=None) self.assertHasChanged(name=False, number=None)
def test_post_save_previous(self): def test_post_save_previous(self) -> None:
self.update_instance(name='retro', number=4) self.update_instance(name='retro', number=4)
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertPrevious(name='retro', number=None) self.assertPrevious(name='retro', number=None)
def test_post_save_changed(self): def test_post_save_changed(self) -> None:
self.update_instance(name='retro', number=4) self.update_instance(name='retro', number=4)
self.assertChanged() self.assertChanged()
self.instance.name = 'new age' self.instance.name = 'new age'
@ -340,7 +340,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.instance.name = 'retro' self.instance.name = 'retro'
self.assertChanged() self.assertChanged()
def test_current(self): def test_current(self) -> None:
self.assertCurrent(name='') self.assertCurrent(name='')
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertCurrent(name='new age') self.assertCurrent(name='new age')
@ -349,7 +349,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.instance.save() self.instance.save()
self.assertCurrent(name='new age') self.assertCurrent(name='new age')
def test_update_fields(self): def test_update_fields(self) -> None:
self.update_instance(name='retro', number=4) self.update_instance(name='retro', number=4)
self.assertChanged() self.assertChanged()
self.instance.name = 'new age' self.instance.name = 'new age'
@ -362,11 +362,11 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
tracked_class = TrackedNonFieldAttr tracked_class = TrackedNonFieldAttr
def setUp(self): def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
self.tracker = self.instance.tracker self.tracker = self.instance.tracker
def test_previous(self): def test_previous(self) -> None:
self.assertPrevious(rounded=None) self.assertPrevious(rounded=None)
self.instance.number = 7.5 self.instance.number = 7.5
self.assertPrevious(rounded=None) self.assertPrevious(rounded=None)
@ -377,7 +377,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
self.instance.save() self.instance.save()
self.assertPrevious(rounded=7) self.assertPrevious(rounded=7)
def test_has_changed(self): def test_has_changed(self) -> None:
self.assertHasChanged(rounded=False) self.assertHasChanged(rounded=False)
self.instance.number = 7.5 self.instance.number = 7.5
self.assertHasChanged(rounded=True) self.assertHasChanged(rounded=True)
@ -388,7 +388,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
self.instance.number = 7.8 self.instance.number = 7.8
self.assertHasChanged(rounded=False) self.assertHasChanged(rounded=False)
def test_changed(self): def test_changed(self) -> None:
self.assertChanged() self.assertChanged()
self.instance.number = 7.5 self.instance.number = 7.5
self.assertPrevious(rounded=None) self.assertPrevious(rounded=None)
@ -401,7 +401,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
self.instance.save() self.instance.save()
self.assertPrevious() self.assertPrevious()
def test_current(self): def test_current(self) -> None:
self.assertCurrent(rounded=None) self.assertCurrent(rounded=None)
self.instance.number = 7.5 self.instance.number = 7.5
self.assertCurrent(rounded=8) self.assertCurrent(rounded=8)
@ -414,12 +414,12 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
tracked_class: type[models.Model] = TrackedMultiple tracked_class: type[models.Model] = TrackedMultiple
def setUp(self): def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
self.trackers = [self.instance.name_tracker, self.trackers = [self.instance.name_tracker,
self.instance.number_tracker] self.instance.number_tracker]
def test_pre_save_changed(self): def test_pre_save_changed(self) -> None:
self.tracker = self.instance.name_tracker self.tracker = self.instance.name_tracker
self.assertChanged(name=None) self.assertChanged(name=None)
self.instance.name = 'new age' self.instance.name = 'new age'
@ -435,7 +435,7 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.instance.number = 8 self.instance.number = 8
self.assertChanged(number=None) self.assertChanged(number=None)
def test_pre_save_has_changed(self): def test_pre_save_has_changed(self) -> None:
self.tracker = self.instance.name_tracker self.tracker = self.instance.name_tracker
self.assertHasChanged(name=True, number=None) self.assertHasChanged(name=True, number=None)
self.instance.name = 'new age' self.instance.name = 'new age'
@ -445,12 +445,12 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertHasChanged(name=None, number=False) self.assertHasChanged(name=None, number=False)
def test_pre_save_previous(self): def test_pre_save_previous(self) -> None:
for tracker in self.trackers: for tracker in self.trackers:
self.tracker = tracker self.tracker = tracker
super().test_pre_save_previous() super().test_pre_save_previous()
def test_post_save_has_changed(self): def test_post_save_has_changed(self) -> None:
self.update_instance(name='retro', number=4) self.update_instance(name='retro', number=4)
self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) self.assertHasChanged(tracker=self.trackers[0], name=False, number=None)
self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) self.assertHasChanged(tracker=self.trackers[1], name=None, number=False)
@ -465,14 +465,14 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) self.assertHasChanged(tracker=self.trackers[0], name=False, number=None)
self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) self.assertHasChanged(tracker=self.trackers[1], name=None, number=False)
def test_post_save_previous(self): def test_post_save_previous(self) -> None:
self.update_instance(name='retro', number=4) self.update_instance(name='retro', number=4)
self.instance.name = 'new age' self.instance.name = 'new age'
self.instance.number = 8 self.instance.number = 8
self.assertPrevious(tracker=self.trackers[0], name='retro', number=None) self.assertPrevious(tracker=self.trackers[0], name='retro', number=None)
self.assertPrevious(tracker=self.trackers[1], name=None, number=4) self.assertPrevious(tracker=self.trackers[1], name=None, number=4)
def test_post_save_changed(self): def test_post_save_changed(self) -> None:
self.update_instance(name='retro', number=4) self.update_instance(name='retro', number=4)
self.assertChanged(tracker=self.trackers[0]) self.assertChanged(tracker=self.trackers[0])
self.assertChanged(tracker=self.trackers[1]) self.assertChanged(tracker=self.trackers[1])
@ -487,7 +487,7 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.assertChanged(tracker=self.trackers[0]) self.assertChanged(tracker=self.trackers[0])
self.assertChanged(tracker=self.trackers[1]) self.assertChanged(tracker=self.trackers[1])
def test_current(self): def test_current(self) -> None:
self.assertCurrent(tracker=self.trackers[0], name='') self.assertCurrent(tracker=self.trackers[0], name='')
self.assertCurrent(tracker=self.trackers[1], number=None) self.assertCurrent(tracker=self.trackers[1], number=None)
self.instance.name = 'new age' self.instance.name = 'new age'
@ -506,11 +506,11 @@ class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
fk_class: type[models.Model] = Tracked fk_class: type[models.Model] = Tracked
tracked_class: type[models.Model] = TrackedFK tracked_class: type[models.Model] = TrackedFK
def setUp(self): def setUp(self) -> None:
self.old_fk = self.fk_class.objects.create(number=8) self.old_fk = self.fk_class.objects.create(number=8)
self.instance = self.tracked_class.objects.create(fk=self.old_fk) self.instance = self.tracked_class.objects.create(fk=self.old_fk)
def test_default(self): def test_default(self) -> None:
self.tracker = self.instance.tracker self.tracker = self.instance.tracker
self.assertChanged() self.assertChanged()
self.assertPrevious() self.assertPrevious()
@ -520,7 +520,7 @@ class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
self.assertPrevious(fk_id=self.old_fk.id) self.assertPrevious(fk_id=self.old_fk.id)
self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id) self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id)
def test_custom(self): def test_custom(self) -> None:
self.tracker = self.instance.custom_tracker self.tracker = self.instance.custom_tracker
self.assertChanged() self.assertChanged()
self.assertPrevious() self.assertPrevious()
@ -530,7 +530,7 @@ class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
self.assertPrevious(fk_id=self.old_fk.id) self.assertPrevious(fk_id=self.old_fk.id)
self.assertCurrent(fk_id=self.instance.fk_id) self.assertCurrent(fk_id=self.instance.fk_id)
def test_custom_without_id(self): def test_custom_without_id(self) -> None:
with self.assertNumQueries(1): with self.assertNumQueries(1):
self.tracked_class.objects.get() self.tracked_class.objects.get()
self.tracker = self.instance.custom_tracker_without_id self.tracker = self.instance.custom_tracker_without_id
@ -549,19 +549,19 @@ class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase):
fk_class = Tracked fk_class = Tracked
tracked_class = TrackedFK tracked_class = TrackedFK
def setUp(self): def setUp(self) -> None:
model_tracked = self.fk_class.objects.create(name="", number=0) model_tracked = self.fk_class.objects.create(name="", number=0)
self.instance = self.tracked_class.objects.create(fk=model_tracked) self.instance = self.tracked_class.objects.create(fk=model_tracked)
def test_default(self): def test_default(self) -> None:
self.tracker = self.instance.tracker self.tracker = self.instance.tracker
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
def test_custom(self): def test_custom(self) -> None:
self.tracker = self.instance.custom_tracker self.tracker = self.instance.custom_tracker
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
def test_custom_without_id(self): def test_custom_without_id(self) -> None:
self.tracker = self.instance.custom_tracker_without_id self.tracker = self.instance.custom_tracker_without_id
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk"))) self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
@ -571,18 +571,18 @@ class FieldTrackerTimeStampedTests(FieldTrackerTestCase):
fk_class = Tracked fk_class = Tracked
tracked_class = TrackerTimeStamped tracked_class = TrackerTimeStamped
def setUp(self): def setUp(self) -> None:
self.instance = self.tracked_class.objects.create(name='old', number=1) self.instance = self.tracked_class.objects.create(name='old', number=1)
self.tracker = self.instance.tracker self.tracker = self.instance.tracker
def test_set_modified_on_save(self): def test_set_modified_on_save(self) -> None:
old_modified = self.instance.modified old_modified = self.instance.modified
self.instance.name = 'new' self.instance.name = 'new'
self.instance.save() self.instance.save()
self.assertGreater(self.instance.modified, old_modified) self.assertGreater(self.instance.modified, old_modified)
self.assertChanged() self.assertChanged()
def test_set_modified_on_save_update_fields(self): def test_set_modified_on_save_update_fields(self) -> None:
old_modified = self.instance.modified old_modified = self.instance.modified
self.instance.name = 'new' self.instance.name = 'new'
self.instance.save(update_fields=('name',)) self.instance.save(update_fields=('name',))
@ -594,7 +594,7 @@ class InheritedFieldTrackerTests(FieldTrackerTests):
tracked_class = InheritedTracked tracked_class = InheritedTracked
def test_child_fields_not_tracked(self): def test_child_fields_not_tracked(self) -> None:
self.name2 = 'test' self.name2 = 'test'
self.assertEqual(self.tracker.previous('name2'), None) self.assertEqual(self.tracker.previous('name2'), None)
self.assertRaises(FieldError, self.tracker.has_changed, 'name2') self.assertRaises(FieldError, self.tracker.has_changed, 'name2')
@ -609,13 +609,13 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
tracked_class = TrackedFileField tracked_class = TrackedFileField
def setUp(self): def setUp(self) -> None:
self.instance = self.tracked_class() self.instance = self.tracked_class()
self.tracker = self.instance.tracker self.tracker = self.instance.tracker
self.some_file = 'something.txt' self.some_file = 'something.txt'
self.another_file = 'another.txt' self.another_file = 'another.txt'
def test_saved_data_without_instance(self): def test_saved_data_without_instance(self) -> None:
""" """
Tests that instance won't get copied by the Field Tracker. Tests that instance won't get copied by the Field Tracker.
@ -634,22 +634,22 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
self.assertEqual(self.instance.some_file.instance, self.instance) self.assertEqual(self.instance.some_file.instance, self.instance)
self.assertIsInstance(self.instance.some_file, FieldFile) self.assertIsInstance(self.instance.some_file, FieldFile)
def test_pre_save_changed(self): def test_pre_save_changed(self) -> None:
self.assertChanged(some_file=None) self.assertChanged(some_file=None)
self.instance.some_file = self.some_file self.instance.some_file = self.some_file
self.assertChanged(some_file=None) self.assertChanged(some_file=None)
def test_pre_save_has_changed(self): def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(some_file=True) self.assertHasChanged(some_file=True)
self.instance.some_file = self.some_file self.instance.some_file = self.some_file
self.assertHasChanged(some_file=True) self.assertHasChanged(some_file=True)
def test_pre_save_previous(self): def test_pre_save_previous(self) -> None:
self.assertPrevious(some_file=None) self.assertPrevious(some_file=None)
self.instance.some_file = self.some_file self.instance.some_file = self.some_file
self.assertPrevious(some_file=None) self.assertPrevious(some_file=None)
def test_post_save_changed(self): def test_post_save_changed(self) -> None:
self.update_instance(some_file=self.some_file) self.update_instance(some_file=self.some_file)
self.assertChanged() self.assertChanged()
previous_file = self.instance.some_file previous_file = self.instance.some_file
@ -667,7 +667,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
some_file=previous_file, some_file=previous_file,
) )
def test_post_save_has_changed(self): def test_post_save_has_changed(self) -> None:
self.update_instance(some_file=self.some_file) self.update_instance(some_file=self.some_file)
self.assertHasChanged(some_file=False) self.assertHasChanged(some_file=False)
self.instance.some_file = self.another_file self.instance.some_file = self.another_file
@ -687,7 +687,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
some_file=True, some_file=True,
) )
def test_post_save_previous(self): def test_post_save_previous(self) -> None:
self.update_instance(some_file=self.some_file) self.update_instance(some_file=self.some_file)
previous_file = self.instance.some_file previous_file = self.instance.some_file
self.instance.some_file = self.another_file self.instance.some_file = self.another_file
@ -707,7 +707,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
some_file=previous_file, some_file=previous_file,
) )
def test_current(self): def test_current(self) -> None:
self.assertCurrent(some_file=self.instance.some_file, id=None) self.assertCurrent(some_file=self.instance.some_file, id=None)
self.instance.some_file = self.some_file self.instance.some_file = self.some_file
self.assertCurrent(some_file=self.instance.some_file, id=None) self.assertCurrent(some_file=self.instance.some_file, id=None)
@ -732,7 +732,7 @@ class ModelTrackerTests(FieldTrackerTests):
tracked_class: type[models.Model] = ModelTracked tracked_class: type[models.Model] = ModelTracked
def test_cache_compatible(self): def test_cache_compatible(self) -> None:
cache.set('key', self.instance) cache.set('key', self.instance)
instance = cache.get('key') instance = cache.get('key')
instance.number = 1 instance.number = 1
@ -742,7 +742,7 @@ class ModelTrackerTests(FieldTrackerTests):
instance.number = 2 instance.number = 2
self.assertHasChanged(number=True) self.assertHasChanged(number=True)
def test_pre_save_changed(self): def test_pre_save_changed(self) -> None:
self.assertChanged() self.assertChanged()
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertChanged() self.assertChanged()
@ -753,7 +753,7 @@ class ModelTrackerTests(FieldTrackerTests):
self.instance.mutable = [1, 2, 3] self.instance.mutable = [1, 2, 3]
self.assertChanged() self.assertChanged()
def test_first_save(self): def test_first_save(self) -> None:
self.assertHasChanged(name=True, number=True, mutable=True) self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None) self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='', number=None, id=None, mutable=None) self.assertCurrent(name='', number=None, id=None, mutable=None)
@ -774,7 +774,7 @@ class ModelTrackerTests(FieldTrackerTests):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.instance.save(update_fields=['number']) self.instance.save(update_fields=['number'])
def test_pre_save_has_changed(self): def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(name=True, number=True) self.assertHasChanged(name=True, number=True)
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertHasChanged(name=True, number=True) self.assertHasChanged(name=True, number=True)
@ -786,7 +786,7 @@ class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):
tracked_class = ModelTrackedNotDefault tracked_class = ModelTrackedNotDefault
def test_first_save(self): def test_first_save(self) -> None:
self.assertHasChanged(name=True, number=True) self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None) self.assertPrevious(name=None, number=None)
self.assertCurrent(name='') self.assertCurrent(name='')
@ -798,14 +798,14 @@ class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):
self.assertCurrent(name='retro') self.assertCurrent(name='retro')
self.assertChanged() self.assertChanged()
def test_pre_save_has_changed(self): def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(name=True, number=True) self.assertHasChanged(name=True, number=True)
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertHasChanged(name=True, number=True) self.assertHasChanged(name=True, number=True)
self.instance.number = 7 self.instance.number = 7
self.assertHasChanged(name=True, number=True) self.assertHasChanged(name=True, number=True)
def test_pre_save_changed(self): def test_pre_save_changed(self) -> None:
self.assertChanged() self.assertChanged()
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertChanged() self.assertChanged()
@ -819,7 +819,7 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
tracked_class = ModelTrackedMultiple tracked_class = ModelTrackedMultiple
def test_pre_save_has_changed(self): def test_pre_save_has_changed(self) -> None:
self.tracker = self.instance.name_tracker self.tracker = self.instance.name_tracker
self.assertHasChanged(name=True, number=True) self.assertHasChanged(name=True, number=True)
self.instance.name = 'new age' self.instance.name = 'new age'
@ -829,7 +829,7 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
self.instance.name = 'new age' self.instance.name = 'new age'
self.assertHasChanged(name=True, number=True) self.assertHasChanged(name=True, number=True)
def test_pre_save_changed(self): def test_pre_save_changed(self) -> None:
self.tracker = self.instance.name_tracker self.tracker = self.instance.name_tracker
self.assertChanged() self.assertChanged()
self.instance.name = 'new age' self.instance.name = 'new age'
@ -851,7 +851,7 @@ class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):
fk_class = ModelTracked fk_class = ModelTracked
tracked_class = ModelTrackedFK tracked_class = ModelTrackedFK
def test_custom_without_id(self): def test_custom_without_id(self) -> None:
with self.assertNumQueries(2): with self.assertNumQueries(2):
self.tracked_class.objects.get() self.tracked_class.objects.get()
self.tracker = self.instance.custom_tracker_without_id self.tracker = self.instance.custom_tracker_without_id
@ -869,7 +869,7 @@ class InheritedModelTrackerTests(ModelTrackerTests):
tracked_class = InheritedModelTracked tracked_class = InheritedModelTracked
def test_child_fields_not_tracked(self): def test_child_fields_not_tracked(self) -> None:
self.name2 = 'test' self.name2 = 'test'
self.assertEqual(self.tracker.previous('name2'), None) self.assertEqual(self.tracker.previous('name2'), None)
self.assertTrue(self.tracker.has_changed('name2')) self.assertTrue(self.tracker.has_changed('name2'))
@ -882,7 +882,7 @@ class AbstractModelTrackerTests(ModelTrackerTests):
class TrackerContextDecoratorTests(TestCase): class TrackerContextDecoratorTests(TestCase):
def setUp(self): def setUp(self) -> None:
self.instance = Tracked.objects.create(number=1) self.instance = Tracked.objects.create(number=1)
self.tracker = self.instance.tracker self.tracker = self.instance.tracker
@ -894,7 +894,7 @@ class TrackerContextDecoratorTests(TestCase):
for f in fields: for f in fields:
self.assertFalse(self.tracker.has_changed(f)) self.assertFalse(self.tracker.has_changed(f))
def test_context_manager(self): def test_context_manager(self) -> None:
with self.tracker: with self.tracker:
with self.tracker: with self.tracker:
self.instance.name = 'new' self.instance.name = 'new'
@ -905,7 +905,7 @@ class TrackerContextDecoratorTests(TestCase):
self.assertNotChanged('name') self.assertNotChanged('name')
def test_context_manager_fields(self): def test_context_manager_fields(self) -> None:
with self.tracker('number'): with self.tracker('number'):
with self.tracker('number', 'name'): with self.tracker('number', 'name'):
self.instance.name = 'new' self.instance.name = 'new'
@ -918,7 +918,7 @@ class TrackerContextDecoratorTests(TestCase):
self.assertNotChanged('number', 'name') self.assertNotChanged('number', 'name')
def test_tracker_decorator(self): def test_tracker_decorator(self) -> None:
@Tracked.tracker @Tracked.tracker
def tracked_method(obj): def tracked_method(obj):
@ -929,7 +929,7 @@ class TrackerContextDecoratorTests(TestCase):
self.assertNotChanged('name') self.assertNotChanged('name')
def test_tracker_decorator_fields(self): def test_tracker_decorator_fields(self) -> None:
@Tracked.tracker(fields=['name']) @Tracked.tracker(fields=['name'])
def tracked_method(obj): def tracked_method(obj):
@ -942,7 +942,7 @@ class TrackerContextDecoratorTests(TestCase):
self.assertChanged('number') self.assertChanged('number')
self.assertNotChanged('name') self.assertNotChanged('name')
def test_tracker_context_with_save(self): def test_tracker_context_with_save(self) -> None:
with self.tracker: with self.tracker:
self.instance.name = 'new' self.instance.name = 'new'

View file

@ -10,33 +10,33 @@ from tests.models import DoubleMonitored, Monitored, MonitorWhen, MonitorWhenEmp
class MonitorFieldTests(TestCase): class MonitorFieldTests(TestCase):
def setUp(self): def setUp(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)): with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)):
self.instance = Monitored(name='Charlie') self.instance = Monitored(name='Charlie')
self.created = self.instance.name_changed self.created = self.instance.name_changed
def test_save_no_change(self): def test_save_no_change(self) -> None:
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, self.created) self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed(self): def test_save_changed(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
self.instance.name = 'Maria' self.instance.name = 'Maria'
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
def test_double_save(self): def test_double_save(self) -> None:
self.instance.name = 'Jose' self.instance.name = 'Jose'
self.instance.save() self.instance.save()
changed = self.instance.name_changed changed = self.instance.name_changed
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, changed) self.assertEqual(self.instance.name_changed, changed)
def test_no_monitor_arg(self): def test_no_monitor_arg(self) -> None:
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
MonitorField() MonitorField()
def test_monitor_default_is_none_when_nullable(self): def test_monitor_default_is_none_when_nullable(self) -> None:
self.assertIsNone(self.instance.name_changed_nullable) self.assertIsNone(self.instance.name_changed_nullable)
expected_datetime = datetime(2022, 1, 18, 12, 0, 0, tzinfo=timezone.utc) expected_datetime = datetime(2022, 1, 18, 12, 0, 0, tzinfo=timezone.utc)
@ -51,33 +51,33 @@ class MonitorWhenFieldTests(TestCase):
""" """
Will record changes only when name is 'Jose' or 'Maria' Will record changes only when name is 'Jose' or 'Maria'
""" """
def setUp(self): def setUp(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)): with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)):
self.instance = MonitorWhen(name='Charlie') self.instance = MonitorWhen(name='Charlie')
self.created = self.instance.name_changed self.created = self.instance.name_changed
def test_save_no_change(self): def test_save_no_change(self) -> None:
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, self.created) self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Jose(self): def test_save_changed_to_Jose(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
self.instance.name = 'Jose' self.instance.name = 'Jose'
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
def test_save_changed_to_Maria(self): def test_save_changed_to_Maria(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)): with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
self.instance.name = 'Maria' self.instance.name = 'Maria'
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)) self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
def test_save_changed_to_Pedro(self): def test_save_changed_to_Pedro(self) -> None:
self.instance.name = 'Pedro' self.instance.name = 'Pedro'
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, self.created) self.assertEqual(self.instance.name_changed, self.created)
def test_double_save(self): def test_double_save(self) -> None:
self.instance.name = 'Jose' self.instance.name = 'Jose'
self.instance.save() self.instance.save()
changed = self.instance.name_changed changed = self.instance.name_changed
@ -89,20 +89,20 @@ class MonitorWhenEmptyFieldTests(TestCase):
""" """
Monitor should never be updated id when is an empty list. Monitor should never be updated id when is an empty list.
""" """
def setUp(self): def setUp(self) -> None:
self.instance = MonitorWhenEmpty(name='Charlie') self.instance = MonitorWhenEmpty(name='Charlie')
self.created = self.instance.name_changed self.created = self.instance.name_changed
def test_save_no_change(self): def test_save_no_change(self) -> None:
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, self.created) self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Jose(self): def test_save_changed_to_Jose(self) -> None:
self.instance.name = 'Jose' self.instance.name = 'Jose'
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, self.created) self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Maria(self): def test_save_changed_to_Maria(self) -> None:
self.instance.name = 'Maria' self.instance.name = 'Maria'
self.instance.save() self.instance.save()
self.assertEqual(self.instance.name_changed, self.created) self.assertEqual(self.instance.name_changed, self.created)
@ -110,18 +110,18 @@ class MonitorWhenEmptyFieldTests(TestCase):
class MonitorDoubleFieldTests(TestCase): class MonitorDoubleFieldTests(TestCase):
def setUp(self): def setUp(self) -> None:
DoubleMonitored.objects.create(name='Charlie', name2='Charlie2') DoubleMonitored.objects.create(name='Charlie', name2='Charlie2')
def test_recursion_error_with_only(self): def test_recursion_error_with_only(self) -> None:
# Any field passed to only() is generating a recursion error # Any field passed to only() is generating a recursion error
list(DoubleMonitored.objects.only('id')) list(DoubleMonitored.objects.only('id'))
def test_recursion_error_with_defer(self): def test_recursion_error_with_defer(self) -> None:
# Only monitored fields passed to defer() are failing # Only monitored fields passed to defer() are failing
list(DoubleMonitored.objects.defer('name')) list(DoubleMonitored.objects.defer('name'))
def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self): def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self) -> None:
obj = DoubleMonitored.objects.defer('name').get(name='Charlie') obj = DoubleMonitored.objects.defer('name').get(name='Charlie')
with time_machine.travel(datetime(2016, 12, 1, tzinfo=timezone.utc)): with time_machine.travel(datetime(2016, 12, 1, tzinfo=timezone.utc)):
obj.name = 'Charlie2' obj.name = 'Charlie2'

View file

@ -9,62 +9,62 @@ class SplitFieldTests(TestCase):
full_text = 'summary\n\n<!-- split -->\n\nmore' full_text = 'summary\n\n<!-- split -->\n\nmore'
excerpt = 'summary\n' excerpt = 'summary\n'
def setUp(self): def setUp(self) -> None:
self.post = Article.objects.create( self.post = Article.objects.create(
title='example post', body=self.full_text) title='example post', body=self.full_text)
def test_unicode_content(self): def test_unicode_content(self) -> None:
self.assertEqual(str(self.post.body), self.full_text) self.assertEqual(str(self.post.body), self.full_text)
def test_excerpt(self): def test_excerpt(self) -> None:
self.assertEqual(self.post.body.excerpt, self.excerpt) self.assertEqual(self.post.body.excerpt, self.excerpt)
def test_content(self): def test_content(self) -> None:
self.assertEqual(self.post.body.content, self.full_text) self.assertEqual(self.post.body.content, self.full_text)
def test_has_more(self): def test_has_more(self) -> None:
self.assertTrue(self.post.body.has_more) self.assertTrue(self.post.body.has_more)
def test_not_has_more(self): def test_not_has_more(self) -> None:
post = Article.objects.create(title='example 2', post = Article.objects.create(title='example 2',
body='some text\n\nsome more\n') body='some text\n\nsome more\n')
self.assertFalse(post.body.has_more) self.assertFalse(post.body.has_more)
def test_load_back(self): def test_load_back(self) -> None:
post = Article.objects.get(pk=self.post.pk) post = Article.objects.get(pk=self.post.pk)
self.assertEqual(post.body.content, self.post.body.content) self.assertEqual(post.body.content, self.post.body.content)
self.assertEqual(post.body.excerpt, self.post.body.excerpt) self.assertEqual(post.body.excerpt, self.post.body.excerpt)
def test_assign_to_body(self): def test_assign_to_body(self) -> None:
new_text = 'different\n\n<!-- split -->\n\nother' new_text = 'different\n\n<!-- split -->\n\nother'
self.post.body = new_text self.post.body = new_text
self.post.save() self.post.save()
self.assertEqual(str(self.post.body), new_text) self.assertEqual(str(self.post.body), new_text)
def test_assign_to_content(self): def test_assign_to_content(self) -> None:
new_text = 'different\n\n<!-- split -->\n\nother' new_text = 'different\n\n<!-- split -->\n\nother'
self.post.body.content = new_text self.post.body.content = new_text
self.post.save() self.post.save()
self.assertEqual(str(self.post.body), new_text) self.assertEqual(str(self.post.body), new_text)
def test_assign_to_excerpt(self): def test_assign_to_excerpt(self) -> None:
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
self.post.body.excerpt = 'this should fail' self.post.body.excerpt = 'this should fail'
def test_access_via_class(self): def test_access_via_class(self) -> None:
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
Article.body Article.body
def test_assign_splittext(self): def test_assign_splittext(self) -> None:
a = Article(title='Some Title') a = Article(title='Some Title')
a.body = self.post.body a.body = self.post.body
self.assertEqual(a.body.excerpt, 'summary\n') self.assertEqual(a.body.excerpt, 'summary\n')
def test_value_to_string(self): def test_value_to_string(self) -> None:
f = self.post._meta.get_field('body') f = self.post._meta.get_field('body')
self.assertEqual(f.value_to_string(self.post), self.full_text) self.assertEqual(f.value_to_string(self.post), self.full_text)
def test_abstract_inheritance(self): def test_abstract_inheritance(self) -> None:
class Child(SplitFieldAbstractParent): class Child(SplitFieldAbstractParent):
pass pass

View file

@ -13,22 +13,22 @@ from tests.models import (
class StatusFieldTests(TestCase): class StatusFieldTests(TestCase):
def test_status_with_default_filled(self): def test_status_with_default_filled(self) -> None:
instance = StatusFieldDefaultFilled() instance = StatusFieldDefaultFilled()
self.assertEqual(instance.status, instance.STATUS.yes) self.assertEqual(instance.status, instance.STATUS.yes)
def test_status_with_default_not_filled(self): def test_status_with_default_not_filled(self) -> None:
instance = StatusFieldDefaultNotFilled() instance = StatusFieldDefaultNotFilled()
self.assertEqual(instance.status, instance.STATUS.no) self.assertEqual(instance.status, instance.STATUS.no)
def test_no_check_for_status(self): def test_no_check_for_status(self) -> None:
field = StatusField(no_check_for_status=True) field = StatusField(no_check_for_status=True)
# this model has no STATUS attribute, so checking for it would error # this model has no STATUS attribute, so checking for it would error
field.prepare_class(Article) field.prepare_class(Article)
def test_get_status_display(self): def test_get_status_display(self) -> None:
instance = StatusFieldDefaultFilled() instance = StatusFieldDefaultFilled()
self.assertEqual(instance.get_status_display(), "Yes") self.assertEqual(instance.get_status_display(), "Yes")
def test_choices_name(self): def test_choices_name(self) -> None:
StatusFieldChoicesName() StatusFieldChoicesName()

View file

@ -9,41 +9,41 @@ from model_utils.fields import UrlsafeTokenField
class UrlsaftTokenFieldTests(TestCase): class UrlsaftTokenFieldTests(TestCase):
def test_editable_default(self): def test_editable_default(self) -> None:
field = UrlsafeTokenField() field = UrlsafeTokenField()
self.assertFalse(field.editable) self.assertFalse(field.editable)
def test_editable(self): def test_editable(self) -> None:
field = UrlsafeTokenField(editable=True) field = UrlsafeTokenField(editable=True)
self.assertTrue(field.editable) self.assertTrue(field.editable)
def test_max_length_default(self): def test_max_length_default(self) -> None:
field = UrlsafeTokenField() field = UrlsafeTokenField()
self.assertEqual(field.max_length, 128) self.assertEqual(field.max_length, 128)
def test_max_length(self): def test_max_length(self) -> None:
field = UrlsafeTokenField(max_length=256) field = UrlsafeTokenField(max_length=256)
self.assertEqual(field.max_length, 256) self.assertEqual(field.max_length, 256)
def test_factory_default(self): def test_factory_default(self) -> None:
field = UrlsafeTokenField() field = UrlsafeTokenField()
self.assertIsNone(field._factory) self.assertIsNone(field._factory)
def test_factory_not_callable(self): def test_factory_not_callable(self) -> None:
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
UrlsafeTokenField(factory='INVALID') UrlsafeTokenField(factory='INVALID')
def test_get_default(self): def test_get_default(self) -> None:
field = UrlsafeTokenField() field = UrlsafeTokenField()
value = field.get_default() value = field.get_default()
self.assertEqual(len(value), field.max_length) self.assertEqual(len(value), field.max_length)
def test_get_default_with_non_default_max_length(self): def test_get_default_with_non_default_max_length(self) -> None:
field = UrlsafeTokenField(max_length=64) field = UrlsafeTokenField(max_length=64)
value = field.get_default() value = field.get_default()
self.assertEqual(len(value), 64) self.assertEqual(len(value), 64)
def test_get_default_with_factory(self): def test_get_default_with_factory(self) -> None:
token = 'SAMPLE_TOKEN' token = 'SAMPLE_TOKEN'
factory = Mock(return_value=token) factory = Mock(return_value=token)
field = UrlsafeTokenField(factory=factory) field = UrlsafeTokenField(factory=factory)
@ -52,12 +52,12 @@ class UrlsaftTokenFieldTests(TestCase):
self.assertEqual(value, token) self.assertEqual(value, token)
factory.assert_called_once_with(field.max_length) factory.assert_called_once_with(field.max_length)
def test_no_default_param(self): def test_no_default_param(self) -> None:
field = UrlsafeTokenField(default='DEFAULT') field = UrlsafeTokenField(default='DEFAULT')
self.assertIs(field.default, NOT_PROVIDED) self.assertIs(field.default, NOT_PROVIDED)
def test_deconstruct(self): def test_deconstruct(self) -> None:
def test_factory(): def test_factory() -> None:
pass pass
instance = UrlsafeTokenField(factory=test_factory) instance = UrlsafeTokenField(factory=test_factory)
name, path, args, kwargs = instance.deconstruct() name, path, args, kwargs = instance.deconstruct()

View file

@ -10,31 +10,31 @@ from model_utils.fields import UUIDField
class UUIDFieldTests(TestCase): class UUIDFieldTests(TestCase):
def test_uuid_version_default(self): def test_uuid_version_default(self) -> None:
instance = UUIDField() instance = UUIDField()
self.assertEqual(instance.default, uuid.uuid4) self.assertEqual(instance.default, uuid.uuid4)
def test_uuid_version_1(self): def test_uuid_version_1(self) -> None:
instance = UUIDField(version=1) instance = UUIDField(version=1)
self.assertEqual(instance.default, uuid.uuid1) self.assertEqual(instance.default, uuid.uuid1)
def test_uuid_version_2_error(self): def test_uuid_version_2_error(self) -> None:
self.assertRaises(ValidationError, UUIDField, 'version', 2) self.assertRaises(ValidationError, UUIDField, 'version', 2)
def test_uuid_version_3(self): def test_uuid_version_3(self) -> None:
instance = UUIDField(version=3) instance = UUIDField(version=3)
self.assertEqual(instance.default, uuid.uuid3) self.assertEqual(instance.default, uuid.uuid3)
def test_uuid_version_4(self): def test_uuid_version_4(self) -> None:
instance = UUIDField(version=4) instance = UUIDField(version=4)
self.assertEqual(instance.default, uuid.uuid4) self.assertEqual(instance.default, uuid.uuid4)
def test_uuid_version_5(self): def test_uuid_version_5(self) -> None:
instance = UUIDField(version=5) instance = UUIDField(version=5)
self.assertEqual(instance.default, uuid.uuid5) self.assertEqual(instance.default, uuid.uuid5)
def test_uuid_version_bellow_min(self): def test_uuid_version_bellow_min(self) -> None:
self.assertRaises(ValidationError, UUIDField, 'version', 0) self.assertRaises(ValidationError, UUIDField, 'version', 0)
def test_uuid_version_above_max(self): def test_uuid_version_above_max(self) -> None:
self.assertRaises(ValidationError, UUIDField, 'version', 6) self.assertRaises(ValidationError, UUIDField, 'version', 6)

View file

@ -7,7 +7,7 @@ from tests.models import InheritanceManagerTestChild1, InheritanceManagerTestPar
class InheritanceIterableTest(TestCase): class InheritanceIterableTest(TestCase):
def test_prefetch(self): def test_prefetch(self) -> None:
qs = InheritanceManagerTestChild1.objects.all().prefetch_related( qs = InheritanceManagerTestChild1.objects.all().prefetch_related(
Prefetch( Prefetch(
'normal_field', 'normal_field',

View file

@ -1,8 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from model_utils.managers import InheritanceManager
from tests.models import ( from tests.models import (
InheritanceManagerTestChild1, InheritanceManagerTestChild1,
InheritanceManagerTestChild2, InheritanceManagerTestChild2,
@ -16,19 +19,22 @@ from tests.models import (
TimeFrame, TimeFrame,
) )
if TYPE_CHECKING:
from django.db.models.fields.related_descriptors import RelatedManager
class InheritanceManagerTests(TestCase): class InheritanceManagerTests(TestCase):
def setUp(self): def setUp(self) -> None:
self.child1 = InheritanceManagerTestChild1.objects.create() self.child1 = InheritanceManagerTestChild1.objects.create()
self.child2 = InheritanceManagerTestChild2.objects.create() self.child2 = InheritanceManagerTestChild2.objects.create()
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
self.grandchild1_2 = \ self.grandchild1_2 = \
InheritanceManagerTestGrandChild1_2.objects.create() InheritanceManagerTestGrandChild1_2.objects.create()
def get_manager(self): def get_manager(self) -> InheritanceManager[InheritanceManagerTestParent]:
return InheritanceManagerTestParent.objects return InheritanceManagerTestParent.objects
def test_normal(self): def test_normal(self) -> None:
children = { children = {
InheritanceManagerTestParent(pk=self.child1.pk), InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestParent(pk=self.child2.pk), InheritanceManagerTestParent(pk=self.child2.pk),
@ -37,14 +43,14 @@ class InheritanceManagerTests(TestCase):
} }
self.assertEqual(set(self.get_manager().all()), children) self.assertEqual(set(self.get_manager().all()), children)
def test_select_all_subclasses(self): def test_select_all_subclasses(self) -> None:
children = {self.child1, self.child2} children = {self.child1, self.child2}
children.add(self.grandchild1) children.add(self.grandchild1)
children.add(self.grandchild1_2) children.add(self.grandchild1_2)
self.assertEqual( self.assertEqual(
set(self.get_manager().select_subclasses()), children) set(self.get_manager().select_subclasses()), children)
def test_select_subclasses_invalid_relation(self): def test_select_subclasses_invalid_relation(self) -> None:
""" """
If an invalid relation string is provided, we can provide the user If an invalid relation string is provided, we can provide the user
with a list which is valid, rather than just have the select_related() with a list which is valid, rather than just have the select_related()
@ -54,7 +60,7 @@ class InheritanceManagerTests(TestCase):
with self.assertRaisesRegex(ValueError, regex): with self.assertRaisesRegex(ValueError, regex):
self.get_manager().select_subclasses('user') self.get_manager().select_subclasses('user')
def test_select_specific_subclasses(self): def test_select_specific_subclasses(self) -> None:
children = { children = {
self.child1, self.child1,
InheritanceManagerTestParent(pk=self.child2.pk), InheritanceManagerTestParent(pk=self.child2.pk),
@ -69,7 +75,7 @@ class InheritanceManagerTests(TestCase):
children, children,
) )
def test_select_specific_grandchildren(self): def test_select_specific_grandchildren(self) -> None:
children = { children = {
InheritanceManagerTestParent(pk=self.child1.pk), InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestParent(pk=self.child2.pk), InheritanceManagerTestParent(pk=self.child2.pk),
@ -85,7 +91,7 @@ class InheritanceManagerTests(TestCase):
children, children,
) )
def test_children_and_grandchildren(self): def test_children_and_grandchildren(self) -> None:
children = { children = {
self.child1, self.child1,
InheritanceManagerTestParent(pk=self.child2.pk), InheritanceManagerTestParent(pk=self.child2.pk),
@ -102,24 +108,24 @@ class InheritanceManagerTests(TestCase):
children, children,
) )
def test_get_subclass(self): def test_get_subclass(self) -> None:
self.assertEqual( self.assertEqual(
self.get_manager().get_subclass(pk=self.child1.pk), self.get_manager().get_subclass(pk=self.child1.pk),
self.child1) self.child1)
def test_get_subclass_on_queryset(self): def test_get_subclass_on_queryset(self) -> None:
self.assertEqual( self.assertEqual(
self.get_manager().all().get_subclass(pk=self.child1.pk), self.get_manager().all().get_subclass(pk=self.child1.pk),
self.child1) self.child1)
def test_prior_select_related(self): def test_prior_select_related(self) -> None:
with self.assertNumQueries(1): with self.assertNumQueries(1):
obj = self.get_manager().select_related( obj = self.get_manager().select_related(
"inheritancemanagertestchild1").select_subclasses( "inheritancemanagertestchild1").select_subclasses(
"inheritancemanagertestchild2").get(pk=self.child1.pk) "inheritancemanagertestchild2").get(pk=self.child1.pk)
obj.inheritancemanagertestchild1 obj.inheritancemanagertestchild1
def test_manually_specifying_parent_fk_including_grandchildren(self): def test_manually_specifying_parent_fk_including_grandchildren(self) -> None:
""" """
given a Model which inherits from another Model, but also declares given a Model which inherits from another Model, but also declares
the OneToOne link manually using `related_name` and `parent_link`, the OneToOne link manually using `related_name` and `parent_link`,
@ -150,7 +156,7 @@ class InheritanceManagerTests(TestCase):
self.assertEqual(set(results.subclasses), self.assertEqual(set(results.subclasses),
set(expected_related_names)) set(expected_related_names))
def test_manually_specifying_parent_fk_single_subclass(self): def test_manually_specifying_parent_fk_single_subclass(self) -> None:
""" """
Using a string related_name when the relation is manually defined Using a string related_name when the relation is manually defined
instead of implicit should still work in the same way. instead of implicit should still work in the same way.
@ -170,11 +176,11 @@ class InheritanceManagerTests(TestCase):
self.assertEqual(set(results.subclasses), self.assertEqual(set(results.subclasses),
set(expected_related_names)) set(expected_related_names))
def test_filter_on_values_queryset(self): def test_filter_on_values_queryset(self) -> None:
queryset = InheritanceManagerTestChild1.objects.values('id').filter(pk=self.child1.pk) queryset = InheritanceManagerTestChild1.objects.values('id').filter(pk=self.child1.pk)
self.assertEqual(list(queryset), [{'id': self.child1.pk}]) self.assertEqual(list(queryset), [{'id': self.child1.pk}])
def test_values_list_on_select_subclasses(self): def test_values_list_on_select_subclasses(self) -> None:
""" """
Using `select_subclasses` in conjunction with `values_list()` raised an Using `select_subclasses` in conjunction with `values_list()` raised an
exception in `_get_sub_obj_recurse()` because the result of `values_list()` exception in `_get_sub_obj_recurse()` because the result of `values_list()`
@ -219,14 +225,14 @@ class InheritanceManagerTests(TestCase):
class InheritanceManagerUsingModelsTests(TestCase): class InheritanceManagerUsingModelsTests(TestCase):
def setUp(self): def setUp(self) -> None:
self.parent1 = InheritanceManagerTestParent.objects.create() self.parent1 = InheritanceManagerTestParent.objects.create()
self.child1 = InheritanceManagerTestChild1.objects.create() self.child1 = InheritanceManagerTestChild1.objects.create()
self.child2 = InheritanceManagerTestChild2.objects.create() self.child2 = InheritanceManagerTestChild2.objects.create()
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create() self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create()
def test_select_subclass_by_child_model(self): def test_select_subclass_by_child_model(self) -> None:
""" """
Confirm that passing a child model works the same as passing the Confirm that passing a child model works the same as passing the
select_related manually select_related manually
@ -238,7 +244,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(objs.subclasses, objsmodels.subclasses) self.assertEqual(objs.subclasses, objsmodels.subclasses)
self.assertEqual(list(objs), list(objsmodels)) self.assertEqual(list(objs), list(objsmodels))
def test_select_subclass_by_grandchild_model(self): def test_select_subclass_by_grandchild_model(self) -> None:
""" """
Confirm that passing a grandchild model works the same as passing the Confirm that passing a grandchild model works the same as passing the
select_related manually select_related manually
@ -251,7 +257,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(objs.subclasses, objsmodels.subclasses) self.assertEqual(objs.subclasses, objsmodels.subclasses)
self.assertEqual(list(objs), list(objsmodels)) self.assertEqual(list(objs), list(objsmodels))
def test_selecting_all_subclasses_specifically_grandchildren(self): def test_selecting_all_subclasses_specifically_grandchildren(self) -> None:
""" """
A bare select_subclasses() should achieve the same results as doing A bare select_subclasses() should achieve the same results as doing
select_subclasses and specifying all possible subclasses. select_subclasses and specifying all possible subclasses.
@ -268,7 +274,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses))
self.assertEqual(list(objs), list(objsmodels)) self.assertEqual(list(objs), list(objsmodels))
def test_selecting_all_subclasses_specifically_children(self): def test_selecting_all_subclasses_specifically_children(self) -> None:
""" """
A bare select_subclasses() should achieve the same results as doing A bare select_subclasses() should achieve the same results as doing
select_subclasses and specifying all possible subclasses. select_subclasses and specifying all possible subclasses.
@ -296,7 +302,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses))
self.assertEqual(list(objs), list(objsmodels)) self.assertEqual(list(objs), list(objsmodels))
def test_select_subclass_just_self(self): def test_select_subclass_just_self(self) -> None:
""" """
Passing in the same model as the manager/queryset is bound against Passing in the same model as the manager/queryset is bound against
(ie: the root parent) should have no effect on the result set. (ie: the root parent) should have no effect on the result set.
@ -312,7 +318,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
InheritanceManagerTestParent(pk=self.grandchild1_2.pk), InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
]) ])
def test_select_subclass_invalid_related_model(self): def test_select_subclass_invalid_related_model(self) -> None:
""" """
Confirming that giving a stupid model doesn't work. Confirming that giving a stupid model doesn't work.
""" """
@ -321,7 +327,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
InheritanceManagerTestParent.objects.select_subclasses( InheritanceManagerTestParent.objects.select_subclasses(
TimeFrame).order_by('pk') TimeFrame).order_by('pk')
def test_mixing_strings_and_classes_with_grandchildren(self): def test_mixing_strings_and_classes_with_grandchildren(self) -> None:
""" """
Given arguments consisting of both strings and model classes, Given arguments consisting of both strings and model classes,
ensure the right resolutions take place, accounting for the extra ensure the right resolutions take place, accounting for the extra
@ -342,7 +348,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
] ]
self.assertEqual(list(objs), expecting2) self.assertEqual(list(objs), expecting2)
def test_mixing_strings_and_classes_with_children(self): def test_mixing_strings_and_classes_with_children(self) -> None:
""" """
Given arguments consisting of both strings and model classes, Given arguments consisting of both strings and model classes,
ensure the right resolutions take place, walking down as far as ensure the right resolutions take place, walking down as far as
@ -364,7 +370,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
] ]
self.assertEqual(list(objs), expecting2) self.assertEqual(list(objs), expecting2)
def test_duplications(self): def test_duplications(self) -> None:
""" """
Check that even if the same thing is provided as a string and a model Check that even if the same thing is provided as a string and a model
that the right results are retrieved. that the right results are retrieved.
@ -381,7 +387,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
InheritanceManagerTestParent(pk=self.grandchild1_2.pk), InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
]) ])
def test_child_doesnt_accidentally_get_parent(self): def test_child_doesnt_accidentally_get_parent(self) -> None:
""" """
Given a Child model which also has an InheritanceManager, Given a Child model which also has an InheritanceManager,
none of the returned objects should be Parent objects. none of the returned objects should be Parent objects.
@ -394,7 +400,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), InheritanceManagerTestChild1(pk=self.grandchild1_2.pk),
], list(objs)) ], list(objs))
def test_manually_specifying_parent_fk_only_specific_child(self): def test_manually_specifying_parent_fk_only_specific_child(self) -> None:
""" """
given a Model which inherits from another Model, but also declares given a Model which inherits from another Model, but also declares
the OneToOne link manually using `related_name` and `parent_link`, the OneToOne link manually using `related_name` and `parent_link`,
@ -418,7 +424,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(results.subclasses), self.assertEqual(set(results.subclasses),
set(expected_related_names)) set(expected_related_names))
def test_extras_descend(self): def test_extras_descend(self) -> None:
""" """
Ensure that extra(select=) values are copied onto sub-classes. Ensure that extra(select=) values are copied onto sub-classes.
""" """
@ -427,25 +433,25 @@ class InheritanceManagerUsingModelsTests(TestCase):
) )
self.assertTrue(all(result.foo == (result.id + 1) for result in results)) self.assertTrue(all(result.foo == (result.id + 1) for result in results))
def test_limit_to_specific_subclass(self): def test_limit_to_specific_subclass(self) -> None:
child3 = InheritanceManagerTestChild3.objects.create() child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3) results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3)
self.assertEqual([child3], list(results)) self.assertEqual([child3], list(results))
def test_limit_to_specific_subclass_with_custom_db_column(self): def test_limit_to_specific_subclass_with_custom_db_column(self) -> None:
item = InheritanceManagerTestChild3_1.objects.create() item = InheritanceManagerTestChild3_1.objects.create()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3_1) results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3_1)
self.assertEqual([item], list(results)) self.assertEqual([item], list(results))
def test_limit_to_specific_grandchild_class(self): def test_limit_to_specific_grandchild_class(self) -> None:
grandchild1 = InheritanceManagerTestGrandChild1.objects.get() grandchild1 = InheritanceManagerTestGrandChild1.objects.get()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1) results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1)
self.assertEqual([grandchild1], list(results)) self.assertEqual([grandchild1], list(results))
def test_limit_to_child_fetches_grandchildren_as_child_class(self): def test_limit_to_child_fetches_grandchildren_as_child_class(self) -> None:
# Not sure if this is the desired behaviour...? # Not sure if this is the desired behaviour...?
children = InheritanceManagerTestChild1.objects.all() children = InheritanceManagerTestChild1.objects.all()
@ -453,7 +459,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(children), set(results)) self.assertEqual(set(children), set(results))
def test_can_fetch_limited_class_grandchildren(self): def test_can_fetch_limited_class_grandchildren(self) -> None:
# Not sure if this is the desired behaviour...? # Not sure if this is the desired behaviour...?
children = InheritanceManagerTestChild1.objects.select_subclasses() children = InheritanceManagerTestChild1.objects.select_subclasses()
@ -461,7 +467,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(children), set(results)) self.assertEqual(set(children), set(results))
def test_selecting_multiple_instance_classes(self): def test_selecting_multiple_instance_classes(self) -> None:
child3 = InheritanceManagerTestChild3.objects.create() child3 = InheritanceManagerTestChild3.objects.create()
children1 = InheritanceManagerTestChild1.objects.all() children1 = InheritanceManagerTestChild1.objects.all()
@ -469,7 +475,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set([child3] + list(children1)), set(results)) self.assertEqual(set([child3] + list(children1)), set(results))
def test_selecting_multiple_instance_classes_including_grandchildren(self): def test_selecting_multiple_instance_classes_including_grandchildren(self) -> None:
child3 = InheritanceManagerTestChild3.objects.create() child3 = InheritanceManagerTestChild3.objects.create()
grandchild1 = InheritanceManagerTestGrandChild1.objects.get() grandchild1 = InheritanceManagerTestGrandChild1.objects.get()
@ -477,7 +483,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual({child3, grandchild1}, set(results)) self.assertEqual({child3, grandchild1}, set(results))
def test_select_subclasses_interaction_with_instance_of(self): def test_select_subclasses_interaction_with_instance_of(self) -> None:
child3 = InheritanceManagerTestChild3.objects.create() child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3) results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3)
@ -486,7 +492,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
class InheritanceManagerRelatedTests(InheritanceManagerTests): class InheritanceManagerRelatedTests(InheritanceManagerTests):
def setUp(self): def setUp(self) -> None:
self.related = InheritanceManagerTestRelated.objects.create() self.related = InheritanceManagerTestRelated.objects.create()
self.child1 = InheritanceManagerTestChild1.objects.create( self.child1 = InheritanceManagerTestChild1.objects.create(
related=self.related) related=self.related)
@ -495,16 +501,16 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related) self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related)
self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create(related=self.related) self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create(related=self.related)
def get_manager(self): def get_manager(self) -> RelatedManager[InheritanceManagerTestParent]: # type: ignore[override]
return self.related.imtests return self.related.imtests
def test_get_method_with_select_subclasses(self): def test_get_method_with_select_subclasses(self) -> None:
self.assertEqual( self.assertEqual(
InheritanceManagerTestParent.objects.select_subclasses().get( InheritanceManagerTestParent.objects.select_subclasses().get(
id=self.child1.id), id=self.child1.id),
self.child1) self.child1)
def test_get_method_with_select_subclasses_check_for_useless_join(self): def test_get_method_with_select_subclasses_check_for_useless_join(self) -> None:
child4 = InheritanceManagerTestChild4.objects.create(related=self.related, other_onetoone=self.child1) child4 = InheritanceManagerTestChild4.objects.create(related=self.related, other_onetoone=self.child1)
self.assertEqual( self.assertEqual(
str(InheritanceManagerTestChild4.objects.select_subclasses().filter( str(InheritanceManagerTestChild4.objects.select_subclasses().filter(
@ -512,26 +518,26 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
str(InheritanceManagerTestChild4.objects.select_subclasses().select_related(None).filter( str(InheritanceManagerTestChild4.objects.select_subclasses().select_related(None).filter(
id=child4.id).query)) id=child4.id).query))
def test_annotate_with_select_subclasses(self): def test_annotate_with_select_subclasses(self) -> None:
qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( qs = InheritanceManagerTestParent.objects.select_subclasses().annotate(
models.Count('id')) models.Count('id'))
self.assertEqual(qs.get(id=self.child1.id).id__count, 1) self.assertEqual(qs.get(id=self.child1.id).id__count, 1)
def test_annotate_with_named_arguments_with_select_subclasses(self): def test_annotate_with_named_arguments_with_select_subclasses(self) -> None:
qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( qs = InheritanceManagerTestParent.objects.select_subclasses().annotate(
test_count=models.Count('id')) test_count=models.Count('id'))
self.assertEqual(qs.get(id=self.child1.id).test_count, 1) self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
def test_annotate_before_select_subclasses(self): def test_annotate_before_select_subclasses(self) -> None:
qs = InheritanceManagerTestParent.objects.annotate( qs = InheritanceManagerTestParent.objects.annotate(
models.Count('id')).select_subclasses() models.Count('id')).select_subclasses()
self.assertEqual(qs.get(id=self.child1.id).id__count, 1) self.assertEqual(qs.get(id=self.child1.id).id__count, 1)
def test_annotate_with_named_arguments_before_select_subclasses(self): def test_annotate_with_named_arguments_before_select_subclasses(self) -> None:
qs = InheritanceManagerTestParent.objects.annotate( qs = InheritanceManagerTestParent.objects.annotate(
test_count=models.Count('id')).select_subclasses() test_count=models.Count('id')).select_subclasses()
self.assertEqual(qs.get(id=self.child1.id).test_count, 1) self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self): def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self) -> None:
qs = InheritanceManagerTestParent.objects.select_subclasses() qs = InheritanceManagerTestParent.objects.select_subclasses()
self.assertEqual(qs.subclasses, qs._clone().subclasses) self.assertEqual(qs.subclasses, qs._clone().subclasses)

View file

@ -6,7 +6,7 @@ from tests.models import BoxJoinModel, JoinItemForeignKey
class JoinManagerTest(TestCase): class JoinManagerTest(TestCase):
def setUp(self): def setUp(self) -> None:
for i in range(20): for i in range(20):
BoxJoinModel.objects.create(name=f'name_{i}') BoxJoinModel.objects.create(name=f'name_{i}')
@ -15,24 +15,24 @@ class JoinManagerTest(TestCase):
) )
JoinItemForeignKey.objects.create(weight=20) JoinItemForeignKey.objects.create(weight=20)
def test_self_join(self): def test_self_join(self) -> None:
a_slice = BoxJoinModel.objects.all()[0:10] a_slice = BoxJoinModel.objects.all()[0:10]
with self.assertNumQueries(1): with self.assertNumQueries(1):
result = a_slice.join() result = a_slice.join()
self.assertEqual(result.count(), 10) self.assertEqual(result.count(), 10)
def test_self_join_with_where_statement(self): def test_self_join_with_where_statement(self) -> None:
qs = BoxJoinModel.objects.filter(name='name_1') qs = BoxJoinModel.objects.filter(name='name_1')
result = qs.join() result = qs.join()
self.assertEqual(result.count(), 1) self.assertEqual(result.count(), 1)
def test_join_with_other_qs(self): def test_join_with_other_qs(self) -> None:
item_qs = JoinItemForeignKey.objects.filter(weight=10) item_qs = JoinItemForeignKey.objects.filter(weight=10)
boxes = BoxJoinModel.objects.all().join(qs=item_qs) boxes = BoxJoinModel.objects.all().join(qs=item_qs)
self.assertEqual(boxes.count(), 1) self.assertEqual(boxes.count(), 1)
self.assertEqual(boxes[0].name, 'name_1') self.assertEqual(boxes[0].name, 'name_1')
def test_reverse_join(self): def test_reverse_join(self) -> None:
box_qs = BoxJoinModel.objects.filter(name='name_1') box_qs = BoxJoinModel.objects.filter(name='name_1')
items = JoinItemForeignKey.objects.all().join(box_qs) items = JoinItemForeignKey.objects.all().join(box_qs)
self.assertEqual(items.count(), 1) self.assertEqual(items.count(), 1)

View file

@ -6,7 +6,7 @@ from tests.models import Post
class QueryManagerTests(TestCase): class QueryManagerTests(TestCase):
def setUp(self): def setUp(self) -> None:
data = ((True, True, 0), data = ((True, True, 0),
(True, False, 4), (True, False, 4),
(False, False, 2), (False, False, 2),
@ -16,14 +16,14 @@ class QueryManagerTests(TestCase):
for p, c, o in data: for p, c, o in data:
Post.objects.create(published=p, confirmed=c, order=o) Post.objects.create(published=p, confirmed=c, order=o)
def test_passing_kwargs(self): def test_passing_kwargs(self) -> None:
qs = Post.public.all() qs = Post.public.all()
self.assertEqual([p.order for p in qs], [0, 1, 4, 5]) self.assertEqual([p.order for p in qs], [0, 1, 4, 5])
def test_passing_Q(self): def test_passing_Q(self) -> None:
qs = Post.public_confirmed.all() qs = Post.public_confirmed.all()
self.assertEqual([p.order for p in qs], [0, 1]) self.assertEqual([p.order for p in qs], [0, 1])
def test_ordering(self): def test_ordering(self) -> None:
qs = Post.public_reversed.all() qs = Post.public_reversed.all()
self.assertEqual([p.order for p in qs], [5, 4, 1, 0]) self.assertEqual([p.order for p in qs], [5, 4, 1, 0])

View file

@ -7,21 +7,21 @@ from tests.models import CustomSoftDelete
class CustomSoftDeleteManagerTests(TestCase): class CustomSoftDeleteManagerTests(TestCase):
def test_custom_manager_empty(self): def test_custom_manager_empty(self) -> None:
qs = CustomSoftDelete.available_objects.only_read() qs = CustomSoftDelete.available_objects.only_read()
self.assertEqual(qs.count(), 0) self.assertEqual(qs.count(), 0)
def test_custom_qs_empty(self): def test_custom_qs_empty(self) -> None:
qs = CustomSoftDelete.available_objects.all().only_read() qs = CustomSoftDelete.available_objects.all().only_read()
self.assertEqual(qs.count(), 0) self.assertEqual(qs.count(), 0)
def test_is_read(self): def test_is_read(self) -> None:
for is_read in [True, False, True, False]: for is_read in [True, False, True, False]:
CustomSoftDelete.available_objects.create(is_read=is_read) CustomSoftDelete.available_objects.create(is_read=is_read)
qs = CustomSoftDelete.available_objects.only_read() qs = CustomSoftDelete.available_objects.only_read()
self.assertEqual(qs.count(), 2) self.assertEqual(qs.count(), 2)
def test_is_read_removed(self): def test_is_read_removed(self) -> None:
for is_read, is_removed in [(True, True), (True, False), (False, False), (False, True)]: for is_read, is_removed in [(True, True), (True, False), (False, False), (False, True)]:
CustomSoftDelete.available_objects.create(is_read=is_read, is_removed=is_removed) CustomSoftDelete.available_objects.create(is_read=is_read, is_removed=is_removed)
qs = CustomSoftDelete.available_objects.only_read() qs = CustomSoftDelete.available_objects.only_read()

View file

@ -10,10 +10,10 @@ from tests.models import StatusManagerAdded
class StatusManagerAddedTests(TestCase): class StatusManagerAddedTests(TestCase):
def test_manager_available(self): def test_manager_available(self) -> None:
self.assertTrue(isinstance(StatusManagerAdded.active, QueryManager)) self.assertTrue(isinstance(StatusManagerAdded.active, QueryManager))
def test_conflict_error(self): def test_conflict_error(self) -> None:
with self.assertRaises(ImproperlyConfigured): with self.assertRaises(ImproperlyConfigured):
class ErrorModel(StatusModel): class ErrorModel(StatusModel):
STATUS = ( STATUS = (

View file

@ -7,23 +7,23 @@ from model_utils.fields import get_excerpt
class MigrationsTests(TestCase): class MigrationsTests(TestCase):
def test_makemigrations(self): def test_makemigrations(self) -> None:
call_command('makemigrations', dry_run=True) call_command('makemigrations', dry_run=True)
class GetExcerptTests(TestCase): class GetExcerptTests(TestCase):
def test_split(self): def test_split(self) -> None:
e = get_excerpt("some content\n\n<!-- split -->\n\nsome more") e = get_excerpt("some content\n\n<!-- split -->\n\nsome more")
self.assertEqual(e, 'some content\n') self.assertEqual(e, 'some content\n')
def test_auto_split(self): def test_auto_split(self) -> None:
e = get_excerpt("para one\n\npara two\n\npara three") e = get_excerpt("para one\n\npara two\n\npara three")
self.assertEqual(e, 'para one\n\npara two') self.assertEqual(e, 'para one\n\npara two')
def test_middle_of_para(self): def test_middle_of_para(self) -> None:
e = get_excerpt("some text\n<!-- split -->\nmore text") e = get_excerpt("some text\n<!-- split -->\nmore text")
self.assertEqual(e, 'some text') self.assertEqual(e, 'some text')
def test_middle_of_line(self): def test_middle_of_line(self) -> None:
e = get_excerpt("some text <!-- split --> more text") e = get_excerpt("some text <!-- split --> more text")
self.assertEqual(e, "some text <!-- split --> more text") self.assertEqual(e, "some text <!-- split --> more text")

View file

@ -6,7 +6,7 @@ from tests.models import ModelWithCustomDescriptor
class CustomDescriptorTests(TestCase): class CustomDescriptorTests(TestCase):
def setUp(self): def setUp(self) -> None:
self.instance = ModelWithCustomDescriptor.objects.create( self.instance = ModelWithCustomDescriptor.objects.create(
custom_field='1', custom_field='1',
tracked_custom_field='1', tracked_custom_field='1',
@ -14,7 +14,7 @@ class CustomDescriptorTests(TestCase):
tracked_regular_field=1, tracked_regular_field=1,
) )
def test_custom_descriptor_works(self): def test_custom_descriptor_works(self) -> None:
instance = self.instance instance = self.instance
self.assertEqual(instance.custom_field, '1') self.assertEqual(instance.custom_field, '1')
self.assertEqual(instance.__dict__['custom_field'], 1) self.assertEqual(instance.__dict__['custom_field'], 1)
@ -27,7 +27,7 @@ class CustomDescriptorTests(TestCase):
self.assertEqual(instance.custom_field, '2') self.assertEqual(instance.custom_field, '2')
self.assertEqual(instance.__dict__['custom_field'], 2) self.assertEqual(instance.__dict__['custom_field'], 2)
def test_deferred(self): def test_deferred(self) -> None:
instance = ModelWithCustomDescriptor.objects.only('id').get( instance = ModelWithCustomDescriptor.objects.only('id').get(
pk=self.instance.pk) pk=self.instance.pk)
self.assertIn('custom_field', instance.get_deferred_fields()) self.assertIn('custom_field', instance.get_deferred_fields())

View file

@ -7,7 +7,7 @@ from tests.models import SoftDeletable
class SoftDeletableModelTests(TestCase): class SoftDeletableModelTests(TestCase):
def test_can_only_see_not_removed_entries(self): def test_can_only_see_not_removed_entries(self) -> None:
SoftDeletable.available_objects.create(name='a', is_removed=True) SoftDeletable.available_objects.create(name='a', is_removed=True)
SoftDeletable.available_objects.create(name='b', is_removed=False) SoftDeletable.available_objects.create(name='b', is_removed=False)
@ -16,7 +16,7 @@ class SoftDeletableModelTests(TestCase):
self.assertEqual(queryset.count(), 1) self.assertEqual(queryset.count(), 1)
self.assertEqual(queryset[0].name, 'b') self.assertEqual(queryset[0].name, 'b')
def test_instance_cannot_be_fully_deleted(self): def test_instance_cannot_be_fully_deleted(self) -> None:
instance = SoftDeletable.available_objects.create(name='a') instance = SoftDeletable.available_objects.create(name='a')
instance.delete() instance.delete()
@ -24,7 +24,7 @@ class SoftDeletableModelTests(TestCase):
self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.available_objects.count(), 0)
self.assertEqual(SoftDeletable.all_objects.count(), 1) self.assertEqual(SoftDeletable.all_objects.count(), 1)
def test_instance_cannot_be_fully_deleted_via_queryset(self): def test_instance_cannot_be_fully_deleted_via_queryset(self) -> None:
SoftDeletable.available_objects.create(name='a') SoftDeletable.available_objects.create(name='a')
SoftDeletable.available_objects.all().delete() SoftDeletable.available_objects.all().delete()
@ -32,12 +32,12 @@ class SoftDeletableModelTests(TestCase):
self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.available_objects.count(), 0)
self.assertEqual(SoftDeletable.all_objects.count(), 1) self.assertEqual(SoftDeletable.all_objects.count(), 1)
def test_delete_instance_no_connection(self): def test_delete_instance_no_connection(self) -> None:
obj = SoftDeletable.available_objects.create(name='a') obj = SoftDeletable.available_objects.create(name='a')
self.assertRaises(ConnectionDoesNotExist, obj.delete, using='other') self.assertRaises(ConnectionDoesNotExist, obj.delete, using='other')
def test_instance_purge(self): def test_instance_purge(self) -> None:
instance = SoftDeletable.available_objects.create(name='a') instance = SoftDeletable.available_objects.create(name='a')
instance.delete(soft=False) instance.delete(soft=False)
@ -45,11 +45,11 @@ class SoftDeletableModelTests(TestCase):
self.assertEqual(SoftDeletable.available_objects.count(), 0) self.assertEqual(SoftDeletable.available_objects.count(), 0)
self.assertEqual(SoftDeletable.all_objects.count(), 0) self.assertEqual(SoftDeletable.all_objects.count(), 0)
def test_instance_purge_no_connection(self): def test_instance_purge_no_connection(self) -> None:
instance = SoftDeletable.available_objects.create(name='a') instance = SoftDeletable.available_objects.create(name='a')
self.assertRaises(ConnectionDoesNotExist, instance.delete, self.assertRaises(ConnectionDoesNotExist, instance.delete,
using='other', soft=False) using='other', soft=False)
def test_deprecation_warning(self): def test_deprecation_warning(self) -> None:
self.assertWarns(DeprecationWarning, SoftDeletable.objects.all) self.assertWarns(DeprecationWarning, SoftDeletable.objects.all)

View file

@ -9,12 +9,12 @@ from tests.models import CustomManagerStatusModel, Status, StatusPlainTuple
class StatusModelTests(TestCase): class StatusModelTests(TestCase):
def setUp(self): def setUp(self) -> None:
self.model = Status self.model = Status
self.on_hold = Status.STATUS.on_hold self.on_hold = Status.STATUS.on_hold
self.active = Status.STATUS.active self.active = Status.STATUS.active
def test_created(self): def test_created(self) -> None:
with time_machine.travel(datetime(2016, 1, 1)): with time_machine.travel(datetime(2016, 1, 1)):
c1 = self.model.objects.create() c1 = self.model.objects.create()
self.assertTrue(c1.status_changed, datetime(2016, 1, 1)) self.assertTrue(c1.status_changed, datetime(2016, 1, 1))
@ -23,7 +23,7 @@ class StatusModelTests(TestCase):
self.assertEqual(self.model.active.count(), 2) self.assertEqual(self.model.active.count(), 2)
self.assertEqual(self.model.deleted.count(), 0) self.assertEqual(self.model.deleted.count(), 0)
def test_modification(self): def test_modification(self) -> None:
t1 = self.model.objects.create() t1 = self.model.objects.create()
date_created = t1.status_changed date_created = t1.status_changed
t1.status = self.on_hold t1.status = self.on_hold
@ -39,7 +39,7 @@ class StatusModelTests(TestCase):
t1.save() t1.save()
self.assertTrue(t1.status_changed > date_active_again) self.assertTrue(t1.status_changed > date_active_again)
def test_save_with_update_fields_overrides_status_changed_provided(self): def test_save_with_update_fields_overrides_status_changed_provided(self) -> None:
''' '''
Tests if the save method updated status_changed field Tests if the save method updated status_changed field
accordingly when update_fields is used as an argument accordingly when update_fields is used as an argument
@ -54,7 +54,7 @@ class StatusModelTests(TestCase):
self.assertEqual(t1.status_changed, datetime(2020, 1, 2, tzinfo=timezone.utc)) self.assertEqual(t1.status_changed, datetime(2020, 1, 2, tzinfo=timezone.utc))
def test_save_with_update_fields_overrides_status_changed_not_provided(self): def test_save_with_update_fields_overrides_status_changed_not_provided(self) -> None:
''' '''
Tests if the save method updated status_changed field Tests if the save method updated status_changed field
accordingly when update_fields is used as an argument accordingly when update_fields is used as an argument
@ -71,7 +71,7 @@ class StatusModelTests(TestCase):
class StatusModelPlainTupleTests(StatusModelTests): class StatusModelPlainTupleTests(StatusModelTests):
def setUp(self): def setUp(self) -> None:
self.model = StatusPlainTuple self.model = StatusPlainTuple
self.on_hold = StatusPlainTuple.STATUS[2][0] self.on_hold = StatusPlainTuple.STATUS[2][0]
self.active = StatusPlainTuple.STATUS[0][0] self.active = StatusPlainTuple.STATUS[0][0]
@ -79,7 +79,7 @@ class StatusModelPlainTupleTests(StatusModelTests):
class StatusModelDefaultManagerTests(TestCase): class StatusModelDefaultManagerTests(TestCase):
def test_default_manager_is_not_status_model_generated_ones(self): def test_default_manager_is_not_status_model_generated_ones(self) -> None:
# Regression test for GH-251 # Regression test for GH-251
# The logic behind order for managers seems to have changed in Django 1.10 # The logic behind order for managers seems to have changed in Django 1.10
# and affects default manager. # and affects default manager.

View file

@ -12,36 +12,36 @@ from tests.models import TimeFrame, TimeFrameManagerAdded
class TimeFramedModelTests(TestCase): class TimeFramedModelTests(TestCase):
def setUp(self): def setUp(self) -> None:
self.now = datetime.now() self.now = datetime.now()
def test_not_yet_begun(self): def test_not_yet_begun(self) -> None:
TimeFrame.objects.create(start=self.now + timedelta(days=2)) TimeFrame.objects.create(start=self.now + timedelta(days=2))
self.assertEqual(TimeFrame.timeframed.count(), 0) self.assertEqual(TimeFrame.timeframed.count(), 0)
def test_finished(self): def test_finished(self) -> None:
TimeFrame.objects.create(end=self.now - timedelta(days=1)) TimeFrame.objects.create(end=self.now - timedelta(days=1))
self.assertEqual(TimeFrame.timeframed.count(), 0) self.assertEqual(TimeFrame.timeframed.count(), 0)
def test_no_end(self): def test_no_end(self) -> None:
TimeFrame.objects.create(start=self.now - timedelta(days=10)) TimeFrame.objects.create(start=self.now - timedelta(days=10))
self.assertEqual(TimeFrame.timeframed.count(), 1) self.assertEqual(TimeFrame.timeframed.count(), 1)
def test_no_start(self): def test_no_start(self) -> None:
TimeFrame.objects.create(end=self.now + timedelta(days=2)) TimeFrame.objects.create(end=self.now + timedelta(days=2))
self.assertEqual(TimeFrame.timeframed.count(), 1) self.assertEqual(TimeFrame.timeframed.count(), 1)
def test_within_range(self): def test_within_range(self) -> None:
TimeFrame.objects.create(start=self.now - timedelta(days=1), TimeFrame.objects.create(start=self.now - timedelta(days=1),
end=self.now + timedelta(days=1)) end=self.now + timedelta(days=1))
self.assertEqual(TimeFrame.timeframed.count(), 1) self.assertEqual(TimeFrame.timeframed.count(), 1)
class TimeFrameManagerAddedTests(TestCase): class TimeFrameManagerAddedTests(TestCase):
def test_manager_available(self): def test_manager_available(self) -> None:
self.assertTrue(isinstance(TimeFrameManagerAdded.timeframed, QueryManager)) self.assertTrue(isinstance(TimeFrameManagerAdded.timeframed, QueryManager))
def test_conflict_error(self): def test_conflict_error(self) -> None:
with self.assertRaises(ImproperlyConfigured): with self.assertRaises(ImproperlyConfigured):
class ErrorModel(TimeFramedModel): class ErrorModel(TimeFramedModel):
timeframed = models.BooleanField() timeframed = models.BooleanField()

View file

@ -9,19 +9,19 @@ from tests.models import TimeStamp, TimeStampWithStatusModel
class TimeStampedModelTests(TestCase): class TimeStampedModelTests(TestCase):
def test_created(self): def test_created(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)): with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)):
t1 = TimeStamp.objects.create() t1 = TimeStamp.objects.create()
self.assertEqual(t1.created, datetime(2016, 1, 1, tzinfo=timezone.utc)) self.assertEqual(t1.created, datetime(2016, 1, 1, tzinfo=timezone.utc))
def test_created_sets_modified(self): def test_created_sets_modified(self) -> None:
''' '''
Ensure that on creation that modified is set exactly equal to created. Ensure that on creation that modified is set exactly equal to created.
''' '''
t1 = TimeStamp.objects.create() t1 = TimeStamp.objects.create()
self.assertEqual(t1.created, t1.modified) self.assertEqual(t1.created, t1.modified)
def test_modified(self): def test_modified(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)): with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)):
t1 = TimeStamp.objects.create() t1 = TimeStamp.objects.create()
@ -30,7 +30,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.modified, datetime(2016, 1, 2, tzinfo=timezone.utc)) self.assertEqual(t1.modified, datetime(2016, 1, 2, tzinfo=timezone.utc))
def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self): def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self) -> None:
""" """
Setting the created date when first creating an object Setting the created date when first creating an object
should be permissible. should be permissible.
@ -40,7 +40,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.created, different_date) self.assertEqual(t1.created, different_date)
self.assertEqual(t1.modified, different_date) self.assertEqual(t1.modified, different_date)
def test_overriding_modified_via_object_creation(self): def test_overriding_modified_via_object_creation(self) -> None:
""" """
Setting the modified date explicitly should be possible when Setting the modified date explicitly should be possible when
first creating an object, but not thereafter. first creating an object, but not thereafter.
@ -50,7 +50,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.modified, different_date) self.assertEqual(t1.modified, different_date)
self.assertNotEqual(t1.created, different_date) self.assertNotEqual(t1.created, different_date)
def test_overriding_created_after_object_created(self): def test_overriding_created_after_object_created(self) -> None:
""" """
The created date may be changed post-create The created date may be changed post-create
""" """
@ -60,7 +60,7 @@ class TimeStampedModelTests(TestCase):
t1.save() t1.save()
self.assertEqual(t1.created, different_date) self.assertEqual(t1.created, different_date)
def test_overriding_modified_after_object_created(self): def test_overriding_modified_after_object_created(self) -> None:
""" """
The modified date should always be updated when the object The modified date should always be updated when the object
is saved, regardless of attempts to change it. is saved, regardless of attempts to change it.
@ -71,7 +71,7 @@ class TimeStampedModelTests(TestCase):
t1.save() t1.save()
self.assertNotEqual(t1.modified, different_date) self.assertNotEqual(t1.modified, different_date)
def test_overrides_using_save(self): def test_overrides_using_save(self) -> None:
""" """
The first time an object is saved, allow modification of both The first time an object is saved, allow modification of both
created and modified fields. created and modified fields.
@ -92,7 +92,7 @@ class TimeStampedModelTests(TestCase):
self.assertNotEqual(t1.modified, different_date2) self.assertNotEqual(t1.modified, different_date2)
self.assertNotEqual(t1.modified, different_date) self.assertNotEqual(t1.modified, different_date)
def test_save_with_update_fields_overrides_modified_provided_within_a(self): def test_save_with_update_fields_overrides_modified_provided_within_a(self) -> None:
""" """
Tests if the save method updated modified field Tests if the save method updated modified field
accordingly when update_fields is used as an argument accordingly when update_fields is used as an argument
@ -113,7 +113,7 @@ class TimeStampedModelTests(TestCase):
t1.save(update_fields=update_fields) t1.save(update_fields=update_fields)
self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc)) self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc))
def test_save_is_skipped_for_empty_update_fields_iterable(self): def test_save_is_skipped_for_empty_update_fields_iterable(self) -> None:
tests = ( tests = (
[], # list [], # list
(), # tuple (), # tuple
@ -133,7 +133,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.test_field, 0) self.assertEqual(t1.test_field, 0)
self.assertEqual(t1.modified, datetime(2020, 1, 1, tzinfo=timezone.utc)) self.assertEqual(t1.modified, datetime(2020, 1, 1, tzinfo=timezone.utc))
def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self): def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self) -> None:
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)): with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)):
t1 = TimeStamp.objects.create() t1 = TimeStamp.objects.create()
@ -142,7 +142,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc)) self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc))
def test_model_inherit_timestampmodel_and_statusmodel(self): def test_model_inherit_timestampmodel_and_statusmodel(self) -> None:
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)): with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)):
t1 = TimeStampWithStatusModel.objects.create() t1 = TimeStampWithStatusModel.objects.create()

View file

@ -7,13 +7,13 @@ from tests.models import CustomNotPrimaryUUIDModel, CustomUUIDModel
class UUIDFieldTests(TestCase): class UUIDFieldTests(TestCase):
def test_uuid_model_with_uuid_field_as_primary_key(self): def test_uuid_model_with_uuid_field_as_primary_key(self) -> None:
instance = CustomUUIDModel() instance = CustomUUIDModel()
instance.save() instance.save()
self.assertEqual(instance.id.__class__.__name__, 'UUID') self.assertEqual(instance.id.__class__.__name__, 'UUID')
self.assertEqual(instance.id, instance.pk) self.assertEqual(instance.id, instance.pk)
def test_uuid_model_with_uuid_field_as_not_primary_key(self): def test_uuid_model_with_uuid_field_as_not_primary_key(self) -> None:
instance = CustomNotPrimaryUUIDModel() instance = CustomNotPrimaryUUIDModel()
instance.save() instance.save()
self.assertEqual(instance.uuid.__class__.__name__, 'UUID') self.assertEqual(instance.uuid.__class__.__name__, 'UUID')