Merge pull request #603 from ProtixIT/type-annotations

Add type annotations
This commit is contained in:
Jelmer 2024-06-19 17:24:56 +02:00 committed by GitHub
commit 731ed804f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 1142 additions and 642 deletions

View file

@ -1,2 +1,8 @@
[run]
include = model_utils/*.py
[report]
exclude_also =
# Exclusive to mypy:
if TYPE_CHECKING:$
\.\.\.$

View file

@ -1,7 +1,39 @@
from __future__ import annotations
import copy
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload
T = TypeVar("T")
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
# The type aliases defined here are evaluated when the django-stubs mypy plugin
# loads this module, so they must be able to execute under the lowest supported
# Python VM:
# - typing.List, typing.Tuple become obsolete in Pyton 3.9
# - typing.Union becomes obsolete in Pyton 3.10
from typing import List, Tuple, Union
from django_stubs_ext import StrOrPromise
# The type argument 'T' to 'Choices' is the database representation type.
_Double = Tuple[T, StrOrPromise]
_Triple = Tuple[T, str, StrOrPromise]
_Group = Tuple[StrOrPromise, Sequence["_Choice[T]"]]
_Choice = Union[_Double[T], _Triple[T], _Group[T]]
# Choices can only be given as a single string if 'T' is 'str'.
_GroupStr = Tuple[StrOrPromise, Sequence["_ChoiceStr"]]
_ChoiceStr = Union[str, _Double[str], _Triple[str], _GroupStr]
# Note that we only accept lists and tuples in groups, not arbitrary sequences.
# However, annotating it as such causes many problems.
_DoubleRead = Union[_Double[T], Tuple[StrOrPromise, Iterable["_DoubleRead[T]"]]]
_DoubleCollector = List[Union[_Double[T], Tuple[StrOrPromise, "_DoubleCollector[T]"]]]
_TripleCollector = List[Union[_Triple[T], Tuple[StrOrPromise, "_TripleCollector[T]"]]]
class Choices:
class Choices(Generic[T]):
"""
A class to encapsulate handy functionality for lists of choices
for a Django model field.
@ -41,36 +73,60 @@ class Choices:
"""
def __init__(self, *choices):
@overload
def __init__(self: Choices[str], *choices: _ChoiceStr):
...
@overload
def __init__(self, *choices: _Choice[T]):
...
def __init__(self, *choices: _ChoiceStr | _Choice[T]):
# list of choices expanded to triples - can include optgroups
self._triples = []
self._triples: _TripleCollector[T] = []
# list of choices as (db, human-readable) - can include optgroups
self._doubles = []
self._doubles: _DoubleCollector[T] = []
# dictionary mapping db representation to human-readable
self._display_map = {}
self._display_map: dict[T, StrOrPromise | list[_Triple[T]]] = {}
# dictionary mapping Python identifier to db representation
self._identifier_map = {}
self._identifier_map: dict[str, T] = {}
# set of db representations
self._db_values = set()
self._db_values: set[T] = set()
self._process(choices)
def _store(self, triple, triple_collector, double_collector):
def _store(
self,
triple: tuple[T, str, StrOrPromise],
triple_collector: _TripleCollector[T],
double_collector: _DoubleCollector[T]
) -> None:
self._identifier_map[triple[1]] = triple[0]
self._display_map[triple[0]] = triple[2]
self._db_values.add(triple[0])
triple_collector.append(triple)
double_collector.append((triple[0], triple[2]))
def _process(self, choices, triple_collector=None, double_collector=None):
def _process(
self,
choices: Iterable[_ChoiceStr | _Choice[T]],
triple_collector: _TripleCollector[T] | None = None,
double_collector: _DoubleCollector[T] | None = None
) -> None:
if triple_collector is None:
triple_collector = self._triples
if double_collector is None:
double_collector = self._doubles
store = lambda c: self._store(c, triple_collector, double_collector)
def store(c: tuple[Any, str, StrOrPromise]) -> None:
self._store(c, triple_collector, double_collector)
for choice in choices:
# The type inference is not very accurate here:
# - we lied in the type aliases, stating groups contain an arbitrary Sequence
# rather than only list or tuple
# - there is no way to express that _ChoiceStr is only used when T=str
# - mypy 1.9.0 doesn't narrow types based on the value of len()
if isinstance(choice, (list, tuple)):
if len(choice) == 3:
store(choice)
@ -79,13 +135,13 @@ class Choices:
# option group
group_name = choice[0]
subchoices = choice[1]
tc = []
tc: _TripleCollector[T] = []
triple_collector.append((group_name, tc))
dc = []
dc: _DoubleCollector[T] = []
double_collector.append((group_name, dc))
self._process(subchoices, tc, dc)
else:
store((choice[0], choice[0], choice[1]))
store((choice[0], cast(str, choice[0]), cast('StrOrPromise', choice[1])))
else:
raise ValueError(
"Choices can't take a list of length %s, only 2 or 3"
@ -94,54 +150,74 @@ class Choices:
else:
store((choice, choice, choice))
def __len__(self):
def __len__(self) -> int:
return len(self._doubles)
def __iter__(self):
def __iter__(self) -> Iterator[_DoubleRead[T]]:
return iter(self._doubles)
def __reversed__(self):
def __reversed__(self) -> Iterator[_DoubleRead[T]]:
return reversed(self._doubles)
def __getattr__(self, attname):
def __getattr__(self, attname: str) -> T:
try:
return self._identifier_map[attname]
except KeyError:
raise AttributeError(attname)
def __getitem__(self, key):
def __getitem__(self, key: T) -> StrOrPromise | Sequence[_Triple[T]]:
return self._display_map[key]
def __add__(self, other):
@overload
def __add__(self: Choices[str], other: Choices[str] | Iterable[_ChoiceStr]) -> Choices[str]:
...
@overload
def __add__(self, other: Choices[T] | Iterable[_Choice[T]]) -> Choices[T]:
...
def __add__(self, other: Choices[Any] | Iterable[_ChoiceStr | _Choice[Any]]) -> Choices[Any]:
other_args: list[Any]
if isinstance(other, self.__class__):
other = other._triples
other_args = other._triples
else:
other = list(other)
return Choices(*(self._triples + other))
other_args = list(other)
return Choices(*(self._triples + other_args))
def __radd__(self, other):
@overload
def __radd__(self: Choices[str], other: Iterable[_ChoiceStr]) -> Choices[str]:
...
@overload
def __radd__(self, other: Iterable[_Choice[T]]) -> Choices[T]:
...
def __radd__(self, other: Iterable[_ChoiceStr] | Iterable[_Choice[T]]) -> Choices[Any]:
# radd is never called for matching types, so we don't check here
other = list(other)
return Choices(*(other + self._triples))
other_args = list(other)
# The exact type of 'other' depends on our type argument 'T', which
# is expressed in the overloading, but lost within this method body.
return Choices(*(other_args + self._triples)) # type: ignore[arg-type]
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self._triples == other._triples
return False
def __repr__(self):
def __repr__(self) -> str:
return '{}({})'.format(
self.__class__.__name__,
', '.join("%s" % repr(i) for i in self._triples)
)
def __contains__(self, item):
def __contains__(self, item: T) -> bool:
return item in self._db_values
def __deepcopy__(self, memo):
return self.__class__(*copy.deepcopy(self._triples, memo))
def __deepcopy__(self, memo: dict[int, Any] | None) -> Choices[T]:
args: list[Any] = copy.deepcopy(self._triples, memo)
return self.__class__(*args)
def subset(self, *new_identifiers):
def subset(self, *new_identifiers: str) -> Choices[T]:
identifiers = set(self._identifier_map.keys())
if not identifiers.issuperset(new_identifiers):
@ -150,7 +226,8 @@ class Choices:
identifiers.symmetric_difference(new_identifiers),
)
return self.__class__(*[
args: list[Any] = [
choice for choice in self._triples
if choice[1] in new_identifiers
])
]
return self.__class__(*args)

View file

@ -1,15 +1,27 @@
from __future__ import annotations
import secrets
import uuid
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Union
from django.conf import settings
from django.core.exceptions import ValidationError
from django.db import models
from django.utils.timezone import now
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from datetime import date, datetime
DateTimeFieldBase = models.DateTimeField[Union[str, datetime, date], datetime]
else:
DateTimeFieldBase = models.DateTimeField
DEFAULT_CHOICES_NAME = 'STATUS'
class AutoCreatedField(models.DateTimeField):
class AutoCreatedField(DateTimeFieldBase):
"""
A DateTimeField that automatically populates itself at
object creation.
@ -18,7 +30,7 @@ class AutoCreatedField(models.DateTimeField):
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
kwargs.setdefault('editable', False)
kwargs.setdefault('default', now)
super().__init__(*args, **kwargs)
@ -31,13 +43,13 @@ class AutoLastModifiedField(AutoCreatedField):
By default, sets editable=False and default=datetime.now.
"""
def get_default(self):
def get_default(self) -> datetime:
"""Return the default value for this field."""
if not hasattr(self, "_default"):
self._default = self._get_default()
self._default = super().get_default()
return self._default
def pre_save(self, model_instance, add):
def pre_save(self, model_instance: models.Model, add: bool) -> datetime:
value = now()
if add:
current_value = getattr(model_instance, self.attname, self.get_default())
@ -68,13 +80,19 @@ class StatusField(models.CharField):
South can handle this field when it freezes a model.
"""
def __init__(self, *args, no_check_for_status=False, choices_name=DEFAULT_CHOICES_NAME, **kwargs):
def __init__(
self,
*args: Any,
no_check_for_status: bool = False,
choices_name: str = DEFAULT_CHOICES_NAME,
**kwargs: Any
):
kwargs.setdefault('max_length', 100)
self.check_for_status = not no_check_for_status
self.choices_name = choices_name
super().__init__(*args, **kwargs)
def prepare_class(self, sender, **kwargs):
def prepare_class(self, sender: type[models.Model], **kwargs: Any) -> None:
if not sender._meta.abstract and self.check_for_status:
assert hasattr(sender, self.choices_name), \
"To use StatusField, the model '%s' must have a %s choices class attribute." \
@ -83,7 +101,7 @@ class StatusField(models.CharField):
if not self.has_default():
self.default = tuple(getattr(sender, self.choices_name))[0][0] # set first as default
def contribute_to_class(self, cls, name, *args, **kwargs):
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
models.signals.class_prepared.connect(self.prepare_class, sender=cls)
# we don't set the real choices until class_prepared (so we can rely on
# the STATUS class attr being available), but we need to set some dummy
@ -91,13 +109,13 @@ class StatusField(models.CharField):
self.choices = [(0, 'dummy')]
super().contribute_to_class(cls, name, *args, **kwargs)
def deconstruct(self):
def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]:
name, path, args, kwargs = super().deconstruct()
kwargs['no_check_for_status'] = True
return name, path, args, kwargs
class MonitorField(models.DateTimeField):
class MonitorField(DateTimeFieldBase):
"""
A DateTimeField that monitors another field on the same model and
sets itself to the current date/time whenever the monitored field
@ -105,30 +123,28 @@ class MonitorField(models.DateTimeField):
"""
def __init__(self, *args, monitor, when=None, **kwargs):
def __init__(self, *args: Any, monitor: str, when: Iterable[Any] | None = None, **kwargs: Any):
default = None if kwargs.get("null") else now
kwargs.setdefault('default', default)
self.monitor = monitor
if when is not None:
when = set(when)
self.when = when
self.when = None if when is None else set(when)
super().__init__(*args, **kwargs)
def contribute_to_class(self, cls, name, *args, **kwargs):
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
self.monitor_attname = '_monitor_%s' % name
models.signals.post_init.connect(self._save_initial, sender=cls)
super().contribute_to_class(cls, name, *args, **kwargs)
def get_monitored_value(self, instance):
def get_monitored_value(self, instance: models.Model) -> Any:
return getattr(instance, self.monitor)
def _save_initial(self, sender, instance, **kwargs):
def _save_initial(self, sender: type[models.Model], instance: models.Model, **kwargs: Any) -> None:
if self.monitor in instance.get_deferred_fields():
# Fix related to issue #241 to avoid recursive error on double monitor fields
return
setattr(instance, self.monitor_attname, self.get_monitored_value(instance))
def pre_save(self, model_instance, add):
def pre_save(self, model_instance: models.Model, add: bool) -> Any:
value = now()
previous = getattr(model_instance, self.monitor_attname, None)
current = self.get_monitored_value(model_instance)
@ -138,7 +154,7 @@ class MonitorField(models.DateTimeField):
self._save_initial(model_instance.__class__, model_instance)
return super().pre_save(model_instance, add)
def deconstruct(self):
def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]:
name, path, args, kwargs = super().deconstruct()
kwargs['monitor'] = self.monitor
if self.when is not None:
@ -152,12 +168,12 @@ SPLIT_MARKER = getattr(settings, 'SPLIT_MARKER', '<!-- split -->')
SPLIT_DEFAULT_PARAGRAPHS = getattr(settings, 'SPLIT_DEFAULT_PARAGRAPHS', 2)
def _excerpt_field_name(name):
def _excerpt_field_name(name: str) -> str:
return '_%s_excerpt' % name
def get_excerpt(content):
excerpt = []
def get_excerpt(content: str) -> str:
excerpt: list[str] = []
default_excerpt = []
paras_seen = 0
for line in content.splitlines():
@ -173,7 +189,7 @@ def get_excerpt(content):
class SplitText:
def __init__(self, instance, field_name, excerpt_field_name):
def __init__(self, instance: models.Model, field_name: str, excerpt_field_name: str):
# instead of storing actual values store a reference to the instance
# along with field names, this makes assignment possible
self.instance = instance
@ -181,36 +197,36 @@ class SplitText:
self.excerpt_field_name = excerpt_field_name
@property
def content(self):
def content(self) -> str:
return self.instance.__dict__[self.field_name]
@content.setter
def content(self, val):
def content(self, val: str) -> None:
setattr(self.instance, self.field_name, val)
@property
def excerpt(self):
def excerpt(self) -> str:
return getattr(self.instance, self.excerpt_field_name)
@property
def has_more(self):
def has_more(self) -> bool:
return self.excerpt.strip() != self.content.strip()
def __str__(self):
def __str__(self) -> str:
return self.content
class SplitDescriptor:
def __init__(self, field):
def __init__(self, field: SplitField):
self.field = field
self.excerpt_field_name = _excerpt_field_name(self.field.name)
def __get__(self, instance, owner):
def __get__(self, instance: models.Model, owner: type[models.Model]) -> SplitText:
if instance is None:
raise AttributeError('Can only be accessed via an instance.')
return SplitText(instance, self.field.name, self.excerpt_field_name)
def __set__(self, obj, value):
def __set__(self, obj: models.Model, value: SplitText | str) -> None:
if isinstance(value, SplitText):
obj.__dict__[self.field.name] = value.content
setattr(obj, self.excerpt_field_name, value.excerpt)
@ -218,25 +234,32 @@ class SplitDescriptor:
obj.__dict__[self.field.name] = value
class SplitField(models.TextField):
def contribute_to_class(self, cls, name, *args, **kwargs):
if TYPE_CHECKING:
_SplitFieldBase = models.TextField[Union[SplitText, str], SplitText]
else:
_SplitFieldBase = models.TextField
class SplitField(_SplitFieldBase):
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
if not cls._meta.abstract:
excerpt_field = models.TextField(editable=False)
excerpt_field: models.TextField = models.TextField(editable=False)
cls.add_to_class(_excerpt_field_name(name), excerpt_field)
super().contribute_to_class(cls, name, *args, **kwargs)
setattr(cls, self.name, SplitDescriptor(self))
def pre_save(self, model_instance, add):
value = super().pre_save(model_instance, add)
def pre_save(self, model_instance: models.Model, add: bool) -> str:
value: SplitText = super().pre_save(model_instance, add)
excerpt = get_excerpt(value.content)
setattr(model_instance, _excerpt_field_name(self.attname), excerpt)
return value.content
def value_to_string(self, obj):
def value_to_string(self, obj: models.Model) -> str:
value = self.value_from_object(obj)
return value.content
def get_prep_value(self, value):
def get_prep_value(self, value: Any) -> str:
try:
return value.content
except AttributeError:
@ -248,7 +271,14 @@ class UUIDField(models.UUIDField):
A field for storing universally unique identifiers. Use Python UUID class.
"""
def __init__(self, primary_key=True, version=4, editable=False, *args, **kwargs):
def __init__(
self,
primary_key: bool = True,
version: int = 4,
editable: bool = False,
*args: Any,
**kwargs: Any
):
"""
Parameters
----------
@ -274,6 +304,7 @@ class UUIDField(models.UUIDField):
raise ValidationError(
'UUID version is not valid.')
default: Callable[..., uuid.UUID]
if version == 1:
default = uuid.uuid1
elif version == 3:
@ -294,7 +325,15 @@ class UrlsafeTokenField(models.CharField):
A field for storing a unique token in database.
"""
def __init__(self, editable=False, max_length=128, factory=None, **kwargs):
max_length: int
def __init__(
self,
editable: bool = False,
max_length: int = 128,
factory: Callable[[int], str] | None = None,
**kwargs: Any
):
"""
Parameters
----------
@ -319,14 +358,14 @@ class UrlsafeTokenField(models.CharField):
super().__init__(editable=editable, max_length=max_length, **kwargs)
def get_default(self):
def get_default(self) -> str:
if self._factory is not None:
return self._factory(self.max_length)
# generate a token of length x1.33 approx. trim up to max length
token = secrets.token_urlsafe(self.max_length)[:self.max_length]
return token
def deconstruct(self):
def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]:
name, path, args, kwargs = super().deconstruct()
kwargs['factory'] = self._factory
return name, path, args, kwargs

View file

@ -1,4 +1,7 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar, cast, overload
from django.core.exceptions import ObjectDoesNotExist
from django.db import connection, models
@ -7,57 +10,84 @@ from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import ModelIterable, QuerySet
from django.db.models.sql.datastructures import Join
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
class InheritanceIterable(ModelIterable):
def __iter__(self):
queryset = self.queryset
iter = ModelIterable(queryset)
if getattr(queryset, 'subclasses', False):
extras = tuple(queryset.query.extra.keys())
# sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in subclasses:
sub_obj = queryset._get_sub_obj_recurse(obj, s)
if sub_obj:
break
if not sub_obj:
sub_obj = obj
if TYPE_CHECKING:
from collections.abc import Iterator
if getattr(queryset, '_annotated', False):
for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k))
from django.db.models.query import BaseIterable
for k in extras:
def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]:
iter: ModelIterable[ModelT] = ModelIterable(queryset)
if hasattr(queryset, 'subclasses'):
assert hasattr(queryset, '_get_sub_obj_recurse')
extras = tuple(queryset.query.extra.keys())
# sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in subclasses:
sub_obj = queryset._get_sub_obj_recurse(obj, s)
if sub_obj:
break
if not sub_obj:
sub_obj = obj
if hasattr(queryset, '_annotated'):
for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k))
yield sub_obj
else:
yield from iter
for k in extras:
setattr(sub_obj, k, getattr(obj, k))
yield sub_obj
else:
yield from iter
class InheritanceQuerySetMixin:
def __init__(self, *args, **kwargs):
if TYPE_CHECKING:
class InheritanceIterable(ModelIterable[ModelT]):
queryset: QuerySet[ModelT]
def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any):
...
def __iter__(self) -> Iterator[ModelT]:
...
else:
class InheritanceIterable(ModelIterable):
def __iter__(self):
return _iter_inheritance_queryset(self.queryset)
class InheritanceQuerySetMixin(Generic[ModelT]):
model: type[ModelT]
subclasses: Sequence[str]
def __init__(self, *args: object, **kwargs: object):
super().__init__(*args, **kwargs)
self._iterable_class = InheritanceIterable
self._iterable_class: type[BaseIterable[ModelT]] = InheritanceIterable
def select_subclasses(self, *subclasses):
calculated_subclasses = self._get_subclasses_recurse(self.model)
def select_subclasses(self, *subclasses: str | type[models.Model]) -> InheritanceQuerySet[ModelT]:
model: type[ModelT] = self.model
calculated_subclasses = self._get_subclasses_recurse(model)
# if none were passed in, we can just short circuit and select all
if not subclasses:
subclasses = calculated_subclasses
selected_subclasses = calculated_subclasses
else:
verified_subclasses = []
verified_subclasses: list[str] = []
for subclass in subclasses:
# special case for passing in the same model as the queryset
# is bound against. Rather than raise an error later, we know
# we can allow this through.
if subclass is self.model:
if subclass is model:
continue
if not isinstance(subclass, (str,)):
if not isinstance(subclass, str):
subclass = self._get_ancestors_path(subclass)
if subclass in calculated_subclasses:
@ -67,38 +97,39 @@ class InheritanceQuerySetMixin:
'{!r} is not in the discovered subclasses, tried: {}'.format(
subclass, ', '.join(calculated_subclasses))
)
subclasses = verified_subclasses
selected_subclasses = verified_subclasses
if subclasses:
new_qs = self.select_related(*subclasses)
else:
new_qs = self
new_qs.subclasses = subclasses
new_qs = cast('InheritanceQuerySet[ModelT]', self)
if selected_subclasses:
new_qs = new_qs.select_related(*selected_subclasses)
new_qs.subclasses = selected_subclasses
return new_qs
def _chain(self, **kwargs):
def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]:
update = {}
for name in ['subclasses', '_annotated']:
if hasattr(self, name):
update[name] = getattr(self, name)
chained = super()._chain(**kwargs)
# django-stubs doesn't include this private API.
chained = super()._chain(**kwargs) # type: ignore[misc]
chained.__dict__.update(update)
return chained
def _clone(self):
qs = super()._clone()
def _clone(self) -> InheritanceQuerySet[ModelT]:
# django-stubs doesn't include this private API.
qs = super()._clone() # type: ignore[misc]
for name in ['subclasses', '_annotated']:
if hasattr(self, name):
setattr(qs, name, getattr(self, name))
return qs
def annotate(self, *args, **kwargs):
qset = super().annotate(*args, **kwargs)
def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
qset = cast(QuerySet[ModelT], super()).annotate(*args, **kwargs)
qset._annotated = [a.default_alias for a in args] + list(kwargs.keys())
return qset
def _get_subclasses_recurse(self, model):
def _get_subclasses_recurse(self, model: type[models.Model]) -> list[str]:
"""
Given a Model class, find all related objects, exploring children
recursively, returning a `list` of strings representing the
@ -124,7 +155,7 @@ class InheritanceQuerySetMixin:
subclasses.append(rel.get_accessor_name())
return subclasses
def _get_ancestors_path(self, model):
def _get_ancestors_path(self, model: type[models.Model]) -> str:
"""
Serves as an opposite to _get_subclasses_recurse, instead walking from
the Model class up the Model's ancestry and constructing the desired
@ -134,7 +165,7 @@ class InheritanceQuerySetMixin:
raise ValueError(
f"{model!r} is not a subclass of {self.model!r}")
ancestry = []
ancestry: list[str] = []
# should be a OneToOneField or None
parent_link = model._meta.get_ancestor_link(self.model)
@ -147,7 +178,7 @@ class InheritanceQuerySetMixin:
return LOOKUP_SEP.join(ancestry)
def _get_sub_obj_recurse(self, obj, s):
def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None:
rel, _, s = s.partition(LOOKUP_SEP)
try:
@ -160,12 +191,14 @@ class InheritanceQuerySetMixin:
else:
return node
def get_subclass(self, *args, **kwargs):
def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
return self.select_subclasses().get(*args, **kwargs)
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
def instance_of(self, *models):
# Defining the 'model' attribute using a generic type triggers a bug in mypy:
# https://github.com/python/mypy/issues/9031
class InheritanceQuerySet(InheritanceQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc]
def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
"""
Fetch only objects that are instances of the provided model(s).
"""
@ -188,88 +221,187 @@ class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
) for field in model._meta.parents.values()
]) + ')')
return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
return cast(
'InheritanceQuerySet[ModelT]',
self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
)
class InheritanceManagerMixin:
class InheritanceManagerMixin(Generic[ModelT]):
_queryset_class = InheritanceQuerySet
def get_queryset(self):
return self._queryset_class(self.model)
if TYPE_CHECKING:
from collections.abc import Sequence
def select_subclasses(self, *subclasses):
def none(self) -> InheritanceQuerySet[ModelT]:
...
def all(self) -> InheritanceQuerySet[ModelT]:
...
def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
...
def exclude(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
...
def complex_filter(self, filter_obj: Any) -> InheritanceQuerySet[ModelT]:
...
def union(self, *other_qs: Any, all: bool = ...) -> InheritanceQuerySet[ModelT]:
...
def intersection(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]:
...
def difference(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]:
...
def select_for_update(
self, nowait: bool = ..., skip_locked: bool = ..., of: Sequence[str] = ..., no_key: bool = ...
) -> InheritanceQuerySet[ModelT]:
...
def select_related(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
...
def prefetch_related(self, *lookups: Any) -> InheritanceQuerySet[ModelT]:
...
def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
...
def alias(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
...
def order_by(self, *field_names: Any) -> InheritanceQuerySet[ModelT]:
...
def distinct(self, *field_names: Any) -> InheritanceQuerySet[ModelT]:
...
def extra(
self,
select: dict[str, Any] | None = ...,
where: list[str] | None = ...,
params: list[Any] | None = ...,
tables: list[str] | None = ...,
order_by: Sequence[str] | None = ...,
select_params: Sequence[Any] | None = ...,
) -> InheritanceQuerySet[Any]:
...
def reverse(self) -> InheritanceQuerySet[ModelT]:
...
def defer(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
...
def only(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
...
def using(self, alias: str | None) -> InheritanceQuerySet[ModelT]:
...
def get_queryset(self) -> InheritanceQuerySet[ModelT]:
model: type[ModelT] = self.model # type: ignore[attr-defined]
return self._queryset_class(model)
def select_subclasses(
self, *subclasses: str | type[models.Model]
) -> InheritanceQuerySet[ModelT]:
return self.get_queryset().select_subclasses(*subclasses)
def get_subclass(self, *args, **kwargs):
def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
return self.get_queryset().get_subclass(*args, **kwargs)
def instance_of(self, *models):
def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
return self.get_queryset().instance_of(*models)
class InheritanceManager(InheritanceManagerMixin, models.Manager):
class InheritanceManager(InheritanceManagerMixin[ModelT], models.Manager[ModelT]):
pass
class QueryManagerMixin:
class QueryManagerMixin(Generic[ModelT]):
def __init__(self, *args, **kwargs):
@overload
def __init__(self, *args: models.Q):
...
@overload
def __init__(self, **kwargs: object):
...
def __init__(self, *args: models.Q, **kwargs: object):
if args:
self._q = args[0]
else:
self._q = models.Q(**kwargs)
self._order_by = None
self._order_by: tuple[Any, ...] | None = None
super().__init__()
def order_by(self, *args):
def order_by(self, *args: Any) -> QueryManager[ModelT]:
self._order_by = args
return self
return cast('QueryManager[ModelT]', self)
def get_queryset(self):
qs = super().get_queryset().filter(self._q)
def get_queryset(self) -> QuerySet[ModelT]:
qs = super().get_queryset() # type: ignore[misc]
qs = qs.filter(self._q)
if self._order_by is not None:
return qs.order_by(*self._order_by)
return qs
class QueryManager(QueryManagerMixin, models.Manager):
class QueryManager(QueryManagerMixin[ModelT], models.Manager[ModelT]): # type: ignore[misc]
pass
class SoftDeletableQuerySetMixin:
class SoftDeletableQuerySetMixin(Generic[ModelT]):
"""
QuerySet for SoftDeletableModel. Instead of removing instance sets
its ``is_removed`` field to True.
"""
def delete(self):
def delete(self) -> None:
"""
Soft delete objects from queryset (set their ``is_removed``
field to True)
"""
self.update(is_removed=True)
cast(QuerySet[ModelT], self).update(is_removed=True)
class SoftDeletableQuerySet(SoftDeletableQuerySetMixin, QuerySet):
# Note that our delete() method does not return anything, unlike Django's.
# https://github.com/jazzband/django-model-utils/issues/541
class SoftDeletableQuerySet(SoftDeletableQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc]
pass
class SoftDeletableManagerMixin:
class SoftDeletableManagerMixin(Generic[ModelT]):
"""
Manager that limits the queryset by default to show only not removed
instances of model.
"""
_queryset_class = SoftDeletableQuerySet
def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs):
_db: str | None
def __init__(
self,
*args: object,
_emit_deprecation_warnings: bool = False,
**kwargs: object
):
self.emit_deprecation_warnings = _emit_deprecation_warnings
super().__init__(*args, **kwargs)
def get_queryset(self):
def get_queryset(self) -> SoftDeletableQuerySet[ModelT]:
"""
Return queryset limited to not removed entries.
"""
model: type[ModelT] = self.model # type: ignore[attr-defined]
if self.emit_deprecation_warnings:
warning_message = (
"{0}.objects model manager will include soft-deleted objects in an "
@ -277,23 +409,23 @@ class SoftDeletableManagerMixin:
"excluding soft-deleted objects. See "
"https://django-model-utils.readthedocs.io/en/stable/models.html"
"#softdeletablemodel for more information."
).format(self.model.__class__.__name__)
).format(model.__class__.__name__)
warnings.warn(warning_message, DeprecationWarning)
kwargs = {'model': self.model, 'using': self._db}
if hasattr(self, '_hints'):
kwargs['hints'] = self._hints
return self._queryset_class(**kwargs).filter(is_removed=False)
return self._queryset_class(
model=model,
using=self._db,
**({'hints': self._hints} if hasattr(self, '_hints') else {})
).filter(is_removed=False)
class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager):
class SoftDeletableManager(SoftDeletableManagerMixin[ModelT], models.Manager[ModelT]):
pass
class JoinQueryset(models.QuerySet):
class JoinQueryset(models.QuerySet[Any]):
def join(self, qs=None):
def join(self, qs: QuerySet[Any] | None = None) -> QuerySet[Any]:
'''
Join one queryset together with another using a temporary table. If
no queryset is used, it will use the current queryset and join that
@ -308,11 +440,11 @@ class JoinQueryset(models.QuerySet):
to_field = 'id'
if qs:
fk = [
fks = [
fk for fk in qs.model._meta.fields
if getattr(fk, 'related_model', None) == self.model
]
fk = fk[0] if fk else None
fk = fks[0] if fks else None
model_set = f'{self.model.__name__.lower()}_set'
key = fk or getattr(qs.model, model_set, None)
@ -331,7 +463,7 @@ class JoinQueryset(models.QuerySet):
else:
fk_column = 'id'
qs = self.only(fk_column)
new_qs = self.model.objects.all()
new_qs = self.model._default_manager.all()
TABLE_NAME = 'temp_stuff'
query, params = qs.query.sql_with_params()
@ -369,21 +501,24 @@ class JoinQueryset(models.QuerySet):
return new_qs
class JoinManagerMixin:
"""
Manager that adds a method join. This method allows you to join two
querysets together.
"""
_queryset_class = JoinQueryset
if not TYPE_CHECKING:
# Hide deprecated API during type checking, to encourage switch to
# 'JoinQueryset.as_manager()', which is supported by the mypy plugin
# of django-stubs.
def get_queryset(self):
warnings.warn(
"JoinManager and JoinManagerMixin are deprecated. "
"Please use 'JoinQueryset.as_manager()' instead.",
DeprecationWarning
)
return self._queryset_class(model=self.model, using=self._db)
class JoinManagerMixin:
"""
Manager that adds a method join. This method allows you to join two
querysets together.
"""
def get_queryset(self):
warnings.warn(
"JoinManager and JoinManagerMixin are deprecated. "
"Please use 'JoinQueryset.as_manager()' instead.",
DeprecationWarning
)
return self._queryset_class(model=self.model, using=self._db)
class JoinManager(JoinManagerMixin, models.Manager):
pass
class JoinManager(JoinManagerMixin):
pass

View file

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Any, Literal, TypeVar, overload
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.db.models.functions import Now
@ -12,6 +16,8 @@ from model_utils.fields import (
)
from model_utils.managers import QueryManager, SoftDeletableManager
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
now = Now()
@ -24,7 +30,7 @@ class TimeStampedModel(models.Model):
created = AutoCreatedField(_('created'))
modified = AutoLastModifiedField(_('modified'))
def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
"""
Overriding the save method in order to make sure that
modified field is updated even if it is not given as
@ -65,7 +71,7 @@ class StatusModel(models.Model):
status = StatusField(_('status'))
status_changed = MonitorField(_('status changed'), monitor='status')
def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
"""
Overriding the save method in order to make sure that
status_changed field is updated even if it is not given as
@ -81,7 +87,7 @@ class StatusModel(models.Model):
abstract = True
def add_status_query_managers(sender, **kwargs):
def add_status_query_managers(sender: type[models.Model], **kwargs: Any) -> None:
"""
Add a Querymanager for each status item dynamically.
@ -90,6 +96,7 @@ def add_status_query_managers(sender, **kwargs):
return
default_manager = sender._meta.default_manager
assert default_manager is not None
for value, display in getattr(sender, 'STATUS', ()):
if _field_exists(sender, value):
@ -103,7 +110,7 @@ def add_status_query_managers(sender, **kwargs):
sender._meta.default_manager_name = default_manager.name
def add_timeframed_query_manager(sender, **kwargs):
def add_timeframed_query_manager(sender: type[models.Model], **kwargs: Any) -> None:
"""
Add a QueryManager for a specific timeframe.
@ -126,7 +133,7 @@ models.signals.class_prepared.connect(add_status_query_managers)
models.signals.class_prepared.connect(add_timeframed_query_manager)
def _field_exists(model_class, field_name):
def _field_exists(model_class: type[models.Model], field_name: str) -> bool:
return field_name in [f.attname for f in model_class._meta.local_fields]
@ -142,11 +149,28 @@ class SoftDeletableModel(models.Model):
class Meta:
abstract = True
objects = SoftDeletableManager(_emit_deprecation_warnings=True)
available_objects = SoftDeletableManager()
objects: models.Manager[SoftDeletableModel] = SoftDeletableManager(_emit_deprecation_warnings=True)
available_objects: models.Manager[SoftDeletableModel] = SoftDeletableManager()
all_objects = models.Manager()
def delete(self, using=None, *args, soft=True, **kwargs):
# Note that soft delete does not return anything,
# which doesn't conform to Django's interface.
# https://github.com/jazzband/django-model-utils/issues/541
@overload # type: ignore[override]
def delete(
self, using: Any = None, *args: Any, soft: Literal[True] = True, **kwargs: Any
) -> None:
...
@overload
def delete(
self, using: Any = None, *args: Any, soft: Literal[False], **kwargs: Any
) -> tuple[int, dict[str, int]]:
...
def delete(
self, using: Any = None, *args: Any, soft: bool = True, **kwargs: Any
) -> tuple[int, dict[str, int]] | None:
"""
Soft delete object (set its ``is_removed`` field to True).
Actually delete object if setting ``soft`` to False.
@ -154,6 +178,7 @@ class SoftDeletableModel(models.Model):
if soft:
self.is_removed = True
self.save(using=using)
return None
else:
return super().delete(using, *args, **kwargs)

View file

@ -1,10 +1,45 @@
from __future__ import annotations
from copy import deepcopy
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Protocol,
TypeVar,
cast,
overload,
)
from django.core.exceptions import FieldError
from django.db import models
from django.db.models.fields.files import FieldFile
if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from types import TracebackType
class _AugmentedModel(models.Model):
_instance_initialized: bool
_deferred_fields: set[str]
T = TypeVar("T")
class Descriptor(Protocol[T]):
def __get__(self, instance: object, owner: type[object]) -> T:
...
def __set__(self, instance: object, value: T) -> None:
...
class FullDescriptor(Descriptor[T]):
def __delete__(self, instance: object) -> None:
...
class LightStateFieldFile(FieldFile):
"""
@ -16,7 +51,7 @@ class LightStateFieldFile(FieldFile):
Django 3.1+ can make the app unusable, as CPU and memory usage gets easily
multiplied by magnitudes.
"""
def __getstate__(self):
def __getstate__(self) -> dict[str, Any]:
"""
We don't need to deepcopy the instance, so nullify if provided.
"""
@ -26,27 +61,35 @@ class LightStateFieldFile(FieldFile):
return state
def lightweight_deepcopy(value):
def lightweight_deepcopy(value: T) -> T:
"""
Use our lightweight class to avoid copying the instance on a FieldFile deepcopy.
"""
if isinstance(value, FieldFile):
value = LightStateFieldFile(
value = cast(T, LightStateFieldFile(
instance=value.instance,
field=value.field,
name=value.name,
)
))
return deepcopy(value)
class DescriptorWrapper:
class DescriptorWrapper(Generic[T]):
def __init__(self, field_name, descriptor, tracker_attname):
def __init__(self, field_name: str, descriptor: Descriptor[T], tracker_attname: str):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname
def __get__(self, instance, owner):
@overload
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper[T]:
...
@overload
def __get__(self, instance: models.Model, owner: type[models.Model]) -> T:
...
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper[T] | T:
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
@ -56,7 +99,7 @@ class DescriptorWrapper:
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
return value
def __set__(self, instance, value):
def __set__(self, instance: models.Model, value: T) -> None:
initialized = hasattr(instance, '_instance_initialized')
was_deferred = self.field_name in instance.get_deferred_fields()
@ -79,23 +122,23 @@ class DescriptorWrapper:
else:
instance.__dict__[self.field_name] = value
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> T:
return getattr(self.descriptor, attr)
@staticmethod
def cls_for_descriptor(descriptor):
def cls_for_descriptor(descriptor: Descriptor[T]) -> type[DescriptorWrapper[T]]:
if hasattr(descriptor, '__delete__'):
return FullDescriptorWrapper
else:
return DescriptorWrapper
class FullDescriptorWrapper(DescriptorWrapper):
class FullDescriptorWrapper(DescriptorWrapper[T]):
"""
Wrapper for descriptors with all three descriptor methods.
"""
def __delete__(self, obj):
self.descriptor.__delete__(obj)
def __delete__(self, obj: models.Model) -> None:
cast(FullDescriptor[T], self.descriptor).__delete__(obj)
class FieldsContext:
@ -119,7 +162,12 @@ class FieldsContext:
"""
def __init__(self, tracker, *fields, state=None):
def __init__(
self,
tracker: FieldInstanceTracker,
*fields: str,
state: dict[str, int] | None = None
):
"""
:param tracker: FieldInstanceTracker instance to be reset after
context exit
@ -137,7 +185,7 @@ class FieldsContext:
self.fields = fields
self.state = state
def __enter__(self):
def __enter__(self) -> FieldsContext:
"""
Increments tracked fields occurrences count in shared state.
"""
@ -146,7 +194,12 @@ class FieldsContext:
self.state[f] += 1
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
"""
Decrements tracked fields occurrences count in shared state.
@ -164,29 +217,34 @@ class FieldsContext:
class FieldInstanceTracker:
def __init__(self, instance, fields, field_map):
self.instance = instance
def __init__(self, instance: models.Model, fields: Iterable[str], field_map: Mapping[str, str]):
self.instance = cast('_AugmentedModel', instance)
self.fields = fields
self.field_map = field_map
self.context = FieldsContext(self, *self.fields)
def __enter__(self):
def __enter__(self) -> FieldsContext:
return self.context.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
return self.context.__exit__(exc_type, exc_val, exc_tb)
def __call__(self, *fields):
def __call__(self, *fields: str) -> FieldsContext:
return FieldsContext(self, *fields, state=self.context.state)
@property
def deferred_fields(self):
def deferred_fields(self) -> set[str]:
return self.instance.get_deferred_fields()
def get_field_value(self, field):
def get_field_value(self, field: str) -> Any:
return getattr(self.instance, self.field_map[field])
def set_saved_fields(self, fields=None):
def set_saved_fields(self, fields: Iterable[str] | None = None) -> None:
if not self.instance.pk:
self.saved_data = {}
elif fields is None:
@ -198,7 +256,7 @@ class FieldInstanceTracker:
for field, field_value in self.saved_data.items():
self.saved_data[field] = lightweight_deepcopy(field_value)
def current(self, fields=None):
def current(self, fields: Iterable[str] | None = None) -> dict[str, Any]:
"""Returns dict of current values for all tracked fields"""
if fields is None:
deferred_fields = self.deferred_fields
@ -212,17 +270,19 @@ class FieldInstanceTracker:
return {f: self.get_field_value(f) for f in fields}
def has_changed(self, field):
def has_changed(self, field: str) -> bool:
"""Returns ``True`` if field has changed from currently saved value"""
if field in self.fields:
# deferred fields haven't changed
if field in self.deferred_fields and field not in self.instance.__dict__:
return False
return self.previous(field) != self.get_field_value(field)
prev: object = self.previous(field)
curr: object = self.get_field_value(field)
return prev != curr
else:
raise FieldError('field "%s" not tracked' % field)
def previous(self, field):
def previous(self, field: str) -> Any:
"""Returns currently saved value of given field"""
# handle deferred fields that have not yet been loaded from the database
@ -242,7 +302,7 @@ class FieldInstanceTracker:
return self.saved_data.get(field)
def changed(self):
def changed(self) -> dict[str, Any]:
"""Returns dict of fields that changed since save (with old values)"""
return {
field: self.previous(field)
@ -255,13 +315,34 @@ class FieldTracker:
tracker_class = FieldInstanceTracker
def __init__(self, fields=None):
self.fields = fields
def __init__(self, fields: Iterable[str] | None = None):
# finalize_class() will replace None; pretend it is never None.
self.fields = cast(Iterable[str], fields)
def __call__(self, func=None, fields=None):
def decorator(f):
@overload
def __call__(
self,
func: None = None,
fields: Iterable[str] | None = None
) -> Callable[[Callable[..., T]], Callable[..., T]]:
...
@overload
def __call__(
self,
func: Callable[..., T],
fields: Iterable[str] | None = None
) -> Callable[..., T]:
...
def __call__(
self,
func: Callable[..., T] | None = None,
fields: Iterable[str] | None = None
) -> Callable[[Callable[..., T]], Callable[..., T]] | Callable[..., T]:
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@wraps(f)
def inner(obj, *args, **kwargs):
def inner(obj: models.Model, *args: object, **kwargs: object) -> T:
tracker = getattr(obj, self.attname)
field_list = tracker.fields if fields is None else fields
with tracker(*field_list):
@ -272,7 +353,7 @@ class FieldTracker:
return decorator
return decorator(func)
def get_field_map(self, cls):
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
"""Returns dict mapping fields names to model attribute names"""
field_map = {field: field for field in self.fields}
all_fields = {f.name: f.attname for f in cls._meta.fields}
@ -280,17 +361,17 @@ class FieldTracker:
if k in field_map})
return field_map
def contribute_to_class(self, cls, name):
def contribute_to_class(self, cls: type[models.Model], name: str) -> None:
self.name = name
self.attname = '_%s' % name
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
def finalize_class(self, sender, **kwargs):
if self.fields is None:
def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
if self.fields is None or TYPE_CHECKING:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
for field_name in self.fields:
descriptor = getattr(sender, field_name)
descriptor: models.Field[Any, Any] = getattr(sender, field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
@ -300,34 +381,39 @@ class FieldTracker:
setattr(sender, self.name, self)
self.patch_save(sender)
def initialize_tracker(self, sender, instance, **kwargs):
def initialize_tracker(
self,
sender: type[models.Model],
instance: models.Model,
**kwargs: object
) -> None:
if not isinstance(instance, self.model_class):
return # Only init instances of given model (including children)
tracker = self.tracker_class(instance, self.fields, self.field_map)
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
instance._instance_initialized = True
cast('_AugmentedModel', instance)._instance_initialized = True
def patch_init(self, model):
def patch_init(self, model: type[models.Model]) -> None:
original = getattr(model, '__init__')
@wraps(original)
def inner(instance, *args, **kwargs):
def inner(instance: models.Model, *args: Any, **kwargs: Any) -> None:
original(instance, *args, **kwargs)
self.initialize_tracker(model, instance)
setattr(model, '__init__', inner)
def patch_save(self, model):
def patch_save(self, model: type[models.Model]) -> None:
self._patch(model, 'save_base', 'update_fields')
self._patch(model, 'refresh_from_db', 'fields')
def _patch(self, model, method, fields_kwarg):
def _patch(self, model: type[models.Model], method: str, fields_kwarg: str) -> None:
original = getattr(model, method)
@wraps(original)
def inner(instance, *args, **kwargs):
update_fields = kwargs.get(fields_kwarg)
def inner(instance: models.Model, *args: object, **kwargs: Any) -> object:
update_fields: Iterable[str] | None = kwargs.get(fields_kwarg)
if update_fields is None:
fields = self.fields
else:
@ -341,7 +427,15 @@ class FieldTracker:
setattr(model, method, inner)
def __get__(self, instance, owner):
@overload
def __get__(self, instance: None, owner: type[models.Model]) -> FieldTracker:
...
@overload
def __get__(self, instance: models.Model, owner: type[models.Model]) -> FieldInstanceTracker:
...
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> FieldTracker | FieldInstanceTracker:
if instance is None:
return self
else:
@ -350,16 +444,18 @@ class FieldTracker:
class ModelInstanceTracker(FieldInstanceTracker):
def has_changed(self, field):
def has_changed(self, field: str) -> bool:
"""Returns ``True`` if field has changed from currently saved value"""
if not self.instance.pk:
return True
elif field in self.saved_data:
return self.previous(field) != self.get_field_value(field)
prev: object = self.previous(field)
curr: object = self.get_field_value(field)
return prev != curr
else:
raise FieldError('field "%s" not tracked' % field)
def changed(self):
def changed(self) -> dict[str, Any]:
"""Returns dict of fields that changed since save (with old values)"""
if not self.instance.pk:
return {}
@ -371,5 +467,5 @@ class ModelInstanceTracker(FieldInstanceTracker):
class ModelTracker(FieldTracker):
tracker_class = ModelInstanceTracker
def get_field_map(self, cls):
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
return {field: field for field in self.fields}

View file

@ -1,4 +1,6 @@
[mypy]
disallow_incomplete_defs=True
disallow_untyped_defs=True
implicit_reexport=False
pretty=True
show_error_codes=True

View file

@ -1,2 +1,3 @@
mypy==1.9.0
django-stubs==4.2.7
mypy==1.10.0
django-stubs==5.0.2
pytest==7.4.3

View file

@ -1,7 +1,12 @@
from __future__ import annotations
from typing import Any
from django.db import models
from django.db.backends.base.base import BaseDatabaseWrapper
def mutable_from_db(value):
def mutable_from_db(value: object) -> Any:
if value == '':
return None
try:
@ -12,7 +17,7 @@ def mutable_from_db(value):
return value
def mutable_to_db(value):
def mutable_to_db(value: object) -> str:
if value is None:
return ''
if isinstance(value, list):
@ -21,12 +26,12 @@ def mutable_to_db(value):
class MutableField(models.TextField):
def to_python(self, value):
def to_python(self, value: object) -> Any:
return mutable_from_db(value)
def from_db_value(self, value, expression, connection):
def from_db_value(self, value: object, expression: object, connection: BaseDatabaseWrapper) -> Any:
return mutable_from_db(value)
def get_db_prep_save(self, value, connection):
def get_db_prep_save(self, value: object, connection: BaseDatabaseWrapper) -> str:
value = super().get_db_prep_save(value, connection)
return mutable_to_db(value)

View file

@ -1,9 +1,10 @@
from __future__ import annotations
from typing import ClassVar
from typing import Any, ClassVar, TypeVar, overload
from django.db import models
from django.db.models import Manager
from django.db.models.query import QuerySet
from django.db.models.query_utils import DeferredAttribute
from django.utils.translation import gettext_lazy as _
@ -11,7 +12,7 @@ from model_utils import Choices
from model_utils.fields import MonitorField, SplitField, StatusField, UUIDField
from model_utils.managers import (
InheritanceManager,
JoinManager,
JoinQueryset,
QueryManager,
SoftDeletableManager,
SoftDeletableQuerySet,
@ -26,6 +27,8 @@ from model_utils.models import (
from model_utils.tracker import FieldTracker, ModelTracker
from tests.fields import MutableField
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
class InheritanceManagerTestRelated(models.Model):
pass
@ -43,7 +46,7 @@ class InheritanceManagerTestParent(models.Model):
on_delete=models.CASCADE)
objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager()
def __str__(self):
def __str__(self) -> str:
return "{}({})".format(
self.__class__.__name__[len('InheritanceManagerTest'):],
self.pk,
@ -128,7 +131,7 @@ class DoubleMonitored(models.Model):
class Status(StatusModel):
STATUS = Choices(
STATUS: Choices[str] = Choices(
("active", _("active")),
("deleted", _("deleted")),
("on_hold", _("on hold")),
@ -184,7 +187,8 @@ class Post(models.Model):
public: ClassVar[QueryManager[Post]] = QueryManager(published=True)
public_confirmed: ClassVar[QueryManager[Post]] = QueryManager(
models.Q(published=True) & models.Q(confirmed=True))
public_reversed = QueryManager(published=True).order_by("-order")
public_reversed: ClassVar[QueryManager[Post]] = QueryManager(
published=True).order_by("-order")
class Meta:
ordering = ("order",)
@ -203,7 +207,6 @@ class SplitFieldAbstractParent(models.Model):
class AbstractTracked(models.Model):
number: models.IntegerField
class Meta:
abstract = True
@ -216,7 +219,7 @@ class Tracked(models.Model):
tracker = FieldTracker()
def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
""" No-op save() to ensure that FieldTracker.patch_save() works. """
super().save(*args, **kwargs)
@ -228,7 +231,7 @@ class TrackerTimeStamped(TimeStampedModel):
tracker = FieldTracker()
def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
""" Automatically add "modified" to update_fields."""
update_fields = kwargs.get('update_fields')
if update_fields is not None:
@ -263,7 +266,7 @@ class TrackedNonFieldAttr(models.Model):
number = models.FloatField()
@property
def rounded(self):
def rounded(self) -> int | None:
return round(self.number) if self.number is not None else None
tracker = FieldTracker(fields=['rounded'])
@ -352,43 +355,52 @@ class SoftDeletable(SoftDeletableModel):
all_objects: ClassVar[Manager[SoftDeletable]] = models.Manager()
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet):
def only_read(self):
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet[ModelT]):
def only_read(self) -> QuerySet[ModelT]:
return self.filter(is_read=True)
class CustomSoftDelete(SoftDeletableModel):
is_read = models.BooleanField(default=False)
available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() # type: ignore[misc]
available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)()
class StringyDescriptor:
"""
Descriptor that returns a string version of the underlying integer value.
"""
def __init__(self, name):
def __init__(self, name: str):
self.name = name
def __get__(self, obj, cls=None):
@overload
def __get__(self, obj: None, cls: type[models.Model] | None = None) -> StringyDescriptor:
...
@overload
def __get__(self, obj: models.Model, cls: type[models.Model]) -> str:
...
def __get__(self, obj: models.Model | None, cls: type[models.Model] | None = None) -> StringyDescriptor | str:
if obj is None:
return self
if self.name in obj.get_deferred_fields():
# This queries the database, and sets the value on the instance.
assert cls is not None
fields_map = {f.name: f for f in cls._meta.fields}
field = fields_map[self.name]
DeferredAttribute(field=field).__get__(obj, cls)
return str(obj.__dict__[self.name])
def __set__(self, obj, value):
def __set__(self, obj: object, value: str) -> None:
obj.__dict__[self.name] = int(value)
def __delete__(self, obj):
def __delete__(self, obj: object) -> None:
del obj.__dict__[self.name]
class CustomDescriptorField(models.IntegerField):
def contribute_to_class(self, cls, name, *args, **kwargs):
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
super().contribute_to_class(cls, name, *args, **kwargs)
setattr(cls, name, StringyDescriptor(name))
@ -404,7 +416,7 @@ class ModelWithCustomDescriptor(models.Model):
class BoxJoinModel(models.Model):
name = models.CharField(max_length=32)
objects: ClassVar[JoinManager[BoxJoinModel]] = JoinManager()
objects = JoinQueryset.as_manager()
class JoinItemForeignKey(models.Model):
@ -414,7 +426,7 @@ class JoinItemForeignKey(models.Model):
null=True,
on_delete=models.CASCADE
)
objects: ClassVar[JoinManager[JoinItemForeignKey]] = JoinManager()
objects = JoinQueryset.as_manager()
class CustomUUIDModel(UUIDModel):

View file

@ -1,116 +1,129 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generic, TypeVar
import pytest
from django.test import TestCase
from model_utils import Choices
T = TypeVar("T")
class ChoicesTests(TestCase):
def setUp(self):
class ChoicesTestsMixin(Generic[T]):
STATUS: Choices[T]
def test_getattr(self) -> None:
assert self.STATUS.DRAFT == 'DRAFT'
def test_len(self) -> None:
assert len(self.STATUS) == 2
def test_repr(self) -> None:
assert repr(self.STATUS) == "Choices" + repr((
('DRAFT', 'DRAFT', 'DRAFT'),
('PUBLISHED', 'PUBLISHED', 'PUBLISHED'),
))
def test_wrong_length_tuple(self) -> None:
with pytest.raises(ValueError):
Choices(('a',)) # type: ignore[arg-type]
def test_deepcopy(self) -> None:
import copy
assert list(self.STATUS) == list(copy.deepcopy(self.STATUS))
def test_equality(self) -> None:
assert self.STATUS == Choices('DRAFT', 'PUBLISHED')
def test_inequality(self) -> None:
assert self.STATUS != ['DRAFT', 'PUBLISHED']
assert self.STATUS != Choices('DRAFT')
def test_composability(self) -> None:
assert Choices('DRAFT') + Choices('PUBLISHED') == self.STATUS
assert Choices('DRAFT') + ('PUBLISHED',) == self.STATUS
assert ('DRAFT',) + Choices('PUBLISHED') == self.STATUS
def test_option_groups(self) -> None:
# Note: The implementation accepts any kind of sequence, but the type system can only
# track per-index types for tuples.
if TYPE_CHECKING:
c = Choices(('group a', ['one', 'two']), ('group b', ('three',)))
else:
c = Choices(('group a', ['one', 'two']), ['group b', ('three',)])
assert list(c) == [
('group a', [('one', 'one'), ('two', 'two')]),
('group b', [('three', 'three')]),
]
class ChoicesTests(TestCase, ChoicesTestsMixin[str]):
def setUp(self) -> None:
self.STATUS = Choices('DRAFT', 'PUBLISHED')
def test_getattr(self):
self.assertEqual(self.STATUS.DRAFT, 'DRAFT')
def test_indexing(self):
def test_indexing(self) -> None:
self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED')
def test_iteration(self):
def test_iteration(self) -> None:
self.assertEqual(tuple(self.STATUS),
(('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED')))
def test_reversed(self):
def test_reversed(self) -> None:
self.assertEqual(tuple(reversed(self.STATUS)),
(('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT')))
def test_len(self):
self.assertEqual(len(self.STATUS), 2)
def test_repr(self):
self.assertEqual(repr(self.STATUS), "Choices" + repr((
('DRAFT', 'DRAFT', 'DRAFT'),
('PUBLISHED', 'PUBLISHED', 'PUBLISHED'),
)))
def test_wrong_length_tuple(self):
with self.assertRaises(ValueError):
Choices(('a',))
def test_contains_value(self):
def test_contains_value(self) -> None:
self.assertTrue('PUBLISHED' in self.STATUS)
self.assertTrue('DRAFT' in self.STATUS)
def test_doesnt_contain_value(self):
def test_doesnt_contain_value(self) -> None:
self.assertFalse('UNPUBLISHED' in self.STATUS)
def test_deepcopy(self):
import copy
self.assertEqual(list(self.STATUS),
list(copy.deepcopy(self.STATUS)))
def test_equality(self):
self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED'))
def test_inequality(self):
self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED'])
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_composability(self):
self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS)
self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS)
self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS)
def test_option_groups(self):
c = Choices(('group a', ['one', 'two']), ['group b', ('three',)])
self.assertEqual(
list(c),
[
('group a', [('one', 'one'), ('two', 'two')]),
('group b', [('three', 'three')]),
],
)
class LabelChoicesTests(ChoicesTests):
def setUp(self):
class LabelChoicesTests(TestCase, ChoicesTestsMixin[str]):
def setUp(self) -> None:
self.STATUS = Choices(
('DRAFT', 'is draft'),
('PUBLISHED', 'is published'),
'DELETED',
)
def test_iteration(self):
def test_iteration(self) -> None:
self.assertEqual(tuple(self.STATUS), (
('DRAFT', 'is draft'),
('PUBLISHED', 'is published'),
('DELETED', 'DELETED'),
))
def test_reversed(self):
def test_reversed(self) -> None:
self.assertEqual(tuple(reversed(self.STATUS)), (
('DELETED', 'DELETED'),
('PUBLISHED', 'is published'),
('DRAFT', 'is draft'),
))
def test_indexing(self):
def test_indexing(self) -> None:
self.assertEqual(self.STATUS['PUBLISHED'], 'is published')
def test_default(self):
def test_default(self) -> None:
self.assertEqual(self.STATUS.DELETED, 'DELETED')
def test_provided(self):
def test_provided(self) -> None:
self.assertEqual(self.STATUS.DRAFT, 'DRAFT')
def test_len(self):
def test_len(self) -> None:
self.assertEqual(len(self.STATUS), 3)
def test_equality(self):
def test_equality(self) -> None:
self.assertEqual(self.STATUS, Choices(
('DRAFT', 'is draft'),
('PUBLISHED', 'is published'),
'DELETED',
))
def test_inequality(self):
def test_inequality(self) -> None:
self.assertNotEqual(self.STATUS, [
('DRAFT', 'is draft'),
('PUBLISHED', 'is published'),
@ -118,27 +131,27 @@ class LabelChoicesTests(ChoicesTests):
])
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_repr(self):
def test_repr(self) -> None:
self.assertEqual(repr(self.STATUS), "Choices" + repr((
('DRAFT', 'DRAFT', 'is draft'),
('PUBLISHED', 'PUBLISHED', 'is published'),
('DELETED', 'DELETED', 'DELETED'),
)))
def test_contains_value(self):
def test_contains_value(self) -> None:
self.assertTrue('PUBLISHED' in self.STATUS)
self.assertTrue('DRAFT' in self.STATUS)
# This should be True, because both the display value
# and the internal representation are both DELETED.
self.assertTrue('DELETED' in self.STATUS)
def test_doesnt_contain_value(self):
def test_doesnt_contain_value(self) -> None:
self.assertFalse('UNPUBLISHED' in self.STATUS)
def test_doesnt_contain_display_value(self):
def test_doesnt_contain_display_value(self) -> None:
self.assertFalse('is draft' in self.STATUS)
def test_composability(self):
def test_composability(self) -> None:
self.assertEqual(
Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'),
self.STATUS
@ -154,11 +167,17 @@ class LabelChoicesTests(ChoicesTests):
self.STATUS
)
def test_option_groups(self):
c = Choices(
('group a', [(1, 'one'), (2, 'two')]),
['group b', ((3, 'three'),)]
)
def test_option_groups(self) -> None:
if TYPE_CHECKING:
c = Choices[int](
('group a', [(1, 'one'), (2, 'two')]),
('group b', ((3, 'three'),))
)
else:
c = Choices(
('group a', [(1, 'one'), (2, 'two')]),
['group b', ((3, 'three'),)]
)
self.assertEqual(
list(c),
[
@ -168,65 +187,65 @@ class LabelChoicesTests(ChoicesTests):
)
class IdentifierChoicesTests(ChoicesTests):
def setUp(self):
class IdentifierChoicesTests(TestCase, ChoicesTestsMixin[int]):
def setUp(self) -> None:
self.STATUS = Choices(
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'),
(2, 'DELETED', 'is deleted'))
def test_iteration(self):
def test_iteration(self) -> None:
self.assertEqual(tuple(self.STATUS), (
(0, 'is draft'),
(1, 'is published'),
(2, 'is deleted'),
))
def test_reversed(self):
def test_reversed(self) -> None:
self.assertEqual(tuple(reversed(self.STATUS)), (
(2, 'is deleted'),
(1, 'is published'),
(0, 'is draft'),
))
def test_indexing(self):
def test_indexing(self) -> None:
self.assertEqual(self.STATUS[1], 'is published')
def test_getattr(self):
def test_getattr(self) -> None:
self.assertEqual(self.STATUS.DRAFT, 0)
def test_len(self):
def test_len(self) -> None:
self.assertEqual(len(self.STATUS), 3)
def test_repr(self):
def test_repr(self) -> None:
self.assertEqual(repr(self.STATUS), "Choices" + repr((
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'),
(2, 'DELETED', 'is deleted'),
)))
def test_contains_value(self):
def test_contains_value(self) -> None:
self.assertTrue(0 in self.STATUS)
self.assertTrue(1 in self.STATUS)
self.assertTrue(2 in self.STATUS)
def test_doesnt_contain_value(self):
def test_doesnt_contain_value(self) -> None:
self.assertFalse(3 in self.STATUS)
def test_doesnt_contain_display_value(self):
self.assertFalse('is draft' in self.STATUS)
def test_doesnt_contain_display_value(self) -> None:
self.assertFalse('is draft' in self.STATUS) # type: ignore[operator]
def test_doesnt_contain_python_attr(self):
self.assertFalse('PUBLISHED' in self.STATUS)
def test_doesnt_contain_python_attr(self) -> None:
self.assertFalse('PUBLISHED' in self.STATUS) # type: ignore[operator]
def test_equality(self):
def test_equality(self) -> None:
self.assertEqual(self.STATUS, Choices(
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'),
(2, 'DELETED', 'is deleted')
))
def test_inequality(self):
def test_inequality(self) -> None:
self.assertNotEqual(self.STATUS, [
(0, 'DRAFT', 'is draft'),
(1, 'PUBLISHED', 'is published'),
@ -234,7 +253,7 @@ class IdentifierChoicesTests(ChoicesTests):
])
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
def test_composability(self):
def test_composability(self) -> None:
self.assertEqual(
Choices(
(0, 'DRAFT', 'is draft'),
@ -265,11 +284,17 @@ class IdentifierChoicesTests(ChoicesTests):
self.STATUS
)
def test_option_groups(self):
c = Choices(
('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]),
['group b', ((3, 'THREE', 'three'),)]
)
def test_option_groups(self) -> None:
if TYPE_CHECKING:
c = Choices[int](
('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]),
('group b', ((3, 'THREE', 'three'),))
)
else:
c = Choices(
('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]),
['group b', ((3, 'THREE', 'three'),)]
)
self.assertEqual(
list(c),
[
@ -281,26 +306,26 @@ class IdentifierChoicesTests(ChoicesTests):
class SubsetChoicesTest(TestCase):
def setUp(self):
self.choices = Choices(
def setUp(self) -> None:
self.choices = Choices[int](
(0, 'a', 'A'),
(1, 'b', 'B'),
)
def test_nonexistent_identifiers_raise(self):
def test_nonexistent_identifiers_raise(self) -> None:
with self.assertRaises(ValueError):
self.choices.subset('a', 'c')
def test_solo_nonexistent_identifiers_raise(self):
def test_solo_nonexistent_identifiers_raise(self) -> None:
with self.assertRaises(ValueError):
self.choices.subset('c')
def test_empty_subset_passes(self):
def test_empty_subset_passes(self) -> None:
subset = self.choices.subset()
self.assertEqual(subset, Choices())
def test_subset_returns_correct_subset(self):
def test_subset_returns_correct_subset(self) -> None:
subset = self.choices.subset('a')
self.assertEqual(subset, Choices((0, 'a', 'A')))

View file

@ -1,5 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from django.core.cache import cache
from django.core.exceptions import FieldError
from django.db import models
@ -7,7 +9,7 @@ from django.db.models.fields.files import FieldFile
from django.test import TestCase
from model_utils import FieldTracker
from model_utils.tracker import DescriptorWrapper
from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker
from tests.models import (
InheritedModelTracked,
InheritedTracked,
@ -26,12 +28,18 @@ from tests.models import (
TrackerTimeStamped,
)
if TYPE_CHECKING:
MixinBase = TestCase
else:
MixinBase = object
class FieldTrackerTestCase(TestCase):
tracker = None
class FieldTrackerMixin(MixinBase):
def assertHasChanged(self, *, tracker=None, **kwargs):
tracker: FieldInstanceTracker
instance: models.Model
def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None:
tracker = self.tracker
for field, value in kwargs.items():
@ -41,49 +49,57 @@ class FieldTrackerTestCase(TestCase):
else:
self.assertEqual(tracker.has_changed(field), value)
def assertPrevious(self, *, tracker=None, **kwargs):
def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None:
tracker = self.tracker
for field, value in kwargs.items():
self.assertEqual(tracker.previous(field), value)
def assertChanged(self, *, tracker=None, **kwargs):
def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None:
tracker = self.tracker
self.assertEqual(tracker.changed(), kwargs)
def assertCurrent(self, *, tracker=None, **kwargs):
def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
if tracker is None:
tracker = self.tracker
self.assertEqual(tracker.current(), kwargs)
def update_instance(self, **kwargs):
def update_instance(self, **kwargs: Any) -> None:
for field, value in kwargs.items():
setattr(self.instance, field, value)
self.instance.save()
class FieldTrackerCommonTests:
class FieldTrackerCommonMixin(FieldTrackerMixin):
def test_pre_save_previous(self):
instance: (
Tracked | TrackedNotDefault | TrackedMultiple
| ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple
| TrackedAbstract
)
def test_pre_save_previous(self) -> None:
self.assertPrevious(name=None, number=None)
self.instance.name = 'new age'
self.instance.number = 8
self.assertPrevious(name=None, number=None)
class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
class FieldTrackerTests(FieldTrackerCommonMixin, TestCase):
tracked_class: type[models.Model] = Tracked
tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked
instance: Tracked | ModelTracked | TrackedAbstract
def setUp(self):
def setUp(self) -> None:
self.instance = self.tracked_class()
self.tracker = self.instance.tracker
def test_descriptor(self):
self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker))
def test_descriptor(self) -> None:
tracker = self.tracked_class.tracker
self.assertTrue(isinstance(tracker, FieldTracker))
def test_pre_save_changed(self):
def test_pre_save_changed(self) -> None:
self.assertChanged(name=None)
self.instance.name = 'new age'
self.assertChanged(name=None)
@ -94,7 +110,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.mutable = [1, 2, 3]
self.assertChanged(name=None, number=None, mutable=None)
def test_pre_save_has_changed(self):
def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(name=True, number=False, mutable=False)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=False, mutable=False)
@ -103,12 +119,12 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.mutable = [1, 2, 3]
self.assertHasChanged(name=True, number=True, mutable=True)
def test_save_with_args(self):
def test_save_with_args(self) -> None:
self.instance.number = 1
self.instance.save(False, False, None, None)
self.assertChanged()
def test_first_save(self):
def test_first_save(self) -> None:
self.assertHasChanged(name=True, number=False, mutable=False)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='', number=None, id=None, mutable=None)
@ -129,7 +145,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
with self.assertRaises(ValueError):
self.instance.save(update_fields=['number'])
def test_post_save_has_changed(self):
def test_post_save_has_changed(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertHasChanged(name=False, number=False, mutable=False)
self.instance.name = 'new age'
@ -141,14 +157,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.name = 'retro'
self.assertHasChanged(name=False, number=True, mutable=True)
def test_post_save_previous(self):
def test_post_save_previous(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.instance.name = 'new age'
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
self.instance.mutable[1] = 4
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
def test_post_save_changed(self):
def test_post_save_changed(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertChanged()
self.instance.name = 'new age'
@ -162,7 +178,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.mutable = [1, 2, 3]
self.assertChanged(number=4)
def test_current(self):
def test_current(self) -> None:
self.assertCurrent(id=None, name='', number=None, mutable=None)
self.instance.name = 'new age'
self.assertCurrent(id=None, name='new age', number=None, mutable=None)
@ -175,7 +191,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.save()
self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3])
def test_update_fields(self):
def test_update_fields(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertChanged()
self.instance.name = 'new age'
@ -198,7 +214,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertEqual(in_db.number, self.instance.number)
self.assertEqual(in_db.mutable, self.instance.mutable)
def test_refresh_from_db(self):
def test_refresh_from_db(self) -> None:
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.tracked_class.objects.filter(pk=self.instance.pk).update(
name='new age', number=8, mutable=[3, 2, 1])
@ -214,11 +230,12 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.refresh_from_db()
self.assertChanged()
def test_with_deferred(self):
def test_with_deferred(self) -> None:
self.instance.name = 'new age'
self.instance.number = 1
self.instance.save()
item = self.tracked_class.objects.only('name').first()
assert item is not None
self.assertTrue(item.get_deferred_fields())
# has_changed() returns False for deferred fields, without un-deferring them.
@ -234,6 +251,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
# examining a deferred field un-defers it
item = self.tracked_class.objects.only('name').first()
assert item is not None
self.assertEqual(item.number, 1)
self.assertTrue('number' not in item.get_deferred_fields())
self.assertEqual(item.tracker.previous('number'), 1)
@ -252,6 +270,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
if self.tracked_class == Tracked:
item = self.tracked_class.objects.only('name').first()
assert item is not None
item.number = 2
# previous() fetches correct value from database after deferred field is assigned
@ -268,7 +287,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
class FieldTrackerMultipleInstancesTests(TestCase):
def test_with_deferred_fields_access_multiple(self):
def test_with_deferred_fields_access_multiple(self) -> None:
Tracked.objects.create(pk=1, name='foo', number=1)
Tracked.objects.create(pk=2, name='bar', number=2)
@ -278,16 +297,16 @@ class FieldTrackerMultipleInstancesTests(TestCase):
instance.name
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
FieldTrackerCommonTests):
class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase):
tracked_class: type[models.Model] = TrackedNotDefault
tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault
instance: TrackedNotDefault | ModelTrackedNotDefault
def setUp(self):
def setUp(self) -> None:
self.instance = self.tracked_class()
self.tracker = self.instance.name_tracker
def test_pre_save_changed(self):
def test_pre_save_changed(self) -> None:
self.assertChanged(name=None)
self.instance.name = 'new age'
self.assertChanged(name=None)
@ -296,7 +315,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.instance.name = ''
self.assertChanged(name=None)
def test_first_save(self):
def test_first_save(self) -> None:
self.assertHasChanged(name=True, number=None)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='')
@ -308,14 +327,14 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.assertCurrent(name='retro')
self.assertChanged(name=None)
def test_pre_save_has_changed(self):
def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(name=True, number=None)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=None)
self.instance.number = 7
self.assertHasChanged(name=True, number=None)
def test_post_save_has_changed(self):
def test_post_save_has_changed(self) -> None:
self.update_instance(name='retro', number=4)
self.assertHasChanged(name=False, number=None)
self.instance.name = 'new age'
@ -325,12 +344,12 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.instance.name = 'retro'
self.assertHasChanged(name=False, number=None)
def test_post_save_previous(self):
def test_post_save_previous(self) -> None:
self.update_instance(name='retro', number=4)
self.instance.name = 'new age'
self.assertPrevious(name='retro', number=None)
def test_post_save_changed(self):
def test_post_save_changed(self) -> None:
self.update_instance(name='retro', number=4)
self.assertChanged()
self.instance.name = 'new age'
@ -340,7 +359,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.instance.name = 'retro'
self.assertChanged()
def test_current(self):
def test_current(self) -> None:
self.assertCurrent(name='')
self.instance.name = 'new age'
self.assertCurrent(name='new age')
@ -349,7 +368,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.instance.save()
self.assertCurrent(name='new age')
def test_update_fields(self):
def test_update_fields(self) -> None:
self.update_instance(name='retro', number=4)
self.assertChanged()
self.instance.name = 'new age'
@ -358,15 +377,16 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
self.assertChanged()
class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase):
tracked_class = TrackedNonFieldAttr
instance: TrackedNonFieldAttr
def setUp(self):
def setUp(self) -> None:
self.instance = self.tracked_class()
self.tracker = self.instance.tracker
def test_previous(self):
def test_previous(self) -> None:
self.assertPrevious(rounded=None)
self.instance.number = 7.5
self.assertPrevious(rounded=None)
@ -377,7 +397,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
self.instance.save()
self.assertPrevious(rounded=7)
def test_has_changed(self):
def test_has_changed(self) -> None:
self.assertHasChanged(rounded=False)
self.instance.number = 7.5
self.assertHasChanged(rounded=True)
@ -388,7 +408,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
self.instance.number = 7.8
self.assertHasChanged(rounded=False)
def test_changed(self):
def test_changed(self) -> None:
self.assertChanged()
self.instance.number = 7.5
self.assertPrevious(rounded=None)
@ -401,7 +421,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
self.instance.save()
self.assertPrevious()
def test_current(self):
def test_current(self) -> None:
self.assertCurrent(rounded=None)
self.instance.number = 7.5
self.assertCurrent(rounded=8)
@ -409,17 +429,17 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
self.assertCurrent(rounded=8)
class FieldTrackedModelMultiTests(FieldTrackerTestCase,
FieldTrackerCommonTests):
class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase):
tracked_class: type[models.Model] = TrackedMultiple
tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple
instance: TrackedMultiple | ModelTrackedMultiple
def setUp(self):
def setUp(self) -> None:
self.instance = self.tracked_class()
self.trackers = [self.instance.name_tracker,
self.instance.number_tracker]
def test_pre_save_changed(self):
def test_pre_save_changed(self) -> None:
self.tracker = self.instance.name_tracker
self.assertChanged(name=None)
self.instance.name = 'new age'
@ -435,7 +455,7 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.instance.number = 8
self.assertChanged(number=None)
def test_pre_save_has_changed(self):
def test_pre_save_has_changed(self) -> None:
self.tracker = self.instance.name_tracker
self.assertHasChanged(name=True, number=None)
self.instance.name = 'new age'
@ -445,12 +465,12 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.instance.name = 'new age'
self.assertHasChanged(name=None, number=False)
def test_pre_save_previous(self):
def test_pre_save_previous(self) -> None:
for tracker in self.trackers:
self.tracker = tracker
super().test_pre_save_previous()
def test_post_save_has_changed(self):
def test_post_save_has_changed(self) -> None:
self.update_instance(name='retro', number=4)
self.assertHasChanged(tracker=self.trackers[0], name=False, number=None)
self.assertHasChanged(tracker=self.trackers[1], name=None, number=False)
@ -465,14 +485,14 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.assertHasChanged(tracker=self.trackers[0], name=False, number=None)
self.assertHasChanged(tracker=self.trackers[1], name=None, number=False)
def test_post_save_previous(self):
def test_post_save_previous(self) -> None:
self.update_instance(name='retro', number=4)
self.instance.name = 'new age'
self.instance.number = 8
self.assertPrevious(tracker=self.trackers[0], name='retro', number=None)
self.assertPrevious(tracker=self.trackers[1], name=None, number=4)
def test_post_save_changed(self):
def test_post_save_changed(self) -> None:
self.update_instance(name='retro', number=4)
self.assertChanged(tracker=self.trackers[0])
self.assertChanged(tracker=self.trackers[1])
@ -487,7 +507,7 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.assertChanged(tracker=self.trackers[0])
self.assertChanged(tracker=self.trackers[1])
def test_current(self):
def test_current(self) -> None:
self.assertCurrent(tracker=self.trackers[0], name='')
self.assertCurrent(tracker=self.trackers[1], number=None)
self.instance.name = 'new age'
@ -501,88 +521,97 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
self.assertCurrent(tracker=self.trackers[1], number=8)
class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
class FieldTrackerForeignKeyMixin(FieldTrackerMixin):
fk_class: type[models.Model] = Tracked
tracked_class: type[models.Model] = TrackedFK
fk_class: type[Tracked | ModelTracked]
tracked_class: type[TrackedFK | ModelTrackedFK]
instance: TrackedFK | ModelTrackedFK
def setUp(self):
def setUp(self) -> None:
self.old_fk = self.fk_class.objects.create(number=8)
self.instance = self.tracked_class.objects.create(fk=self.old_fk)
self.instance = self.tracked_class.objects.create(fk=self.old_fk) # type: ignore[misc]
def test_default(self):
def test_default(self) -> None:
self.tracker = self.instance.tracker
self.assertChanged()
self.assertPrevious()
self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id)
self.instance.fk = self.fk_class.objects.create(number=8)
self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment]
self.assertChanged(fk_id=self.old_fk.id)
self.assertPrevious(fk_id=self.old_fk.id)
self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id)
def test_custom(self):
def test_custom(self) -> None:
self.tracker = self.instance.custom_tracker
self.assertChanged()
self.assertPrevious()
self.assertCurrent(fk_id=self.old_fk.id)
self.instance.fk = self.fk_class.objects.create(number=8)
self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment]
self.assertChanged(fk_id=self.old_fk.id)
self.assertPrevious(fk_id=self.old_fk.id)
self.assertCurrent(fk_id=self.instance.fk_id)
def test_custom_without_id(self):
def test_custom_without_id(self) -> None:
with self.assertNumQueries(1):
self.tracked_class.objects.get()
self.tracker = self.instance.custom_tracker_without_id
self.assertChanged()
self.assertPrevious()
self.assertCurrent(fk=self.old_fk.id)
self.instance.fk = self.fk_class.objects.create(number=8)
self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment]
self.assertChanged(fk=self.old_fk.id)
self.assertPrevious(fk=self.old_fk.id)
self.assertCurrent(fk=self.instance.fk_id)
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase):
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError."""
class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
fk_class = Tracked
tracked_class = TrackedFK
def setUp(self):
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase):
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError."""
fk_class = Tracked
tracked_class = TrackedFK
instance: TrackedFK
def setUp(self) -> None:
model_tracked = self.fk_class.objects.create(name="", number=0)
self.instance = self.tracked_class.objects.create(fk=model_tracked)
def test_default(self):
def test_default(self) -> None:
self.tracker = self.instance.tracker
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
def test_custom(self):
def test_custom(self) -> None:
self.tracker = self.instance.custom_tracker
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
def test_custom_without_id(self):
def test_custom_without_id(self) -> None:
self.tracker = self.instance.custom_tracker_without_id
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
class FieldTrackerTimeStampedTests(FieldTrackerTestCase):
class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase):
fk_class = Tracked
tracked_class = TrackerTimeStamped
instance: TrackerTimeStamped
def setUp(self):
def setUp(self) -> None:
self.instance = self.tracked_class.objects.create(name='old', number=1)
self.tracker = self.instance.tracker
def test_set_modified_on_save(self):
def test_set_modified_on_save(self) -> None:
old_modified = self.instance.modified
self.instance.name = 'new'
self.instance.save()
self.assertGreater(self.instance.modified, old_modified)
self.assertChanged()
def test_set_modified_on_save_update_fields(self):
def test_set_modified_on_save_update_fields(self) -> None:
old_modified = self.instance.modified
self.instance.name = 'new'
self.instance.save(update_fields=('name',))
@ -594,7 +623,7 @@ class InheritedFieldTrackerTests(FieldTrackerTests):
tracked_class = InheritedTracked
def test_child_fields_not_tracked(self):
def test_child_fields_not_tracked(self) -> None:
self.name2 = 'test'
self.assertEqual(self.tracker.previous('name2'), None)
self.assertRaises(FieldError, self.tracker.has_changed, 'name2')
@ -605,17 +634,18 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests):
tracked_class = InheritedTrackedFK
class FieldTrackerFileFieldTests(FieldTrackerTestCase):
class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase):
tracked_class = TrackedFileField
instance: TrackedFileField
def setUp(self):
def setUp(self) -> None:
self.instance = self.tracked_class()
self.tracker = self.instance.tracker
self.some_file = 'something.txt'
self.another_file = 'another.txt'
def test_saved_data_without_instance(self):
def test_saved_data_without_instance(self) -> None:
"""
Tests that instance won't get copied by the Field Tracker.
@ -629,27 +659,27 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
self.assertEqual(self.tracker.saved_data, {})
self.update_instance(some_file=self.some_file)
field_file_copy = self.tracker.saved_data.get('some_file')
self.assertIsNotNone(field_file_copy)
assert field_file_copy is not None
self.assertEqual(field_file_copy.__getstate__().get('instance'), None)
self.assertEqual(self.instance.some_file.instance, self.instance)
self.assertIsInstance(self.instance.some_file, FieldFile)
def test_pre_save_changed(self):
def test_pre_save_changed(self) -> None:
self.assertChanged(some_file=None)
self.instance.some_file = self.some_file
self.assertChanged(some_file=None)
def test_pre_save_has_changed(self):
def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(some_file=True)
self.instance.some_file = self.some_file
self.assertHasChanged(some_file=True)
def test_pre_save_previous(self):
def test_pre_save_previous(self) -> None:
self.assertPrevious(some_file=None)
self.instance.some_file = self.some_file
self.assertPrevious(some_file=None)
def test_post_save_changed(self):
def test_post_save_changed(self) -> None:
self.update_instance(some_file=self.some_file)
self.assertChanged()
previous_file = self.instance.some_file
@ -667,7 +697,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
some_file=previous_file,
)
def test_post_save_has_changed(self):
def test_post_save_has_changed(self) -> None:
self.update_instance(some_file=self.some_file)
self.assertHasChanged(some_file=False)
self.instance.some_file = self.another_file
@ -687,7 +717,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
some_file=True,
)
def test_post_save_previous(self):
def test_post_save_previous(self) -> None:
self.update_instance(some_file=self.some_file)
previous_file = self.instance.some_file
self.instance.some_file = self.another_file
@ -707,7 +737,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
some_file=previous_file,
)
def test_current(self):
def test_current(self) -> None:
self.assertCurrent(some_file=self.instance.some_file, id=None)
self.instance.some_file = self.some_file
self.assertCurrent(some_file=self.instance.some_file, id=None)
@ -730,9 +760,10 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
class ModelTrackerTests(FieldTrackerTests):
tracked_class: type[models.Model] = ModelTracked
tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked
instance: ModelTracked
def test_cache_compatible(self):
def test_cache_compatible(self) -> None:
cache.set('key', self.instance)
instance = cache.get('key')
instance.number = 1
@ -742,7 +773,7 @@ class ModelTrackerTests(FieldTrackerTests):
instance.number = 2
self.assertHasChanged(number=True)
def test_pre_save_changed(self):
def test_pre_save_changed(self) -> None:
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged()
@ -753,7 +784,7 @@ class ModelTrackerTests(FieldTrackerTests):
self.instance.mutable = [1, 2, 3]
self.assertChanged()
def test_first_save(self):
def test_first_save(self) -> None:
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='', number=None, id=None, mutable=None)
@ -774,7 +805,7 @@ class ModelTrackerTests(FieldTrackerTests):
with self.assertRaises(ValueError):
self.instance.save(update_fields=['number'])
def test_pre_save_has_changed(self):
def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(name=True, number=True)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=True)
@ -786,7 +817,7 @@ class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):
tracked_class = ModelTrackedNotDefault
def test_first_save(self):
def test_first_save(self) -> None:
self.assertHasChanged(name=True, number=True)
self.assertPrevious(name=None, number=None)
self.assertCurrent(name='')
@ -798,14 +829,14 @@ class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):
self.assertCurrent(name='retro')
self.assertChanged()
def test_pre_save_has_changed(self):
def test_pre_save_has_changed(self) -> None:
self.assertHasChanged(name=True, number=True)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=True)
self.instance.number = 7
self.assertHasChanged(name=True, number=True)
def test_pre_save_changed(self):
def test_pre_save_changed(self) -> None:
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged()
@ -819,7 +850,7 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
tracked_class = ModelTrackedMultiple
def test_pre_save_has_changed(self):
def test_pre_save_has_changed(self) -> None:
self.tracker = self.instance.name_tracker
self.assertHasChanged(name=True, number=True)
self.instance.name = 'new age'
@ -829,7 +860,7 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=True)
def test_pre_save_changed(self):
def test_pre_save_changed(self) -> None:
self.tracker = self.instance.name_tracker
self.assertChanged()
self.instance.name = 'new age'
@ -846,12 +877,13 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
self.assertChanged()
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
fk_class = ModelTracked
tracked_class = ModelTrackedFK
instance: ModelTrackedFK
def test_custom_without_id(self):
def test_custom_without_id(self) -> None:
with self.assertNumQueries(2):
self.tracked_class.objects.get()
self.tracker = self.instance.custom_tracker_without_id
@ -869,7 +901,7 @@ class InheritedModelTrackerTests(ModelTrackerTests):
tracked_class = InheritedModelTracked
def test_child_fields_not_tracked(self):
def test_child_fields_not_tracked(self) -> None:
self.name2 = 'test'
self.assertEqual(self.tracker.previous('name2'), None)
self.assertTrue(self.tracker.has_changed('name2'))
@ -882,19 +914,19 @@ class AbstractModelTrackerTests(ModelTrackerTests):
class TrackerContextDecoratorTests(TestCase):
def setUp(self):
def setUp(self) -> None:
self.instance = Tracked.objects.create(number=1)
self.tracker = self.instance.tracker
def assertChanged(self, *fields):
def assertChanged(self, *fields: str) -> None:
for f in fields:
self.assertTrue(self.tracker.has_changed(f))
def assertNotChanged(self, *fields):
def assertNotChanged(self, *fields: str) -> None:
for f in fields:
self.assertFalse(self.tracker.has_changed(f))
def test_context_manager(self):
def test_context_manager(self) -> None:
with self.tracker:
with self.tracker:
self.instance.name = 'new'
@ -905,7 +937,7 @@ class TrackerContextDecoratorTests(TestCase):
self.assertNotChanged('name')
def test_context_manager_fields(self):
def test_context_manager_fields(self) -> None:
with self.tracker('number'):
with self.tracker('number', 'name'):
self.instance.name = 'new'
@ -918,10 +950,10 @@ class TrackerContextDecoratorTests(TestCase):
self.assertNotChanged('number', 'name')
def test_tracker_decorator(self):
def test_tracker_decorator(self) -> None:
@Tracked.tracker
def tracked_method(obj):
def tracked_method(obj: Tracked) -> None:
obj.name = 'new'
self.assertChanged('name')
@ -929,10 +961,10 @@ class TrackerContextDecoratorTests(TestCase):
self.assertNotChanged('name')
def test_tracker_decorator_fields(self):
def test_tracker_decorator_fields(self) -> None:
@Tracked.tracker(fields=['name'])
def tracked_method(obj):
def tracked_method(obj: Tracked) -> None:
obj.name = 'new'
obj.number += 1
self.assertChanged('name', 'number')
@ -942,7 +974,7 @@ class TrackerContextDecoratorTests(TestCase):
self.assertChanged('number')
self.assertNotChanged('name')
def test_tracker_context_with_save(self):
def test_tracker_context_with_save(self) -> None:
with self.tracker:
self.instance.name = 'new'

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime, timezone
import time_machine
@ -8,33 +10,33 @@ from tests.models import DoubleMonitored, Monitored, MonitorWhen, MonitorWhenEmp
class MonitorFieldTests(TestCase):
def setUp(self):
def setUp(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)):
self.instance = Monitored(name='Charlie')
self.created = self.instance.name_changed
def test_save_no_change(self):
def test_save_no_change(self) -> None:
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed(self):
def test_save_changed(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
self.instance.name = 'Maria'
self.instance.save()
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
def test_double_save(self):
def test_double_save(self) -> None:
self.instance.name = 'Jose'
self.instance.save()
changed = self.instance.name_changed
self.instance.save()
self.assertEqual(self.instance.name_changed, changed)
def test_no_monitor_arg(self):
def test_no_monitor_arg(self) -> None:
with self.assertRaises(TypeError):
MonitorField()
MonitorField() # type: ignore[call-arg]
def test_monitor_default_is_none_when_nullable(self):
def test_monitor_default_is_none_when_nullable(self) -> None:
self.assertIsNone(self.instance.name_changed_nullable)
expected_datetime = datetime(2022, 1, 18, 12, 0, 0, tzinfo=timezone.utc)
@ -49,33 +51,33 @@ class MonitorWhenFieldTests(TestCase):
"""
Will record changes only when name is 'Jose' or 'Maria'
"""
def setUp(self):
def setUp(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)):
self.instance = MonitorWhen(name='Charlie')
self.created = self.instance.name_changed
def test_save_no_change(self):
def test_save_no_change(self) -> None:
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Jose(self):
def test_save_changed_to_Jose(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
self.instance.name = 'Jose'
self.instance.save()
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
def test_save_changed_to_Maria(self):
def test_save_changed_to_Maria(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
self.instance.name = 'Maria'
self.instance.save()
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
def test_save_changed_to_Pedro(self):
def test_save_changed_to_Pedro(self) -> None:
self.instance.name = 'Pedro'
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_double_save(self):
def test_double_save(self) -> None:
self.instance.name = 'Jose'
self.instance.save()
changed = self.instance.name_changed
@ -87,20 +89,20 @@ class MonitorWhenEmptyFieldTests(TestCase):
"""
Monitor should never be updated id when is an empty list.
"""
def setUp(self):
def setUp(self) -> None:
self.instance = MonitorWhenEmpty(name='Charlie')
self.created = self.instance.name_changed
def test_save_no_change(self):
def test_save_no_change(self) -> None:
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Jose(self):
def test_save_changed_to_Jose(self) -> None:
self.instance.name = 'Jose'
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
def test_save_changed_to_Maria(self):
def test_save_changed_to_Maria(self) -> None:
self.instance.name = 'Maria'
self.instance.save()
self.assertEqual(self.instance.name_changed, self.created)
@ -108,18 +110,18 @@ class MonitorWhenEmptyFieldTests(TestCase):
class MonitorDoubleFieldTests(TestCase):
def setUp(self):
def setUp(self) -> None:
DoubleMonitored.objects.create(name='Charlie', name2='Charlie2')
def test_recursion_error_with_only(self):
def test_recursion_error_with_only(self) -> None:
# Any field passed to only() is generating a recursion error
list(DoubleMonitored.objects.only('id'))
def test_recursion_error_with_defer(self):
def test_recursion_error_with_defer(self) -> None:
# Only monitored fields passed to defer() are failing
list(DoubleMonitored.objects.defer('name'))
def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self):
def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self) -> None:
obj = DoubleMonitored.objects.defer('name').get(name='Charlie')
with time_machine.travel(datetime(2016, 12, 1, tzinfo=timezone.utc)):
obj.name = 'Charlie2'

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from django.test import TestCase
from tests.models import Article, SplitFieldAbstractParent
@ -7,62 +9,62 @@ class SplitFieldTests(TestCase):
full_text = 'summary\n\n<!-- split -->\n\nmore'
excerpt = 'summary\n'
def setUp(self):
def setUp(self) -> None:
self.post = Article.objects.create(
title='example post', body=self.full_text)
def test_unicode_content(self):
def test_unicode_content(self) -> None:
self.assertEqual(str(self.post.body), self.full_text)
def test_excerpt(self):
def test_excerpt(self) -> None:
self.assertEqual(self.post.body.excerpt, self.excerpt)
def test_content(self):
def test_content(self) -> None:
self.assertEqual(self.post.body.content, self.full_text)
def test_has_more(self):
def test_has_more(self) -> None:
self.assertTrue(self.post.body.has_more)
def test_not_has_more(self):
def test_not_has_more(self) -> None:
post = Article.objects.create(title='example 2',
body='some text\n\nsome more\n')
self.assertFalse(post.body.has_more)
def test_load_back(self):
def test_load_back(self) -> None:
post = Article.objects.get(pk=self.post.pk)
self.assertEqual(post.body.content, self.post.body.content)
self.assertEqual(post.body.excerpt, self.post.body.excerpt)
def test_assign_to_body(self):
def test_assign_to_body(self) -> None:
new_text = 'different\n\n<!-- split -->\n\nother'
self.post.body = new_text
self.post.save()
self.assertEqual(str(self.post.body), new_text)
def test_assign_to_content(self):
def test_assign_to_content(self) -> None:
new_text = 'different\n\n<!-- split -->\n\nother'
self.post.body.content = new_text
self.post.save()
self.assertEqual(str(self.post.body), new_text)
def test_assign_to_excerpt(self):
def test_assign_to_excerpt(self) -> None:
with self.assertRaises(AttributeError):
self.post.body.excerpt = 'this should fail'
self.post.body.excerpt = 'this should fail' # type: ignore[misc]
def test_access_via_class(self):
def test_access_via_class(self) -> None:
with self.assertRaises(AttributeError):
Article.body
def test_assign_splittext(self):
def test_assign_splittext(self) -> None:
a = Article(title='Some Title')
a.body = self.post.body
self.assertEqual(a.body.excerpt, 'summary\n')
def test_value_to_string(self):
def test_value_to_string(self) -> None:
f = self.post._meta.get_field('body')
self.assertEqual(f.value_to_string(self.post), self.full_text)
def test_abstract_inheritance(self):
def test_abstract_inheritance(self) -> None:
class Child(SplitFieldAbstractParent):
pass

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from django.test import TestCase
from model_utils.fields import StatusField
@ -11,22 +13,22 @@ from tests.models import (
class StatusFieldTests(TestCase):
def test_status_with_default_filled(self):
def test_status_with_default_filled(self) -> None:
instance = StatusFieldDefaultFilled()
self.assertEqual(instance.status, instance.STATUS.yes)
def test_status_with_default_not_filled(self):
def test_status_with_default_not_filled(self) -> None:
instance = StatusFieldDefaultNotFilled()
self.assertEqual(instance.status, instance.STATUS.no)
def test_no_check_for_status(self):
def test_no_check_for_status(self) -> None:
field = StatusField(no_check_for_status=True)
# this model has no STATUS attribute, so checking for it would error
field.prepare_class(Article)
def test_get_status_display(self):
def test_get_status_display(self) -> None:
instance = StatusFieldDefaultFilled()
self.assertEqual(instance.get_status_display(), "Yes")
def test_choices_name(self):
def test_choices_name(self) -> None:
StatusFieldChoicesName()

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from unittest.mock import Mock
from django.db.models import NOT_PROVIDED
@ -7,41 +9,41 @@ from model_utils.fields import UrlsafeTokenField
class UrlsaftTokenFieldTests(TestCase):
def test_editable_default(self):
def test_editable_default(self) -> None:
field = UrlsafeTokenField()
self.assertFalse(field.editable)
def test_editable(self):
def test_editable(self) -> None:
field = UrlsafeTokenField(editable=True)
self.assertTrue(field.editable)
def test_max_length_default(self):
def test_max_length_default(self) -> None:
field = UrlsafeTokenField()
self.assertEqual(field.max_length, 128)
def test_max_length(self):
def test_max_length(self) -> None:
field = UrlsafeTokenField(max_length=256)
self.assertEqual(field.max_length, 256)
def test_factory_default(self):
def test_factory_default(self) -> None:
field = UrlsafeTokenField()
self.assertIsNone(field._factory)
def test_factory_not_callable(self):
def test_factory_not_callable(self) -> None:
with self.assertRaises(TypeError):
UrlsafeTokenField(factory='INVALID')
UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type]
def test_get_default(self):
def test_get_default(self) -> None:
field = UrlsafeTokenField()
value = field.get_default()
self.assertEqual(len(value), field.max_length)
def test_get_default_with_non_default_max_length(self):
def test_get_default_with_non_default_max_length(self) -> None:
field = UrlsafeTokenField(max_length=64)
value = field.get_default()
self.assertEqual(len(value), 64)
def test_get_default_with_factory(self):
def test_get_default_with_factory(self) -> None:
token = 'SAMPLE_TOKEN'
factory = Mock(return_value=token)
field = UrlsafeTokenField(factory=factory)
@ -50,13 +52,13 @@ class UrlsaftTokenFieldTests(TestCase):
self.assertEqual(value, token)
factory.assert_called_once_with(field.max_length)
def test_no_default_param(self):
def test_no_default_param(self) -> None:
field = UrlsafeTokenField(default='DEFAULT')
self.assertIs(field.default, NOT_PROVIDED)
def test_deconstruct(self):
def test_factory():
pass
def test_deconstruct(self) -> None:
def test_factory(max_length: int) -> str:
assert False
instance = UrlsafeTokenField(factory=test_factory)
name, path, args, kwargs = instance.deconstruct()
new_instance = UrlsafeTokenField(*args, **kwargs)

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import uuid
from django.core.exceptions import ValidationError
@ -8,31 +10,31 @@ from model_utils.fields import UUIDField
class UUIDFieldTests(TestCase):
def test_uuid_version_default(self):
def test_uuid_version_default(self) -> None:
instance = UUIDField()
self.assertEqual(instance.default, uuid.uuid4)
def test_uuid_version_1(self):
def test_uuid_version_1(self) -> None:
instance = UUIDField(version=1)
self.assertEqual(instance.default, uuid.uuid1)
def test_uuid_version_2_error(self):
def test_uuid_version_2_error(self) -> None:
self.assertRaises(ValidationError, UUIDField, 'version', 2)
def test_uuid_version_3(self):
def test_uuid_version_3(self) -> None:
instance = UUIDField(version=3)
self.assertEqual(instance.default, uuid.uuid3)
def test_uuid_version_4(self):
def test_uuid_version_4(self) -> None:
instance = UUIDField(version=4)
self.assertEqual(instance.default, uuid.uuid4)
def test_uuid_version_5(self):
def test_uuid_version_5(self) -> None:
instance = UUIDField(version=5)
self.assertEqual(instance.default, uuid.uuid5)
def test_uuid_version_bellow_min(self):
def test_uuid_version_bellow_min(self) -> None:
self.assertRaises(ValidationError, UUIDField, 'version', 0)
def test_uuid_version_above_max(self):
def test_uuid_version_above_max(self) -> None:
self.assertRaises(ValidationError, UUIDField, 'version', 6)

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from django.db.models import Prefetch
from django.test import TestCase
@ -5,7 +7,7 @@ from tests.models import InheritanceManagerTestChild1, InheritanceManagerTestPar
class InheritanceIterableTest(TestCase):
def test_prefetch(self):
def test_prefetch(self) -> None:
qs = InheritanceManagerTestChild1.objects.all().prefetch_related(
Prefetch(
'normal_field',

View file

@ -1,6 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from django.db import models
from django.test import TestCase
from model_utils.managers import InheritanceManager
from tests.models import (
InheritanceManagerTestChild1,
InheritanceManagerTestChild2,
@ -14,19 +19,22 @@ from tests.models import (
TimeFrame,
)
if TYPE_CHECKING:
from django.db.models.fields.related_descriptors import RelatedManager
class InheritanceManagerTests(TestCase):
def setUp(self):
def setUp(self) -> None:
self.child1 = InheritanceManagerTestChild1.objects.create()
self.child2 = InheritanceManagerTestChild2.objects.create()
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
self.grandchild1_2 = \
InheritanceManagerTestGrandChild1_2.objects.create()
def get_manager(self):
def get_manager(self) -> InheritanceManager[InheritanceManagerTestParent]:
return InheritanceManagerTestParent.objects
def test_normal(self):
def test_normal(self) -> None:
children = {
InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestParent(pk=self.child2.pk),
@ -35,14 +43,14 @@ class InheritanceManagerTests(TestCase):
}
self.assertEqual(set(self.get_manager().all()), children)
def test_select_all_subclasses(self):
def test_select_all_subclasses(self) -> None:
children = {self.child1, self.child2}
children.add(self.grandchild1)
children.add(self.grandchild1_2)
self.assertEqual(
set(self.get_manager().select_subclasses()), children)
def test_select_subclasses_invalid_relation(self):
def test_select_subclasses_invalid_relation(self) -> None:
"""
If an invalid relation string is provided, we can provide the user
with a list which is valid, rather than just have the select_related()
@ -52,7 +60,7 @@ class InheritanceManagerTests(TestCase):
with self.assertRaisesRegex(ValueError, regex):
self.get_manager().select_subclasses('user')
def test_select_specific_subclasses(self):
def test_select_specific_subclasses(self) -> None:
children = {
self.child1,
InheritanceManagerTestParent(pk=self.child2.pk),
@ -67,7 +75,7 @@ class InheritanceManagerTests(TestCase):
children,
)
def test_select_specific_grandchildren(self):
def test_select_specific_grandchildren(self) -> None:
children = {
InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestParent(pk=self.child2.pk),
@ -83,7 +91,7 @@ class InheritanceManagerTests(TestCase):
children,
)
def test_children_and_grandchildren(self):
def test_children_and_grandchildren(self) -> None:
children = {
self.child1,
InheritanceManagerTestParent(pk=self.child2.pk),
@ -100,24 +108,24 @@ class InheritanceManagerTests(TestCase):
children,
)
def test_get_subclass(self):
def test_get_subclass(self) -> None:
self.assertEqual(
self.get_manager().get_subclass(pk=self.child1.pk),
self.child1)
def test_get_subclass_on_queryset(self):
def test_get_subclass_on_queryset(self) -> None:
self.assertEqual(
self.get_manager().all().get_subclass(pk=self.child1.pk),
self.child1)
def test_prior_select_related(self):
def test_prior_select_related(self) -> None:
with self.assertNumQueries(1):
obj = self.get_manager().select_related(
"inheritancemanagertestchild1").select_subclasses(
"inheritancemanagertestchild2").get(pk=self.child1.pk)
obj.inheritancemanagertestchild1
def test_manually_specifying_parent_fk_including_grandchildren(self):
def test_manually_specifying_parent_fk_including_grandchildren(self) -> None:
"""
given a Model which inherits from another Model, but also declares
the OneToOne link manually using `related_name` and `parent_link`,
@ -148,7 +156,7 @@ class InheritanceManagerTests(TestCase):
self.assertEqual(set(results.subclasses),
set(expected_related_names))
def test_manually_specifying_parent_fk_single_subclass(self):
def test_manually_specifying_parent_fk_single_subclass(self) -> None:
"""
Using a string related_name when the relation is manually defined
instead of implicit should still work in the same way.
@ -168,11 +176,11 @@ class InheritanceManagerTests(TestCase):
self.assertEqual(set(results.subclasses),
set(expected_related_names))
def test_filter_on_values_queryset(self):
def test_filter_on_values_queryset(self) -> None:
queryset = InheritanceManagerTestChild1.objects.values('id').filter(pk=self.child1.pk)
self.assertEqual(list(queryset), [{'id': self.child1.pk}])
def test_values_list_on_select_subclasses(self):
def test_values_list_on_select_subclasses(self) -> None:
"""
Using `select_subclasses` in conjunction with `values_list()` raised an
exception in `_get_sub_obj_recurse()` because the result of `values_list()`
@ -217,14 +225,14 @@ class InheritanceManagerTests(TestCase):
class InheritanceManagerUsingModelsTests(TestCase):
def setUp(self):
def setUp(self) -> None:
self.parent1 = InheritanceManagerTestParent.objects.create()
self.child1 = InheritanceManagerTestChild1.objects.create()
self.child2 = InheritanceManagerTestChild2.objects.create()
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create()
def test_select_subclass_by_child_model(self):
def test_select_subclass_by_child_model(self) -> None:
"""
Confirm that passing a child model works the same as passing the
select_related manually
@ -236,7 +244,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(objs.subclasses, objsmodels.subclasses)
self.assertEqual(list(objs), list(objsmodels))
def test_select_subclass_by_grandchild_model(self):
def test_select_subclass_by_grandchild_model(self) -> None:
"""
Confirm that passing a grandchild model works the same as passing the
select_related manually
@ -249,7 +257,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(objs.subclasses, objsmodels.subclasses)
self.assertEqual(list(objs), list(objsmodels))
def test_selecting_all_subclasses_specifically_grandchildren(self):
def test_selecting_all_subclasses_specifically_grandchildren(self) -> None:
"""
A bare select_subclasses() should achieve the same results as doing
select_subclasses and specifying all possible subclasses.
@ -266,7 +274,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses))
self.assertEqual(list(objs), list(objsmodels))
def test_selecting_all_subclasses_specifically_children(self):
def test_selecting_all_subclasses_specifically_children(self) -> None:
"""
A bare select_subclasses() should achieve the same results as doing
select_subclasses and specifying all possible subclasses.
@ -294,7 +302,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses))
self.assertEqual(list(objs), list(objsmodels))
def test_select_subclass_just_self(self):
def test_select_subclass_just_self(self) -> None:
"""
Passing in the same model as the manager/queryset is bound against
(ie: the root parent) should have no effect on the result set.
@ -310,7 +318,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
])
def test_select_subclass_invalid_related_model(self):
def test_select_subclass_invalid_related_model(self) -> None:
"""
Confirming that giving a stupid model doesn't work.
"""
@ -319,7 +327,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
InheritanceManagerTestParent.objects.select_subclasses(
TimeFrame).order_by('pk')
def test_mixing_strings_and_classes_with_grandchildren(self):
def test_mixing_strings_and_classes_with_grandchildren(self) -> None:
"""
Given arguments consisting of both strings and model classes,
ensure the right resolutions take place, accounting for the extra
@ -340,7 +348,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
]
self.assertEqual(list(objs), expecting2)
def test_mixing_strings_and_classes_with_children(self):
def test_mixing_strings_and_classes_with_children(self) -> None:
"""
Given arguments consisting of both strings and model classes,
ensure the right resolutions take place, walking down as far as
@ -362,7 +370,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
]
self.assertEqual(list(objs), expecting2)
def test_duplications(self):
def test_duplications(self) -> None:
"""
Check that even if the same thing is provided as a string and a model
that the right results are retrieved.
@ -379,7 +387,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
])
def test_child_doesnt_accidentally_get_parent(self):
def test_child_doesnt_accidentally_get_parent(self) -> None:
"""
Given a Child model which also has an InheritanceManager,
none of the returned objects should be Parent objects.
@ -392,7 +400,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
InheritanceManagerTestChild1(pk=self.grandchild1_2.pk),
], list(objs))
def test_manually_specifying_parent_fk_only_specific_child(self):
def test_manually_specifying_parent_fk_only_specific_child(self) -> None:
"""
given a Model which inherits from another Model, but also declares
the OneToOne link manually using `related_name` and `parent_link`,
@ -416,7 +424,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(results.subclasses),
set(expected_related_names))
def test_extras_descend(self):
def test_extras_descend(self) -> None:
"""
Ensure that extra(select=) values are copied onto sub-classes.
"""
@ -425,25 +433,25 @@ class InheritanceManagerUsingModelsTests(TestCase):
)
self.assertTrue(all(result.foo == (result.id + 1) for result in results))
def test_limit_to_specific_subclass(self):
def test_limit_to_specific_subclass(self) -> None:
child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3)
self.assertEqual([child3], list(results))
def test_limit_to_specific_subclass_with_custom_db_column(self):
def test_limit_to_specific_subclass_with_custom_db_column(self) -> None:
item = InheritanceManagerTestChild3_1.objects.create()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3_1)
self.assertEqual([item], list(results))
def test_limit_to_specific_grandchild_class(self):
def test_limit_to_specific_grandchild_class(self) -> None:
grandchild1 = InheritanceManagerTestGrandChild1.objects.get()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1)
self.assertEqual([grandchild1], list(results))
def test_limit_to_child_fetches_grandchildren_as_child_class(self):
def test_limit_to_child_fetches_grandchildren_as_child_class(self) -> None:
# Not sure if this is the desired behaviour...?
children = InheritanceManagerTestChild1.objects.all()
@ -451,7 +459,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(children), set(results))
def test_can_fetch_limited_class_grandchildren(self):
def test_can_fetch_limited_class_grandchildren(self) -> None:
# Not sure if this is the desired behaviour...?
children = InheritanceManagerTestChild1.objects.select_subclasses()
@ -459,7 +467,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set(children), set(results))
def test_selecting_multiple_instance_classes(self):
def test_selecting_multiple_instance_classes(self) -> None:
child3 = InheritanceManagerTestChild3.objects.create()
children1 = InheritanceManagerTestChild1.objects.all()
@ -467,7 +475,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual(set([child3] + list(children1)), set(results))
def test_selecting_multiple_instance_classes_including_grandchildren(self):
def test_selecting_multiple_instance_classes_including_grandchildren(self) -> None:
child3 = InheritanceManagerTestChild3.objects.create()
grandchild1 = InheritanceManagerTestGrandChild1.objects.get()
@ -475,7 +483,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
self.assertEqual({child3, grandchild1}, set(results))
def test_select_subclasses_interaction_with_instance_of(self):
def test_select_subclasses_interaction_with_instance_of(self) -> None:
child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3)
@ -484,7 +492,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
class InheritanceManagerRelatedTests(InheritanceManagerTests):
def setUp(self):
def setUp(self) -> None:
self.related = InheritanceManagerTestRelated.objects.create()
self.child1 = InheritanceManagerTestChild1.objects.create(
related=self.related)
@ -493,16 +501,16 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related)
self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create(related=self.related)
def get_manager(self):
def get_manager(self) -> RelatedManager[InheritanceManagerTestParent]: # type: ignore[override]
return self.related.imtests
def test_get_method_with_select_subclasses(self):
def test_get_method_with_select_subclasses(self) -> None:
self.assertEqual(
InheritanceManagerTestParent.objects.select_subclasses().get(
id=self.child1.id),
self.child1)
def test_get_method_with_select_subclasses_check_for_useless_join(self):
def test_get_method_with_select_subclasses_check_for_useless_join(self) -> None:
child4 = InheritanceManagerTestChild4.objects.create(related=self.related, other_onetoone=self.child1)
self.assertEqual(
str(InheritanceManagerTestChild4.objects.select_subclasses().filter(
@ -510,26 +518,26 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
str(InheritanceManagerTestChild4.objects.select_subclasses().select_related(None).filter(
id=child4.id).query))
def test_annotate_with_select_subclasses(self):
def test_annotate_with_select_subclasses(self) -> None:
qs = InheritanceManagerTestParent.objects.select_subclasses().annotate(
models.Count('id'))
self.assertEqual(qs.get(id=self.child1.id).id__count, 1)
def test_annotate_with_named_arguments_with_select_subclasses(self):
def test_annotate_with_named_arguments_with_select_subclasses(self) -> None:
qs = InheritanceManagerTestParent.objects.select_subclasses().annotate(
test_count=models.Count('id'))
self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
def test_annotate_before_select_subclasses(self):
def test_annotate_before_select_subclasses(self) -> None:
qs = InheritanceManagerTestParent.objects.annotate(
models.Count('id')).select_subclasses()
self.assertEqual(qs.get(id=self.child1.id).id__count, 1)
def test_annotate_with_named_arguments_before_select_subclasses(self):
def test_annotate_with_named_arguments_before_select_subclasses(self) -> None:
qs = InheritanceManagerTestParent.objects.annotate(
test_count=models.Count('id')).select_subclasses()
self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self):
def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self) -> None:
qs = InheritanceManagerTestParent.objects.select_subclasses()
self.assertEqual(qs.subclasses, qs._clone().subclasses)

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from django.test import TestCase
from tests.models import BoxJoinModel, JoinItemForeignKey
class JoinManagerTest(TestCase):
def setUp(self):
def setUp(self) -> None:
for i in range(20):
BoxJoinModel.objects.create(name=f'name_{i}')
@ -13,24 +15,24 @@ class JoinManagerTest(TestCase):
)
JoinItemForeignKey.objects.create(weight=20)
def test_self_join(self):
def test_self_join(self) -> None:
a_slice = BoxJoinModel.objects.all()[0:10]
with self.assertNumQueries(1):
result = a_slice.join()
self.assertEqual(result.count(), 10)
def test_self_join_with_where_statement(self):
def test_self_join_with_where_statement(self) -> None:
qs = BoxJoinModel.objects.filter(name='name_1')
result = qs.join()
self.assertEqual(result.count(), 1)
def test_join_with_other_qs(self):
def test_join_with_other_qs(self) -> None:
item_qs = JoinItemForeignKey.objects.filter(weight=10)
boxes = BoxJoinModel.objects.all().join(qs=item_qs)
self.assertEqual(boxes.count(), 1)
self.assertEqual(boxes[0].name, 'name_1')
def test_reverse_join(self):
def test_reverse_join(self) -> None:
box_qs = BoxJoinModel.objects.filter(name='name_1')
items = JoinItemForeignKey.objects.all().join(box_qs)
self.assertEqual(items.count(), 1)

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from django.test import TestCase
from tests.models import Post
class QueryManagerTests(TestCase):
def setUp(self):
def setUp(self) -> None:
data = ((True, True, 0),
(True, False, 4),
(False, False, 2),
@ -14,14 +16,14 @@ class QueryManagerTests(TestCase):
for p, c, o in data:
Post.objects.create(published=p, confirmed=c, order=o)
def test_passing_kwargs(self):
def test_passing_kwargs(self) -> None:
qs = Post.public.all()
self.assertEqual([p.order for p in qs], [0, 1, 4, 5])
def test_passing_Q(self):
def test_passing_Q(self) -> None:
qs = Post.public_confirmed.all()
self.assertEqual([p.order for p in qs], [0, 1])
def test_ordering(self):
def test_ordering(self) -> None:
qs = Post.public_reversed.all()
self.assertEqual([p.order for p in qs], [5, 4, 1, 0])

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from django.test import TestCase
from tests.models import CustomSoftDelete
@ -5,21 +7,21 @@ from tests.models import CustomSoftDelete
class CustomSoftDeleteManagerTests(TestCase):
def test_custom_manager_empty(self):
def test_custom_manager_empty(self) -> None:
qs = CustomSoftDelete.available_objects.only_read()
self.assertEqual(qs.count(), 0)
def test_custom_qs_empty(self):
def test_custom_qs_empty(self) -> None:
qs = CustomSoftDelete.available_objects.all().only_read()
self.assertEqual(qs.count(), 0)
def test_is_read(self):
def test_is_read(self) -> None:
for is_read in [True, False, True, False]:
CustomSoftDelete.available_objects.create(is_read=is_read)
qs = CustomSoftDelete.available_objects.only_read()
self.assertEqual(qs.count(), 2)
def test_is_read_removed(self):
def test_is_read_removed(self) -> None:
for is_read, is_removed in [(True, True), (True, False), (False, False), (False, True)]:
CustomSoftDelete.available_objects.create(is_read=is_read, is_removed=is_removed)
qs = CustomSoftDelete.available_objects.only_read()

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from django.core.exceptions import ImproperlyConfigured
from django.db import models
from django.test import TestCase
@ -8,10 +10,10 @@ from tests.models import StatusManagerAdded
class StatusManagerAddedTests(TestCase):
def test_manager_available(self):
def test_manager_available(self) -> None:
self.assertTrue(isinstance(StatusManagerAdded.active, QueryManager))
def test_conflict_error(self):
def test_conflict_error(self) -> None:
with self.assertRaises(ImproperlyConfigured):
class ErrorModel(StatusModel):
STATUS = (

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from django.core.management import call_command
from django.test import TestCase
@ -5,23 +7,23 @@ from model_utils.fields import get_excerpt
class MigrationsTests(TestCase):
def test_makemigrations(self):
def test_makemigrations(self) -> None:
call_command('makemigrations', dry_run=True)
class GetExcerptTests(TestCase):
def test_split(self):
def test_split(self) -> None:
e = get_excerpt("some content\n\n<!-- split -->\n\nsome more")
self.assertEqual(e, 'some content\n')
def test_auto_split(self):
def test_auto_split(self) -> None:
e = get_excerpt("para one\n\npara two\n\npara three")
self.assertEqual(e, 'para one\n\npara two')
def test_middle_of_para(self):
def test_middle_of_para(self) -> None:
e = get_excerpt("some text\n<!-- split -->\nmore text")
self.assertEqual(e, 'some text')
def test_middle_of_line(self):
def test_middle_of_line(self) -> None:
e = get_excerpt("some text <!-- split --> more text")
self.assertEqual(e, "some text <!-- split --> more text")

View file

@ -1,10 +1,12 @@
from __future__ import annotations
from django.test import TestCase
from tests.models import ModelWithCustomDescriptor
class CustomDescriptorTests(TestCase):
def setUp(self):
def setUp(self) -> None:
self.instance = ModelWithCustomDescriptor.objects.create(
custom_field='1',
tracked_custom_field='1',
@ -12,7 +14,7 @@ class CustomDescriptorTests(TestCase):
tracked_regular_field=1,
)
def test_custom_descriptor_works(self):
def test_custom_descriptor_works(self) -> None:
instance = self.instance
self.assertEqual(instance.custom_field, '1')
self.assertEqual(instance.__dict__['custom_field'], 1)
@ -25,7 +27,7 @@ class CustomDescriptorTests(TestCase):
self.assertEqual(instance.custom_field, '2')
self.assertEqual(instance.__dict__['custom_field'], 2)
def test_deferred(self):
def test_deferred(self) -> None:
instance = ModelWithCustomDescriptor.objects.only('id').get(
pk=self.instance.pk)
self.assertIn('custom_field', instance.get_deferred_fields())

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from django.test import TestCase
from django.utils.connection import ConnectionDoesNotExist
@ -5,7 +7,7 @@ from tests.models import SoftDeletable
class SoftDeletableModelTests(TestCase):
def test_can_only_see_not_removed_entries(self):
def test_can_only_see_not_removed_entries(self) -> None:
SoftDeletable.available_objects.create(name='a', is_removed=True)
SoftDeletable.available_objects.create(name='b', is_removed=False)
@ -14,7 +16,7 @@ class SoftDeletableModelTests(TestCase):
self.assertEqual(queryset.count(), 1)
self.assertEqual(queryset[0].name, 'b')
def test_instance_cannot_be_fully_deleted(self):
def test_instance_cannot_be_fully_deleted(self) -> None:
instance = SoftDeletable.available_objects.create(name='a')
instance.delete()
@ -22,7 +24,7 @@ class SoftDeletableModelTests(TestCase):
self.assertEqual(SoftDeletable.available_objects.count(), 0)
self.assertEqual(SoftDeletable.all_objects.count(), 1)
def test_instance_cannot_be_fully_deleted_via_queryset(self):
def test_instance_cannot_be_fully_deleted_via_queryset(self) -> None:
SoftDeletable.available_objects.create(name='a')
SoftDeletable.available_objects.all().delete()
@ -30,12 +32,12 @@ class SoftDeletableModelTests(TestCase):
self.assertEqual(SoftDeletable.available_objects.count(), 0)
self.assertEqual(SoftDeletable.all_objects.count(), 1)
def test_delete_instance_no_connection(self):
def test_delete_instance_no_connection(self) -> None:
obj = SoftDeletable.available_objects.create(name='a')
self.assertRaises(ConnectionDoesNotExist, obj.delete, using='other')
def test_instance_purge(self):
def test_instance_purge(self) -> None:
instance = SoftDeletable.available_objects.create(name='a')
instance.delete(soft=False)
@ -43,11 +45,11 @@ class SoftDeletableModelTests(TestCase):
self.assertEqual(SoftDeletable.available_objects.count(), 0)
self.assertEqual(SoftDeletable.all_objects.count(), 0)
def test_instance_purge_no_connection(self):
def test_instance_purge_no_connection(self) -> None:
instance = SoftDeletable.available_objects.create(name='a')
self.assertRaises(ConnectionDoesNotExist, instance.delete,
using='other', soft=False)
def test_deprecation_warning(self):
def test_deprecation_warning(self) -> None:
self.assertWarns(DeprecationWarning, SoftDeletable.objects.all)

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime, timezone
import time_machine
@ -7,12 +9,14 @@ from tests.models import CustomManagerStatusModel, Status, StatusPlainTuple
class StatusModelTests(TestCase):
def setUp(self):
model: type[Status] | type[StatusPlainTuple]
def setUp(self) -> None:
self.model = Status
self.on_hold = Status.STATUS.on_hold
self.active = Status.STATUS.active
def test_created(self):
def test_created(self) -> None:
with time_machine.travel(datetime(2016, 1, 1)):
c1 = self.model.objects.create()
self.assertTrue(c1.status_changed, datetime(2016, 1, 1))
@ -21,7 +25,7 @@ class StatusModelTests(TestCase):
self.assertEqual(self.model.active.count(), 2)
self.assertEqual(self.model.deleted.count(), 0)
def test_modification(self):
def test_modification(self) -> None:
t1 = self.model.objects.create()
date_created = t1.status_changed
t1.status = self.on_hold
@ -37,7 +41,7 @@ class StatusModelTests(TestCase):
t1.save()
self.assertTrue(t1.status_changed > date_active_again)
def test_save_with_update_fields_overrides_status_changed_provided(self):
def test_save_with_update_fields_overrides_status_changed_provided(self) -> None:
'''
Tests if the save method updated status_changed field
accordingly when update_fields is used as an argument
@ -52,7 +56,7 @@ class StatusModelTests(TestCase):
self.assertEqual(t1.status_changed, datetime(2020, 1, 2, tzinfo=timezone.utc))
def test_save_with_update_fields_overrides_status_changed_not_provided(self):
def test_save_with_update_fields_overrides_status_changed_not_provided(self) -> None:
'''
Tests if the save method updated status_changed field
accordingly when update_fields is used as an argument
@ -69,7 +73,7 @@ class StatusModelTests(TestCase):
class StatusModelPlainTupleTests(StatusModelTests):
def setUp(self):
def setUp(self) -> None:
self.model = StatusPlainTuple
self.on_hold = StatusPlainTuple.STATUS[2][0]
self.active = StatusPlainTuple.STATUS[0][0]
@ -77,7 +81,7 @@ class StatusModelPlainTupleTests(StatusModelTests):
class StatusModelDefaultManagerTests(TestCase):
def test_default_manager_is_not_status_model_generated_ones(self):
def test_default_manager_is_not_status_model_generated_ones(self) -> None:
# Regression test for GH-251
# The logic behind order for managers seems to have changed in Django 1.10
# and affects default manager.

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime, timedelta
from django.core.exceptions import ImproperlyConfigured
@ -10,36 +12,36 @@ from tests.models import TimeFrame, TimeFrameManagerAdded
class TimeFramedModelTests(TestCase):
def setUp(self):
def setUp(self) -> None:
self.now = datetime.now()
def test_not_yet_begun(self):
def test_not_yet_begun(self) -> None:
TimeFrame.objects.create(start=self.now + timedelta(days=2))
self.assertEqual(TimeFrame.timeframed.count(), 0)
def test_finished(self):
def test_finished(self) -> None:
TimeFrame.objects.create(end=self.now - timedelta(days=1))
self.assertEqual(TimeFrame.timeframed.count(), 0)
def test_no_end(self):
def test_no_end(self) -> None:
TimeFrame.objects.create(start=self.now - timedelta(days=10))
self.assertEqual(TimeFrame.timeframed.count(), 1)
def test_no_start(self):
def test_no_start(self) -> None:
TimeFrame.objects.create(end=self.now + timedelta(days=2))
self.assertEqual(TimeFrame.timeframed.count(), 1)
def test_within_range(self):
def test_within_range(self) -> None:
TimeFrame.objects.create(start=self.now - timedelta(days=1),
end=self.now + timedelta(days=1))
self.assertEqual(TimeFrame.timeframed.count(), 1)
class TimeFrameManagerAddedTests(TestCase):
def test_manager_available(self):
def test_manager_available(self) -> None:
self.assertTrue(isinstance(TimeFrameManagerAdded.timeframed, QueryManager))
def test_conflict_error(self):
def test_conflict_error(self) -> None:
with self.assertRaises(ImproperlyConfigured):
class ErrorModel(TimeFramedModel):
timeframed = models.BooleanField()

View file

@ -1,3 +1,6 @@
from __future__ import annotations
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
import time_machine
@ -7,19 +10,19 @@ from tests.models import TimeStamp, TimeStampWithStatusModel
class TimeStampedModelTests(TestCase):
def test_created(self):
def test_created(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)):
t1 = TimeStamp.objects.create()
self.assertEqual(t1.created, datetime(2016, 1, 1, tzinfo=timezone.utc))
def test_created_sets_modified(self):
def test_created_sets_modified(self) -> None:
'''
Ensure that on creation that modified is set exactly equal to created.
'''
t1 = TimeStamp.objects.create()
self.assertEqual(t1.created, t1.modified)
def test_modified(self):
def test_modified(self) -> None:
with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)):
t1 = TimeStamp.objects.create()
@ -28,7 +31,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.modified, datetime(2016, 1, 2, tzinfo=timezone.utc))
def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self):
def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self) -> None:
"""
Setting the created date when first creating an object
should be permissible.
@ -38,7 +41,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.created, different_date)
self.assertEqual(t1.modified, different_date)
def test_overriding_modified_via_object_creation(self):
def test_overriding_modified_via_object_creation(self) -> None:
"""
Setting the modified date explicitly should be possible when
first creating an object, but not thereafter.
@ -48,7 +51,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.modified, different_date)
self.assertNotEqual(t1.created, different_date)
def test_overriding_created_after_object_created(self):
def test_overriding_created_after_object_created(self) -> None:
"""
The created date may be changed post-create
"""
@ -58,7 +61,7 @@ class TimeStampedModelTests(TestCase):
t1.save()
self.assertEqual(t1.created, different_date)
def test_overriding_modified_after_object_created(self):
def test_overriding_modified_after_object_created(self) -> None:
"""
The modified date should always be updated when the object
is saved, regardless of attempts to change it.
@ -69,7 +72,7 @@ class TimeStampedModelTests(TestCase):
t1.save()
self.assertNotEqual(t1.modified, different_date)
def test_overrides_using_save(self):
def test_overrides_using_save(self) -> None:
"""
The first time an object is saved, allow modification of both
created and modified fields.
@ -90,7 +93,7 @@ class TimeStampedModelTests(TestCase):
self.assertNotEqual(t1.modified, different_date2)
self.assertNotEqual(t1.modified, different_date)
def test_save_with_update_fields_overrides_modified_provided_within_a(self):
def test_save_with_update_fields_overrides_modified_provided_within_a(self) -> None:
"""
Tests if the save method updated modified field
accordingly when update_fields is used as an argument
@ -111,8 +114,8 @@ class TimeStampedModelTests(TestCase):
t1.save(update_fields=update_fields)
self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc))
def test_save_is_skipped_for_empty_update_fields_iterable(self):
tests = (
def test_save_is_skipped_for_empty_update_fields_iterable(self) -> None:
tests: Iterable[Iterable[str]] = (
[], # list
(), # tuple
set(), # set
@ -131,7 +134,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.test_field, 0)
self.assertEqual(t1.modified, datetime(2020, 1, 1, tzinfo=timezone.utc))
def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self):
def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self) -> None:
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)):
t1 = TimeStamp.objects.create()
@ -140,7 +143,7 @@ class TimeStampedModelTests(TestCase):
self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc))
def test_model_inherit_timestampmodel_and_statusmodel(self):
def test_model_inherit_timestampmodel_and_statusmodel(self) -> None:
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)):
t1 = TimeStampWithStatusModel.objects.create()

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from django.test import TestCase
from tests.models import CustomNotPrimaryUUIDModel, CustomUUIDModel
@ -5,13 +7,13 @@ from tests.models import CustomNotPrimaryUUIDModel, CustomUUIDModel
class UUIDFieldTests(TestCase):
def test_uuid_model_with_uuid_field_as_primary_key(self):
def test_uuid_model_with_uuid_field_as_primary_key(self) -> None:
instance = CustomUUIDModel()
instance.save()
self.assertEqual(instance.id.__class__.__name__, 'UUID')
self.assertEqual(instance.id, instance.pk)
def test_uuid_model_with_uuid_field_as_not_primary_key(self):
def test_uuid_model_with_uuid_field_as_not_primary_key(self) -> None:
instance = CustomNotPrimaryUUIDModel()
instance.save()
self.assertEqual(instance.uuid.__class__.__name__, 'UUID')