mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Merge pull request #603 from ProtixIT/type-annotations
Add type annotations
This commit is contained in:
commit
731ed804f3
30 changed files with 1142 additions and 642 deletions
|
|
@ -1,2 +1,8 @@
|
|||
[run]
|
||||
include = model_utils/*.py
|
||||
|
||||
[report]
|
||||
exclude_also =
|
||||
# Exclusive to mypy:
|
||||
if TYPE_CHECKING:$
|
||||
\.\.\.$
|
||||
|
|
|
|||
|
|
@ -1,7 +1,39 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
|
||||
# The type aliases defined here are evaluated when the django-stubs mypy plugin
|
||||
# loads this module, so they must be able to execute under the lowest supported
|
||||
# Python VM:
|
||||
# - typing.List, typing.Tuple become obsolete in Pyton 3.9
|
||||
# - typing.Union becomes obsolete in Pyton 3.10
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from django_stubs_ext import StrOrPromise
|
||||
|
||||
# The type argument 'T' to 'Choices' is the database representation type.
|
||||
_Double = Tuple[T, StrOrPromise]
|
||||
_Triple = Tuple[T, str, StrOrPromise]
|
||||
_Group = Tuple[StrOrPromise, Sequence["_Choice[T]"]]
|
||||
_Choice = Union[_Double[T], _Triple[T], _Group[T]]
|
||||
# Choices can only be given as a single string if 'T' is 'str'.
|
||||
_GroupStr = Tuple[StrOrPromise, Sequence["_ChoiceStr"]]
|
||||
_ChoiceStr = Union[str, _Double[str], _Triple[str], _GroupStr]
|
||||
# Note that we only accept lists and tuples in groups, not arbitrary sequences.
|
||||
# However, annotating it as such causes many problems.
|
||||
|
||||
_DoubleRead = Union[_Double[T], Tuple[StrOrPromise, Iterable["_DoubleRead[T]"]]]
|
||||
_DoubleCollector = List[Union[_Double[T], Tuple[StrOrPromise, "_DoubleCollector[T]"]]]
|
||||
_TripleCollector = List[Union[_Triple[T], Tuple[StrOrPromise, "_TripleCollector[T]"]]]
|
||||
|
||||
|
||||
class Choices:
|
||||
class Choices(Generic[T]):
|
||||
"""
|
||||
A class to encapsulate handy functionality for lists of choices
|
||||
for a Django model field.
|
||||
|
|
@ -41,36 +73,60 @@ class Choices:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, *choices):
|
||||
@overload
|
||||
def __init__(self: Choices[str], *choices: _ChoiceStr):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, *choices: _Choice[T]):
|
||||
...
|
||||
|
||||
def __init__(self, *choices: _ChoiceStr | _Choice[T]):
|
||||
# list of choices expanded to triples - can include optgroups
|
||||
self._triples = []
|
||||
self._triples: _TripleCollector[T] = []
|
||||
# list of choices as (db, human-readable) - can include optgroups
|
||||
self._doubles = []
|
||||
self._doubles: _DoubleCollector[T] = []
|
||||
# dictionary mapping db representation to human-readable
|
||||
self._display_map = {}
|
||||
self._display_map: dict[T, StrOrPromise | list[_Triple[T]]] = {}
|
||||
# dictionary mapping Python identifier to db representation
|
||||
self._identifier_map = {}
|
||||
self._identifier_map: dict[str, T] = {}
|
||||
# set of db representations
|
||||
self._db_values = set()
|
||||
self._db_values: set[T] = set()
|
||||
|
||||
self._process(choices)
|
||||
|
||||
def _store(self, triple, triple_collector, double_collector):
|
||||
def _store(
|
||||
self,
|
||||
triple: tuple[T, str, StrOrPromise],
|
||||
triple_collector: _TripleCollector[T],
|
||||
double_collector: _DoubleCollector[T]
|
||||
) -> None:
|
||||
self._identifier_map[triple[1]] = triple[0]
|
||||
self._display_map[triple[0]] = triple[2]
|
||||
self._db_values.add(triple[0])
|
||||
triple_collector.append(triple)
|
||||
double_collector.append((triple[0], triple[2]))
|
||||
|
||||
def _process(self, choices, triple_collector=None, double_collector=None):
|
||||
def _process(
|
||||
self,
|
||||
choices: Iterable[_ChoiceStr | _Choice[T]],
|
||||
triple_collector: _TripleCollector[T] | None = None,
|
||||
double_collector: _DoubleCollector[T] | None = None
|
||||
) -> None:
|
||||
if triple_collector is None:
|
||||
triple_collector = self._triples
|
||||
if double_collector is None:
|
||||
double_collector = self._doubles
|
||||
|
||||
store = lambda c: self._store(c, triple_collector, double_collector)
|
||||
def store(c: tuple[Any, str, StrOrPromise]) -> None:
|
||||
self._store(c, triple_collector, double_collector)
|
||||
|
||||
for choice in choices:
|
||||
# The type inference is not very accurate here:
|
||||
# - we lied in the type aliases, stating groups contain an arbitrary Sequence
|
||||
# rather than only list or tuple
|
||||
# - there is no way to express that _ChoiceStr is only used when T=str
|
||||
# - mypy 1.9.0 doesn't narrow types based on the value of len()
|
||||
if isinstance(choice, (list, tuple)):
|
||||
if len(choice) == 3:
|
||||
store(choice)
|
||||
|
|
@ -79,13 +135,13 @@ class Choices:
|
|||
# option group
|
||||
group_name = choice[0]
|
||||
subchoices = choice[1]
|
||||
tc = []
|
||||
tc: _TripleCollector[T] = []
|
||||
triple_collector.append((group_name, tc))
|
||||
dc = []
|
||||
dc: _DoubleCollector[T] = []
|
||||
double_collector.append((group_name, dc))
|
||||
self._process(subchoices, tc, dc)
|
||||
else:
|
||||
store((choice[0], choice[0], choice[1]))
|
||||
store((choice[0], cast(str, choice[0]), cast('StrOrPromise', choice[1])))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Choices can't take a list of length %s, only 2 or 3"
|
||||
|
|
@ -94,54 +150,74 @@ class Choices:
|
|||
else:
|
||||
store((choice, choice, choice))
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return len(self._doubles)
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[_DoubleRead[T]]:
|
||||
return iter(self._doubles)
|
||||
|
||||
def __reversed__(self):
|
||||
def __reversed__(self) -> Iterator[_DoubleRead[T]]:
|
||||
return reversed(self._doubles)
|
||||
|
||||
def __getattr__(self, attname):
|
||||
def __getattr__(self, attname: str) -> T:
|
||||
try:
|
||||
return self._identifier_map[attname]
|
||||
except KeyError:
|
||||
raise AttributeError(attname)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: T) -> StrOrPromise | Sequence[_Triple[T]]:
|
||||
return self._display_map[key]
|
||||
|
||||
def __add__(self, other):
|
||||
@overload
|
||||
def __add__(self: Choices[str], other: Choices[str] | Iterable[_ChoiceStr]) -> Choices[str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __add__(self, other: Choices[T] | Iterable[_Choice[T]]) -> Choices[T]:
|
||||
...
|
||||
|
||||
def __add__(self, other: Choices[Any] | Iterable[_ChoiceStr | _Choice[Any]]) -> Choices[Any]:
|
||||
other_args: list[Any]
|
||||
if isinstance(other, self.__class__):
|
||||
other = other._triples
|
||||
other_args = other._triples
|
||||
else:
|
||||
other = list(other)
|
||||
return Choices(*(self._triples + other))
|
||||
other_args = list(other)
|
||||
return Choices(*(self._triples + other_args))
|
||||
|
||||
def __radd__(self, other):
|
||||
@overload
|
||||
def __radd__(self: Choices[str], other: Iterable[_ChoiceStr]) -> Choices[str]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __radd__(self, other: Iterable[_Choice[T]]) -> Choices[T]:
|
||||
...
|
||||
|
||||
def __radd__(self, other: Iterable[_ChoiceStr] | Iterable[_Choice[T]]) -> Choices[Any]:
|
||||
# radd is never called for matching types, so we don't check here
|
||||
other = list(other)
|
||||
return Choices(*(other + self._triples))
|
||||
other_args = list(other)
|
||||
# The exact type of 'other' depends on our type argument 'T', which
|
||||
# is expressed in the overloading, but lost within this method body.
|
||||
return Choices(*(other_args + self._triples)) # type: ignore[arg-type]
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, self.__class__):
|
||||
return self._triples == other._triples
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return '{}({})'.format(
|
||||
self.__class__.__name__,
|
||||
', '.join("%s" % repr(i) for i in self._triples)
|
||||
)
|
||||
|
||||
def __contains__(self, item):
|
||||
def __contains__(self, item: T) -> bool:
|
||||
return item in self._db_values
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
return self.__class__(*copy.deepcopy(self._triples, memo))
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> Choices[T]:
|
||||
args: list[Any] = copy.deepcopy(self._triples, memo)
|
||||
return self.__class__(*args)
|
||||
|
||||
def subset(self, *new_identifiers):
|
||||
def subset(self, *new_identifiers: str) -> Choices[T]:
|
||||
identifiers = set(self._identifier_map.keys())
|
||||
|
||||
if not identifiers.issuperset(new_identifiers):
|
||||
|
|
@ -150,7 +226,8 @@ class Choices:
|
|||
identifiers.symmetric_difference(new_identifiers),
|
||||
)
|
||||
|
||||
return self.__class__(*[
|
||||
args: list[Any] = [
|
||||
choice for choice in self._triples
|
||||
if choice[1] in new_identifiers
|
||||
])
|
||||
]
|
||||
return self.__class__(*args)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.utils.timezone import now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable
|
||||
from datetime import date, datetime
|
||||
|
||||
DateTimeFieldBase = models.DateTimeField[Union[str, datetime, date], datetime]
|
||||
else:
|
||||
DateTimeFieldBase = models.DateTimeField
|
||||
|
||||
DEFAULT_CHOICES_NAME = 'STATUS'
|
||||
|
||||
|
||||
class AutoCreatedField(models.DateTimeField):
|
||||
class AutoCreatedField(DateTimeFieldBase):
|
||||
"""
|
||||
A DateTimeField that automatically populates itself at
|
||||
object creation.
|
||||
|
|
@ -18,7 +30,7 @@ class AutoCreatedField(models.DateTimeField):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
kwargs.setdefault('editable', False)
|
||||
kwargs.setdefault('default', now)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
@ -31,13 +43,13 @@ class AutoLastModifiedField(AutoCreatedField):
|
|||
By default, sets editable=False and default=datetime.now.
|
||||
|
||||
"""
|
||||
def get_default(self):
|
||||
def get_default(self) -> datetime:
|
||||
"""Return the default value for this field."""
|
||||
if not hasattr(self, "_default"):
|
||||
self._default = self._get_default()
|
||||
self._default = super().get_default()
|
||||
return self._default
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
def pre_save(self, model_instance: models.Model, add: bool) -> datetime:
|
||||
value = now()
|
||||
if add:
|
||||
current_value = getattr(model_instance, self.attname, self.get_default())
|
||||
|
|
@ -68,13 +80,19 @@ class StatusField(models.CharField):
|
|||
South can handle this field when it freezes a model.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, no_check_for_status=False, choices_name=DEFAULT_CHOICES_NAME, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
no_check_for_status: bool = False,
|
||||
choices_name: str = DEFAULT_CHOICES_NAME,
|
||||
**kwargs: Any
|
||||
):
|
||||
kwargs.setdefault('max_length', 100)
|
||||
self.check_for_status = not no_check_for_status
|
||||
self.choices_name = choices_name
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def prepare_class(self, sender, **kwargs):
|
||||
def prepare_class(self, sender: type[models.Model], **kwargs: Any) -> None:
|
||||
if not sender._meta.abstract and self.check_for_status:
|
||||
assert hasattr(sender, self.choices_name), \
|
||||
"To use StatusField, the model '%s' must have a %s choices class attribute." \
|
||||
|
|
@ -83,7 +101,7 @@ class StatusField(models.CharField):
|
|||
if not self.has_default():
|
||||
self.default = tuple(getattr(sender, self.choices_name))[0][0] # set first as default
|
||||
|
||||
def contribute_to_class(self, cls, name, *args, **kwargs):
|
||||
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
|
||||
models.signals.class_prepared.connect(self.prepare_class, sender=cls)
|
||||
# we don't set the real choices until class_prepared (so we can rely on
|
||||
# the STATUS class attr being available), but we need to set some dummy
|
||||
|
|
@ -91,13 +109,13 @@ class StatusField(models.CharField):
|
|||
self.choices = [(0, 'dummy')]
|
||||
super().contribute_to_class(cls, name, *args, **kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]:
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
kwargs['no_check_for_status'] = True
|
||||
return name, path, args, kwargs
|
||||
|
||||
|
||||
class MonitorField(models.DateTimeField):
|
||||
class MonitorField(DateTimeFieldBase):
|
||||
"""
|
||||
A DateTimeField that monitors another field on the same model and
|
||||
sets itself to the current date/time whenever the monitored field
|
||||
|
|
@ -105,30 +123,28 @@ class MonitorField(models.DateTimeField):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, monitor, when=None, **kwargs):
|
||||
def __init__(self, *args: Any, monitor: str, when: Iterable[Any] | None = None, **kwargs: Any):
|
||||
default = None if kwargs.get("null") else now
|
||||
kwargs.setdefault('default', default)
|
||||
self.monitor = monitor
|
||||
if when is not None:
|
||||
when = set(when)
|
||||
self.when = when
|
||||
self.when = None if when is None else set(when)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def contribute_to_class(self, cls, name, *args, **kwargs):
|
||||
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
|
||||
self.monitor_attname = '_monitor_%s' % name
|
||||
models.signals.post_init.connect(self._save_initial, sender=cls)
|
||||
super().contribute_to_class(cls, name, *args, **kwargs)
|
||||
|
||||
def get_monitored_value(self, instance):
|
||||
def get_monitored_value(self, instance: models.Model) -> Any:
|
||||
return getattr(instance, self.monitor)
|
||||
|
||||
def _save_initial(self, sender, instance, **kwargs):
|
||||
def _save_initial(self, sender: type[models.Model], instance: models.Model, **kwargs: Any) -> None:
|
||||
if self.monitor in instance.get_deferred_fields():
|
||||
# Fix related to issue #241 to avoid recursive error on double monitor fields
|
||||
return
|
||||
setattr(instance, self.monitor_attname, self.get_monitored_value(instance))
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
def pre_save(self, model_instance: models.Model, add: bool) -> Any:
|
||||
value = now()
|
||||
previous = getattr(model_instance, self.monitor_attname, None)
|
||||
current = self.get_monitored_value(model_instance)
|
||||
|
|
@ -138,7 +154,7 @@ class MonitorField(models.DateTimeField):
|
|||
self._save_initial(model_instance.__class__, model_instance)
|
||||
return super().pre_save(model_instance, add)
|
||||
|
||||
def deconstruct(self):
|
||||
def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]:
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
kwargs['monitor'] = self.monitor
|
||||
if self.when is not None:
|
||||
|
|
@ -152,12 +168,12 @@ SPLIT_MARKER = getattr(settings, 'SPLIT_MARKER', '<!-- split -->')
|
|||
SPLIT_DEFAULT_PARAGRAPHS = getattr(settings, 'SPLIT_DEFAULT_PARAGRAPHS', 2)
|
||||
|
||||
|
||||
def _excerpt_field_name(name):
|
||||
def _excerpt_field_name(name: str) -> str:
|
||||
return '_%s_excerpt' % name
|
||||
|
||||
|
||||
def get_excerpt(content):
|
||||
excerpt = []
|
||||
def get_excerpt(content: str) -> str:
|
||||
excerpt: list[str] = []
|
||||
default_excerpt = []
|
||||
paras_seen = 0
|
||||
for line in content.splitlines():
|
||||
|
|
@ -173,7 +189,7 @@ def get_excerpt(content):
|
|||
|
||||
|
||||
class SplitText:
|
||||
def __init__(self, instance, field_name, excerpt_field_name):
|
||||
def __init__(self, instance: models.Model, field_name: str, excerpt_field_name: str):
|
||||
# instead of storing actual values store a reference to the instance
|
||||
# along with field names, this makes assignment possible
|
||||
self.instance = instance
|
||||
|
|
@ -181,36 +197,36 @@ class SplitText:
|
|||
self.excerpt_field_name = excerpt_field_name
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
def content(self) -> str:
|
||||
return self.instance.__dict__[self.field_name]
|
||||
|
||||
@content.setter
|
||||
def content(self, val):
|
||||
def content(self, val: str) -> None:
|
||||
setattr(self.instance, self.field_name, val)
|
||||
|
||||
@property
|
||||
def excerpt(self):
|
||||
def excerpt(self) -> str:
|
||||
return getattr(self.instance, self.excerpt_field_name)
|
||||
|
||||
@property
|
||||
def has_more(self):
|
||||
def has_more(self) -> bool:
|
||||
return self.excerpt.strip() != self.content.strip()
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
class SplitDescriptor:
|
||||
def __init__(self, field):
|
||||
def __init__(self, field: SplitField):
|
||||
self.field = field
|
||||
self.excerpt_field_name = _excerpt_field_name(self.field.name)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
def __get__(self, instance: models.Model, owner: type[models.Model]) -> SplitText:
|
||||
if instance is None:
|
||||
raise AttributeError('Can only be accessed via an instance.')
|
||||
return SplitText(instance, self.field.name, self.excerpt_field_name)
|
||||
|
||||
def __set__(self, obj, value):
|
||||
def __set__(self, obj: models.Model, value: SplitText | str) -> None:
|
||||
if isinstance(value, SplitText):
|
||||
obj.__dict__[self.field.name] = value.content
|
||||
setattr(obj, self.excerpt_field_name, value.excerpt)
|
||||
|
|
@ -218,25 +234,32 @@ class SplitDescriptor:
|
|||
obj.__dict__[self.field.name] = value
|
||||
|
||||
|
||||
class SplitField(models.TextField):
|
||||
def contribute_to_class(self, cls, name, *args, **kwargs):
|
||||
if TYPE_CHECKING:
|
||||
_SplitFieldBase = models.TextField[Union[SplitText, str], SplitText]
|
||||
else:
|
||||
_SplitFieldBase = models.TextField
|
||||
|
||||
|
||||
class SplitField(_SplitFieldBase):
|
||||
|
||||
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
|
||||
if not cls._meta.abstract:
|
||||
excerpt_field = models.TextField(editable=False)
|
||||
excerpt_field: models.TextField = models.TextField(editable=False)
|
||||
cls.add_to_class(_excerpt_field_name(name), excerpt_field)
|
||||
super().contribute_to_class(cls, name, *args, **kwargs)
|
||||
setattr(cls, self.name, SplitDescriptor(self))
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
value = super().pre_save(model_instance, add)
|
||||
def pre_save(self, model_instance: models.Model, add: bool) -> str:
|
||||
value: SplitText = super().pre_save(model_instance, add)
|
||||
excerpt = get_excerpt(value.content)
|
||||
setattr(model_instance, _excerpt_field_name(self.attname), excerpt)
|
||||
return value.content
|
||||
|
||||
def value_to_string(self, obj):
|
||||
def value_to_string(self, obj: models.Model) -> str:
|
||||
value = self.value_from_object(obj)
|
||||
return value.content
|
||||
|
||||
def get_prep_value(self, value):
|
||||
def get_prep_value(self, value: Any) -> str:
|
||||
try:
|
||||
return value.content
|
||||
except AttributeError:
|
||||
|
|
@ -248,7 +271,14 @@ class UUIDField(models.UUIDField):
|
|||
A field for storing universally unique identifiers. Use Python UUID class.
|
||||
"""
|
||||
|
||||
def __init__(self, primary_key=True, version=4, editable=False, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
primary_key: bool = True,
|
||||
version: int = 4,
|
||||
editable: bool = False,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -274,6 +304,7 @@ class UUIDField(models.UUIDField):
|
|||
raise ValidationError(
|
||||
'UUID version is not valid.')
|
||||
|
||||
default: Callable[..., uuid.UUID]
|
||||
if version == 1:
|
||||
default = uuid.uuid1
|
||||
elif version == 3:
|
||||
|
|
@ -294,7 +325,15 @@ class UrlsafeTokenField(models.CharField):
|
|||
A field for storing a unique token in database.
|
||||
"""
|
||||
|
||||
def __init__(self, editable=False, max_length=128, factory=None, **kwargs):
|
||||
max_length: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
editable: bool = False,
|
||||
max_length: int = 128,
|
||||
factory: Callable[[int], str] | None = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -319,14 +358,14 @@ class UrlsafeTokenField(models.CharField):
|
|||
|
||||
super().__init__(editable=editable, max_length=max_length, **kwargs)
|
||||
|
||||
def get_default(self):
|
||||
def get_default(self) -> str:
|
||||
if self._factory is not None:
|
||||
return self._factory(self.max_length)
|
||||
# generate a token of length x1.33 approx. trim up to max length
|
||||
token = secrets.token_urlsafe(self.max_length)[:self.max_length]
|
||||
return token
|
||||
|
||||
def deconstruct(self):
|
||||
def deconstruct(self) -> tuple[str, str, Sequence[Any], dict[str, Any]]:
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
kwargs['factory'] = self._factory
|
||||
return name, path, args, kwargs
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar, cast, overload
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import connection, models
|
||||
|
|
@ -7,57 +10,84 @@ from django.db.models.fields.related import OneToOneField, OneToOneRel
|
|||
from django.db.models.query import ModelIterable, QuerySet
|
||||
from django.db.models.sql.datastructures import Join
|
||||
|
||||
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
|
||||
|
||||
class InheritanceIterable(ModelIterable):
|
||||
def __iter__(self):
|
||||
queryset = self.queryset
|
||||
iter = ModelIterable(queryset)
|
||||
if getattr(queryset, 'subclasses', False):
|
||||
extras = tuple(queryset.query.extra.keys())
|
||||
# sort the subclass names longest first,
|
||||
# so with 'a' and 'a__b' it goes as deep as possible
|
||||
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
|
||||
for obj in iter:
|
||||
sub_obj = None
|
||||
for s in subclasses:
|
||||
sub_obj = queryset._get_sub_obj_recurse(obj, s)
|
||||
if sub_obj:
|
||||
break
|
||||
if not sub_obj:
|
||||
sub_obj = obj
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
if getattr(queryset, '_annotated', False):
|
||||
for k in queryset._annotated:
|
||||
setattr(sub_obj, k, getattr(obj, k))
|
||||
from django.db.models.query import BaseIterable
|
||||
|
||||
for k in extras:
|
||||
|
||||
def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]:
|
||||
iter: ModelIterable[ModelT] = ModelIterable(queryset)
|
||||
if hasattr(queryset, 'subclasses'):
|
||||
assert hasattr(queryset, '_get_sub_obj_recurse')
|
||||
extras = tuple(queryset.query.extra.keys())
|
||||
# sort the subclass names longest first,
|
||||
# so with 'a' and 'a__b' it goes as deep as possible
|
||||
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
|
||||
for obj in iter:
|
||||
sub_obj = None
|
||||
for s in subclasses:
|
||||
sub_obj = queryset._get_sub_obj_recurse(obj, s)
|
||||
if sub_obj:
|
||||
break
|
||||
if not sub_obj:
|
||||
sub_obj = obj
|
||||
|
||||
if hasattr(queryset, '_annotated'):
|
||||
for k in queryset._annotated:
|
||||
setattr(sub_obj, k, getattr(obj, k))
|
||||
|
||||
yield sub_obj
|
||||
else:
|
||||
yield from iter
|
||||
for k in extras:
|
||||
setattr(sub_obj, k, getattr(obj, k))
|
||||
|
||||
yield sub_obj
|
||||
else:
|
||||
yield from iter
|
||||
|
||||
|
||||
class InheritanceQuerySetMixin:
|
||||
def __init__(self, *args, **kwargs):
|
||||
if TYPE_CHECKING:
|
||||
class InheritanceIterable(ModelIterable[ModelT]):
|
||||
queryset: QuerySet[ModelT]
|
||||
|
||||
def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any):
|
||||
...
|
||||
|
||||
def __iter__(self) -> Iterator[ModelT]:
|
||||
...
|
||||
|
||||
else:
|
||||
class InheritanceIterable(ModelIterable):
|
||||
def __iter__(self):
|
||||
return _iter_inheritance_queryset(self.queryset)
|
||||
|
||||
|
||||
class InheritanceQuerySetMixin(Generic[ModelT]):
|
||||
|
||||
model: type[ModelT]
|
||||
subclasses: Sequence[str]
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._iterable_class = InheritanceIterable
|
||||
self._iterable_class: type[BaseIterable[ModelT]] = InheritanceIterable
|
||||
|
||||
def select_subclasses(self, *subclasses):
|
||||
calculated_subclasses = self._get_subclasses_recurse(self.model)
|
||||
def select_subclasses(self, *subclasses: str | type[models.Model]) -> InheritanceQuerySet[ModelT]:
|
||||
model: type[ModelT] = self.model
|
||||
calculated_subclasses = self._get_subclasses_recurse(model)
|
||||
# if none were passed in, we can just short circuit and select all
|
||||
if not subclasses:
|
||||
subclasses = calculated_subclasses
|
||||
selected_subclasses = calculated_subclasses
|
||||
else:
|
||||
verified_subclasses = []
|
||||
verified_subclasses: list[str] = []
|
||||
for subclass in subclasses:
|
||||
# special case for passing in the same model as the queryset
|
||||
# is bound against. Rather than raise an error later, we know
|
||||
# we can allow this through.
|
||||
if subclass is self.model:
|
||||
if subclass is model:
|
||||
continue
|
||||
|
||||
if not isinstance(subclass, (str,)):
|
||||
if not isinstance(subclass, str):
|
||||
subclass = self._get_ancestors_path(subclass)
|
||||
|
||||
if subclass in calculated_subclasses:
|
||||
|
|
@ -67,38 +97,39 @@ class InheritanceQuerySetMixin:
|
|||
'{!r} is not in the discovered subclasses, tried: {}'.format(
|
||||
subclass, ', '.join(calculated_subclasses))
|
||||
)
|
||||
subclasses = verified_subclasses
|
||||
selected_subclasses = verified_subclasses
|
||||
|
||||
if subclasses:
|
||||
new_qs = self.select_related(*subclasses)
|
||||
else:
|
||||
new_qs = self
|
||||
new_qs.subclasses = subclasses
|
||||
new_qs = cast('InheritanceQuerySet[ModelT]', self)
|
||||
if selected_subclasses:
|
||||
new_qs = new_qs.select_related(*selected_subclasses)
|
||||
new_qs.subclasses = selected_subclasses
|
||||
return new_qs
|
||||
|
||||
def _chain(self, **kwargs):
|
||||
def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]:
|
||||
update = {}
|
||||
for name in ['subclasses', '_annotated']:
|
||||
if hasattr(self, name):
|
||||
update[name] = getattr(self, name)
|
||||
|
||||
chained = super()._chain(**kwargs)
|
||||
# django-stubs doesn't include this private API.
|
||||
chained = super()._chain(**kwargs) # type: ignore[misc]
|
||||
chained.__dict__.update(update)
|
||||
return chained
|
||||
|
||||
def _clone(self):
|
||||
qs = super()._clone()
|
||||
def _clone(self) -> InheritanceQuerySet[ModelT]:
|
||||
# django-stubs doesn't include this private API.
|
||||
qs = super()._clone() # type: ignore[misc]
|
||||
for name in ['subclasses', '_annotated']:
|
||||
if hasattr(self, name):
|
||||
setattr(qs, name, getattr(self, name))
|
||||
return qs
|
||||
|
||||
def annotate(self, *args, **kwargs):
|
||||
qset = super().annotate(*args, **kwargs)
|
||||
def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
qset = cast(QuerySet[ModelT], super()).annotate(*args, **kwargs)
|
||||
qset._annotated = [a.default_alias for a in args] + list(kwargs.keys())
|
||||
return qset
|
||||
|
||||
def _get_subclasses_recurse(self, model):
|
||||
def _get_subclasses_recurse(self, model: type[models.Model]) -> list[str]:
|
||||
"""
|
||||
Given a Model class, find all related objects, exploring children
|
||||
recursively, returning a `list` of strings representing the
|
||||
|
|
@ -124,7 +155,7 @@ class InheritanceQuerySetMixin:
|
|||
subclasses.append(rel.get_accessor_name())
|
||||
return subclasses
|
||||
|
||||
def _get_ancestors_path(self, model):
|
||||
def _get_ancestors_path(self, model: type[models.Model]) -> str:
|
||||
"""
|
||||
Serves as an opposite to _get_subclasses_recurse, instead walking from
|
||||
the Model class up the Model's ancestry and constructing the desired
|
||||
|
|
@ -134,7 +165,7 @@ class InheritanceQuerySetMixin:
|
|||
raise ValueError(
|
||||
f"{model!r} is not a subclass of {self.model!r}")
|
||||
|
||||
ancestry = []
|
||||
ancestry: list[str] = []
|
||||
# should be a OneToOneField or None
|
||||
parent_link = model._meta.get_ancestor_link(self.model)
|
||||
|
||||
|
|
@ -147,7 +178,7 @@ class InheritanceQuerySetMixin:
|
|||
|
||||
return LOOKUP_SEP.join(ancestry)
|
||||
|
||||
def _get_sub_obj_recurse(self, obj, s):
|
||||
def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None:
|
||||
rel, _, s = s.partition(LOOKUP_SEP)
|
||||
|
||||
try:
|
||||
|
|
@ -160,12 +191,14 @@ class InheritanceQuerySetMixin:
|
|||
else:
|
||||
return node
|
||||
|
||||
def get_subclass(self, *args, **kwargs):
|
||||
def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
|
||||
return self.select_subclasses().get(*args, **kwargs)
|
||||
|
||||
|
||||
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
|
||||
def instance_of(self, *models):
|
||||
# Defining the 'model' attribute using a generic type triggers a bug in mypy:
|
||||
# https://github.com/python/mypy/issues/9031
|
||||
class InheritanceQuerySet(InheritanceQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc]
|
||||
def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
|
||||
"""
|
||||
Fetch only objects that are instances of the provided model(s).
|
||||
"""
|
||||
|
|
@ -188,88 +221,187 @@ class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
|
|||
) for field in model._meta.parents.values()
|
||||
]) + ')')
|
||||
|
||||
return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
|
||||
return cast(
|
||||
'InheritanceQuerySet[ModelT]',
|
||||
self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
|
||||
)
|
||||
|
||||
|
||||
class InheritanceManagerMixin:
|
||||
class InheritanceManagerMixin(Generic[ModelT]):
|
||||
_queryset_class = InheritanceQuerySet
|
||||
|
||||
def get_queryset(self):
|
||||
return self._queryset_class(self.model)
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
def select_subclasses(self, *subclasses):
|
||||
def none(self) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def all(self) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def exclude(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def complex_filter(self, filter_obj: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def union(self, *other_qs: Any, all: bool = ...) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def intersection(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def difference(self, *other_qs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def select_for_update(
|
||||
self, nowait: bool = ..., skip_locked: bool = ..., of: Sequence[str] = ..., no_key: bool = ...
|
||||
) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def select_related(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def prefetch_related(self, *lookups: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def alias(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def order_by(self, *field_names: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def distinct(self, *field_names: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def extra(
|
||||
self,
|
||||
select: dict[str, Any] | None = ...,
|
||||
where: list[str] | None = ...,
|
||||
params: list[Any] | None = ...,
|
||||
tables: list[str] | None = ...,
|
||||
order_by: Sequence[str] | None = ...,
|
||||
select_params: Sequence[Any] | None = ...,
|
||||
) -> InheritanceQuerySet[Any]:
|
||||
...
|
||||
|
||||
def reverse(self) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def defer(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def only(self, *fields: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def using(self, alias: str | None) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def get_queryset(self) -> InheritanceQuerySet[ModelT]:
|
||||
model: type[ModelT] = self.model # type: ignore[attr-defined]
|
||||
return self._queryset_class(model)
|
||||
|
||||
def select_subclasses(
|
||||
self, *subclasses: str | type[models.Model]
|
||||
) -> InheritanceQuerySet[ModelT]:
|
||||
return self.get_queryset().select_subclasses(*subclasses)
|
||||
|
||||
def get_subclass(self, *args, **kwargs):
|
||||
def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
|
||||
return self.get_queryset().get_subclass(*args, **kwargs)
|
||||
|
||||
def instance_of(self, *models):
|
||||
def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
|
||||
return self.get_queryset().instance_of(*models)
|
||||
|
||||
|
||||
class InheritanceManager(InheritanceManagerMixin, models.Manager):
|
||||
class InheritanceManager(InheritanceManagerMixin[ModelT], models.Manager[ModelT]):
|
||||
pass
|
||||
|
||||
|
||||
class QueryManagerMixin:
|
||||
class QueryManagerMixin(Generic[ModelT]):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@overload
|
||||
def __init__(self, *args: models.Q):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, **kwargs: object):
|
||||
...
|
||||
|
||||
def __init__(self, *args: models.Q, **kwargs: object):
|
||||
if args:
|
||||
self._q = args[0]
|
||||
else:
|
||||
self._q = models.Q(**kwargs)
|
||||
self._order_by = None
|
||||
self._order_by: tuple[Any, ...] | None = None
|
||||
super().__init__()
|
||||
|
||||
def order_by(self, *args):
|
||||
def order_by(self, *args: Any) -> QueryManager[ModelT]:
|
||||
self._order_by = args
|
||||
return self
|
||||
return cast('QueryManager[ModelT]', self)
|
||||
|
||||
def get_queryset(self):
|
||||
qs = super().get_queryset().filter(self._q)
|
||||
def get_queryset(self) -> QuerySet[ModelT]:
|
||||
qs = super().get_queryset() # type: ignore[misc]
|
||||
qs = qs.filter(self._q)
|
||||
if self._order_by is not None:
|
||||
return qs.order_by(*self._order_by)
|
||||
return qs
|
||||
|
||||
|
||||
class QueryManager(QueryManagerMixin, models.Manager):
|
||||
class QueryManager(QueryManagerMixin[ModelT], models.Manager[ModelT]): # type: ignore[misc]
|
||||
pass
|
||||
|
||||
|
||||
class SoftDeletableQuerySetMixin:
|
||||
class SoftDeletableQuerySetMixin(Generic[ModelT]):
|
||||
"""
|
||||
QuerySet for SoftDeletableModel. Instead of removing instance sets
|
||||
its ``is_removed`` field to True.
|
||||
"""
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Soft delete objects from queryset (set their ``is_removed``
|
||||
field to True)
|
||||
"""
|
||||
self.update(is_removed=True)
|
||||
cast(QuerySet[ModelT], self).update(is_removed=True)
|
||||
|
||||
|
||||
class SoftDeletableQuerySet(SoftDeletableQuerySetMixin, QuerySet):
|
||||
# Note that our delete() method does not return anything, unlike Django's.
|
||||
# https://github.com/jazzband/django-model-utils/issues/541
|
||||
class SoftDeletableQuerySet(SoftDeletableQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc]
|
||||
pass
|
||||
|
||||
|
||||
class SoftDeletableManagerMixin:
|
||||
class SoftDeletableManagerMixin(Generic[ModelT]):
|
||||
"""
|
||||
Manager that limits the queryset by default to show only not removed
|
||||
instances of model.
|
||||
"""
|
||||
_queryset_class = SoftDeletableQuerySet
|
||||
|
||||
def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs):
|
||||
_db: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: object,
|
||||
_emit_deprecation_warnings: bool = False,
|
||||
**kwargs: object
|
||||
):
|
||||
self.emit_deprecation_warnings = _emit_deprecation_warnings
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
def get_queryset(self) -> SoftDeletableQuerySet[ModelT]:
|
||||
"""
|
||||
Return queryset limited to not removed entries.
|
||||
"""
|
||||
|
||||
model: type[ModelT] = self.model # type: ignore[attr-defined]
|
||||
|
||||
if self.emit_deprecation_warnings:
|
||||
warning_message = (
|
||||
"{0}.objects model manager will include soft-deleted objects in an "
|
||||
|
|
@ -277,23 +409,23 @@ class SoftDeletableManagerMixin:
|
|||
"excluding soft-deleted objects. See "
|
||||
"https://django-model-utils.readthedocs.io/en/stable/models.html"
|
||||
"#softdeletablemodel for more information."
|
||||
).format(self.model.__class__.__name__)
|
||||
).format(model.__class__.__name__)
|
||||
warnings.warn(warning_message, DeprecationWarning)
|
||||
|
||||
kwargs = {'model': self.model, 'using': self._db}
|
||||
if hasattr(self, '_hints'):
|
||||
kwargs['hints'] = self._hints
|
||||
|
||||
return self._queryset_class(**kwargs).filter(is_removed=False)
|
||||
return self._queryset_class(
|
||||
model=model,
|
||||
using=self._db,
|
||||
**({'hints': self._hints} if hasattr(self, '_hints') else {})
|
||||
).filter(is_removed=False)
|
||||
|
||||
|
||||
class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager):
|
||||
class SoftDeletableManager(SoftDeletableManagerMixin[ModelT], models.Manager[ModelT]):
|
||||
pass
|
||||
|
||||
|
||||
class JoinQueryset(models.QuerySet):
|
||||
class JoinQueryset(models.QuerySet[Any]):
|
||||
|
||||
def join(self, qs=None):
|
||||
def join(self, qs: QuerySet[Any] | None = None) -> QuerySet[Any]:
|
||||
'''
|
||||
Join one queryset together with another using a temporary table. If
|
||||
no queryset is used, it will use the current queryset and join that
|
||||
|
|
@ -308,11 +440,11 @@ class JoinQueryset(models.QuerySet):
|
|||
to_field = 'id'
|
||||
|
||||
if qs:
|
||||
fk = [
|
||||
fks = [
|
||||
fk for fk in qs.model._meta.fields
|
||||
if getattr(fk, 'related_model', None) == self.model
|
||||
]
|
||||
fk = fk[0] if fk else None
|
||||
fk = fks[0] if fks else None
|
||||
model_set = f'{self.model.__name__.lower()}_set'
|
||||
key = fk or getattr(qs.model, model_set, None)
|
||||
|
||||
|
|
@ -331,7 +463,7 @@ class JoinQueryset(models.QuerySet):
|
|||
else:
|
||||
fk_column = 'id'
|
||||
qs = self.only(fk_column)
|
||||
new_qs = self.model.objects.all()
|
||||
new_qs = self.model._default_manager.all()
|
||||
|
||||
TABLE_NAME = 'temp_stuff'
|
||||
query, params = qs.query.sql_with_params()
|
||||
|
|
@ -369,21 +501,24 @@ class JoinQueryset(models.QuerySet):
|
|||
return new_qs
|
||||
|
||||
|
||||
class JoinManagerMixin:
|
||||
"""
|
||||
Manager that adds a method join. This method allows you to join two
|
||||
querysets together.
|
||||
"""
|
||||
_queryset_class = JoinQueryset
|
||||
if not TYPE_CHECKING:
|
||||
# Hide deprecated API during type checking, to encourage switch to
|
||||
# 'JoinQueryset.as_manager()', which is supported by the mypy plugin
|
||||
# of django-stubs.
|
||||
|
||||
def get_queryset(self):
|
||||
warnings.warn(
|
||||
"JoinManager and JoinManagerMixin are deprecated. "
|
||||
"Please use 'JoinQueryset.as_manager()' instead.",
|
||||
DeprecationWarning
|
||||
)
|
||||
return self._queryset_class(model=self.model, using=self._db)
|
||||
class JoinManagerMixin:
|
||||
"""
|
||||
Manager that adds a method join. This method allows you to join two
|
||||
querysets together.
|
||||
"""
|
||||
|
||||
def get_queryset(self):
|
||||
warnings.warn(
|
||||
"JoinManager and JoinManagerMixin are deprecated. "
|
||||
"Please use 'JoinQueryset.as_manager()' instead.",
|
||||
DeprecationWarning
|
||||
)
|
||||
return self._queryset_class(model=self.model, using=self._db)
|
||||
|
||||
class JoinManager(JoinManagerMixin, models.Manager):
|
||||
pass
|
||||
class JoinManager(JoinManagerMixin):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, TypeVar, overload
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import models
|
||||
from django.db.models.functions import Now
|
||||
|
|
@ -12,6 +16,8 @@ from model_utils.fields import (
|
|||
)
|
||||
from model_utils.managers import QueryManager, SoftDeletableManager
|
||||
|
||||
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
|
||||
|
||||
now = Now()
|
||||
|
||||
|
||||
|
|
@ -24,7 +30,7 @@ class TimeStampedModel(models.Model):
|
|||
created = AutoCreatedField(_('created'))
|
||||
modified = AutoLastModifiedField(_('modified'))
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Overriding the save method in order to make sure that
|
||||
modified field is updated even if it is not given as
|
||||
|
|
@ -65,7 +71,7 @@ class StatusModel(models.Model):
|
|||
status = StatusField(_('status'))
|
||||
status_changed = MonitorField(_('status changed'), monitor='status')
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Overriding the save method in order to make sure that
|
||||
status_changed field is updated even if it is not given as
|
||||
|
|
@ -81,7 +87,7 @@ class StatusModel(models.Model):
|
|||
abstract = True
|
||||
|
||||
|
||||
def add_status_query_managers(sender, **kwargs):
|
||||
def add_status_query_managers(sender: type[models.Model], **kwargs: Any) -> None:
|
||||
"""
|
||||
Add a Querymanager for each status item dynamically.
|
||||
|
||||
|
|
@ -90,6 +96,7 @@ def add_status_query_managers(sender, **kwargs):
|
|||
return
|
||||
|
||||
default_manager = sender._meta.default_manager
|
||||
assert default_manager is not None
|
||||
|
||||
for value, display in getattr(sender, 'STATUS', ()):
|
||||
if _field_exists(sender, value):
|
||||
|
|
@ -103,7 +110,7 @@ def add_status_query_managers(sender, **kwargs):
|
|||
sender._meta.default_manager_name = default_manager.name
|
||||
|
||||
|
||||
def add_timeframed_query_manager(sender, **kwargs):
|
||||
def add_timeframed_query_manager(sender: type[models.Model], **kwargs: Any) -> None:
|
||||
"""
|
||||
Add a QueryManager for a specific timeframe.
|
||||
|
||||
|
|
@ -126,7 +133,7 @@ models.signals.class_prepared.connect(add_status_query_managers)
|
|||
models.signals.class_prepared.connect(add_timeframed_query_manager)
|
||||
|
||||
|
||||
def _field_exists(model_class, field_name):
|
||||
def _field_exists(model_class: type[models.Model], field_name: str) -> bool:
|
||||
return field_name in [f.attname for f in model_class._meta.local_fields]
|
||||
|
||||
|
||||
|
|
@ -142,11 +149,28 @@ class SoftDeletableModel(models.Model):
|
|||
class Meta:
|
||||
abstract = True
|
||||
|
||||
objects = SoftDeletableManager(_emit_deprecation_warnings=True)
|
||||
available_objects = SoftDeletableManager()
|
||||
objects: models.Manager[SoftDeletableModel] = SoftDeletableManager(_emit_deprecation_warnings=True)
|
||||
available_objects: models.Manager[SoftDeletableModel] = SoftDeletableManager()
|
||||
all_objects = models.Manager()
|
||||
|
||||
def delete(self, using=None, *args, soft=True, **kwargs):
|
||||
# Note that soft delete does not return anything,
|
||||
# which doesn't conform to Django's interface.
|
||||
# https://github.com/jazzband/django-model-utils/issues/541
|
||||
@overload # type: ignore[override]
|
||||
def delete(
|
||||
self, using: Any = None, *args: Any, soft: Literal[True] = True, **kwargs: Any
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def delete(
|
||||
self, using: Any = None, *args: Any, soft: Literal[False], **kwargs: Any
|
||||
) -> tuple[int, dict[str, int]]:
|
||||
...
|
||||
|
||||
def delete(
|
||||
self, using: Any = None, *args: Any, soft: bool = True, **kwargs: Any
|
||||
) -> tuple[int, dict[str, int]] | None:
|
||||
"""
|
||||
Soft delete object (set its ``is_removed`` field to True).
|
||||
Actually delete object if setting ``soft`` to False.
|
||||
|
|
@ -154,6 +178,7 @@ class SoftDeletableModel(models.Model):
|
|||
if soft:
|
||||
self.is_removed = True
|
||||
self.save(using=using)
|
||||
return None
|
||||
else:
|
||||
return super().delete(using, *args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,45 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Iterable,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import models
|
||||
from django.db.models.fields.files import FieldFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Mapping
|
||||
from types import TracebackType
|
||||
|
||||
class _AugmentedModel(models.Model):
|
||||
_instance_initialized: bool
|
||||
_deferred_fields: set[str]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Descriptor(Protocol[T]):
|
||||
def __get__(self, instance: object, owner: type[object]) -> T:
|
||||
...
|
||||
|
||||
def __set__(self, instance: object, value: T) -> None:
|
||||
...
|
||||
|
||||
|
||||
class FullDescriptor(Descriptor[T]):
|
||||
def __delete__(self, instance: object) -> None:
|
||||
...
|
||||
|
||||
|
||||
class LightStateFieldFile(FieldFile):
|
||||
"""
|
||||
|
|
@ -16,7 +51,7 @@ class LightStateFieldFile(FieldFile):
|
|||
Django 3.1+ can make the app unusable, as CPU and memory usage gets easily
|
||||
multiplied by magnitudes.
|
||||
"""
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""
|
||||
We don't need to deepcopy the instance, so nullify if provided.
|
||||
"""
|
||||
|
|
@ -26,27 +61,35 @@ class LightStateFieldFile(FieldFile):
|
|||
return state
|
||||
|
||||
|
||||
def lightweight_deepcopy(value):
|
||||
def lightweight_deepcopy(value: T) -> T:
|
||||
"""
|
||||
Use our lightweight class to avoid copying the instance on a FieldFile deepcopy.
|
||||
"""
|
||||
if isinstance(value, FieldFile):
|
||||
value = LightStateFieldFile(
|
||||
value = cast(T, LightStateFieldFile(
|
||||
instance=value.instance,
|
||||
field=value.field,
|
||||
name=value.name,
|
||||
)
|
||||
))
|
||||
return deepcopy(value)
|
||||
|
||||
|
||||
class DescriptorWrapper:
|
||||
class DescriptorWrapper(Generic[T]):
|
||||
|
||||
def __init__(self, field_name, descriptor, tracker_attname):
|
||||
def __init__(self, field_name: str, descriptor: Descriptor[T], tracker_attname: str):
|
||||
self.field_name = field_name
|
||||
self.descriptor = descriptor
|
||||
self.tracker_attname = tracker_attname
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper[T]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: models.Model, owner: type[models.Model]) -> T:
|
||||
...
|
||||
|
||||
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper[T] | T:
|
||||
if instance is None:
|
||||
return self
|
||||
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||
|
|
@ -56,7 +99,7 @@ class DescriptorWrapper:
|
|||
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
|
||||
return value
|
||||
|
||||
def __set__(self, instance, value):
|
||||
def __set__(self, instance: models.Model, value: T) -> None:
|
||||
initialized = hasattr(instance, '_instance_initialized')
|
||||
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||
|
||||
|
|
@ -79,23 +122,23 @@ class DescriptorWrapper:
|
|||
else:
|
||||
instance.__dict__[self.field_name] = value
|
||||
|
||||
def __getattr__(self, attr):
|
||||
def __getattr__(self, attr: str) -> T:
|
||||
return getattr(self.descriptor, attr)
|
||||
|
||||
@staticmethod
|
||||
def cls_for_descriptor(descriptor):
|
||||
def cls_for_descriptor(descriptor: Descriptor[T]) -> type[DescriptorWrapper[T]]:
|
||||
if hasattr(descriptor, '__delete__'):
|
||||
return FullDescriptorWrapper
|
||||
else:
|
||||
return DescriptorWrapper
|
||||
|
||||
|
||||
class FullDescriptorWrapper(DescriptorWrapper):
|
||||
class FullDescriptorWrapper(DescriptorWrapper[T]):
|
||||
"""
|
||||
Wrapper for descriptors with all three descriptor methods.
|
||||
"""
|
||||
def __delete__(self, obj):
|
||||
self.descriptor.__delete__(obj)
|
||||
def __delete__(self, obj: models.Model) -> None:
|
||||
cast(FullDescriptor[T], self.descriptor).__delete__(obj)
|
||||
|
||||
|
||||
class FieldsContext:
|
||||
|
|
@ -119,7 +162,12 @@ class FieldsContext:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, tracker, *fields, state=None):
|
||||
def __init__(
|
||||
self,
|
||||
tracker: FieldInstanceTracker,
|
||||
*fields: str,
|
||||
state: dict[str, int] | None = None
|
||||
):
|
||||
"""
|
||||
:param tracker: FieldInstanceTracker instance to be reset after
|
||||
context exit
|
||||
|
|
@ -137,7 +185,7 @@ class FieldsContext:
|
|||
self.fields = fields
|
||||
self.state = state
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> FieldsContext:
|
||||
"""
|
||||
Increments tracked fields occurrences count in shared state.
|
||||
"""
|
||||
|
|
@ -146,7 +194,12 @@ class FieldsContext:
|
|||
self.state[f] += 1
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
"""
|
||||
Decrements tracked fields occurrences count in shared state.
|
||||
|
||||
|
|
@ -164,29 +217,34 @@ class FieldsContext:
|
|||
|
||||
|
||||
class FieldInstanceTracker:
|
||||
def __init__(self, instance, fields, field_map):
|
||||
self.instance = instance
|
||||
def __init__(self, instance: models.Model, fields: Iterable[str], field_map: Mapping[str, str]):
|
||||
self.instance = cast('_AugmentedModel', instance)
|
||||
self.fields = fields
|
||||
self.field_map = field_map
|
||||
self.context = FieldsContext(self, *self.fields)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> FieldsContext:
|
||||
return self.context.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
return self.context.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def __call__(self, *fields):
|
||||
def __call__(self, *fields: str) -> FieldsContext:
|
||||
return FieldsContext(self, *fields, state=self.context.state)
|
||||
|
||||
@property
|
||||
def deferred_fields(self):
|
||||
def deferred_fields(self) -> set[str]:
|
||||
return self.instance.get_deferred_fields()
|
||||
|
||||
def get_field_value(self, field):
|
||||
def get_field_value(self, field: str) -> Any:
|
||||
return getattr(self.instance, self.field_map[field])
|
||||
|
||||
def set_saved_fields(self, fields=None):
|
||||
def set_saved_fields(self, fields: Iterable[str] | None = None) -> None:
|
||||
if not self.instance.pk:
|
||||
self.saved_data = {}
|
||||
elif fields is None:
|
||||
|
|
@ -198,7 +256,7 @@ class FieldInstanceTracker:
|
|||
for field, field_value in self.saved_data.items():
|
||||
self.saved_data[field] = lightweight_deepcopy(field_value)
|
||||
|
||||
def current(self, fields=None):
|
||||
def current(self, fields: Iterable[str] | None = None) -> dict[str, Any]:
|
||||
"""Returns dict of current values for all tracked fields"""
|
||||
if fields is None:
|
||||
deferred_fields = self.deferred_fields
|
||||
|
|
@ -212,17 +270,19 @@ class FieldInstanceTracker:
|
|||
|
||||
return {f: self.get_field_value(f) for f in fields}
|
||||
|
||||
def has_changed(self, field):
|
||||
def has_changed(self, field: str) -> bool:
|
||||
"""Returns ``True`` if field has changed from currently saved value"""
|
||||
if field in self.fields:
|
||||
# deferred fields haven't changed
|
||||
if field in self.deferred_fields and field not in self.instance.__dict__:
|
||||
return False
|
||||
return self.previous(field) != self.get_field_value(field)
|
||||
prev: object = self.previous(field)
|
||||
curr: object = self.get_field_value(field)
|
||||
return prev != curr
|
||||
else:
|
||||
raise FieldError('field "%s" not tracked' % field)
|
||||
|
||||
def previous(self, field):
|
||||
def previous(self, field: str) -> Any:
|
||||
"""Returns currently saved value of given field"""
|
||||
|
||||
# handle deferred fields that have not yet been loaded from the database
|
||||
|
|
@ -242,7 +302,7 @@ class FieldInstanceTracker:
|
|||
|
||||
return self.saved_data.get(field)
|
||||
|
||||
def changed(self):
|
||||
def changed(self) -> dict[str, Any]:
|
||||
"""Returns dict of fields that changed since save (with old values)"""
|
||||
return {
|
||||
field: self.previous(field)
|
||||
|
|
@ -255,13 +315,34 @@ class FieldTracker:
|
|||
|
||||
tracker_class = FieldInstanceTracker
|
||||
|
||||
def __init__(self, fields=None):
|
||||
self.fields = fields
|
||||
def __init__(self, fields: Iterable[str] | None = None):
|
||||
# finalize_class() will replace None; pretend it is never None.
|
||||
self.fields = cast(Iterable[str], fields)
|
||||
|
||||
def __call__(self, func=None, fields=None):
|
||||
def decorator(f):
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
func: None = None,
|
||||
fields: Iterable[str] | None = None
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
func: Callable[..., T],
|
||||
fields: Iterable[str] | None = None
|
||||
) -> Callable[..., T]:
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
func: Callable[..., T] | None = None,
|
||||
fields: Iterable[str] | None = None
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]] | Callable[..., T]:
|
||||
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
||||
@wraps(f)
|
||||
def inner(obj, *args, **kwargs):
|
||||
def inner(obj: models.Model, *args: object, **kwargs: object) -> T:
|
||||
tracker = getattr(obj, self.attname)
|
||||
field_list = tracker.fields if fields is None else fields
|
||||
with tracker(*field_list):
|
||||
|
|
@ -272,7 +353,7 @@ class FieldTracker:
|
|||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
def get_field_map(self, cls):
|
||||
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
|
||||
"""Returns dict mapping fields names to model attribute names"""
|
||||
field_map = {field: field for field in self.fields}
|
||||
all_fields = {f.name: f.attname for f in cls._meta.fields}
|
||||
|
|
@ -280,17 +361,17 @@ class FieldTracker:
|
|||
if k in field_map})
|
||||
return field_map
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
def contribute_to_class(self, cls: type[models.Model], name: str) -> None:
|
||||
self.name = name
|
||||
self.attname = '_%s' % name
|
||||
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
|
||||
|
||||
def finalize_class(self, sender, **kwargs):
|
||||
if self.fields is None:
|
||||
def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
|
||||
if self.fields is None or TYPE_CHECKING:
|
||||
self.fields = (field.attname for field in sender._meta.fields)
|
||||
self.fields = set(self.fields)
|
||||
for field_name in self.fields:
|
||||
descriptor = getattr(sender, field_name)
|
||||
descriptor: models.Field[Any, Any] = getattr(sender, field_name)
|
||||
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
|
||||
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
|
||||
setattr(sender, field_name, wrapped_descriptor)
|
||||
|
|
@ -300,34 +381,39 @@ class FieldTracker:
|
|||
setattr(sender, self.name, self)
|
||||
self.patch_save(sender)
|
||||
|
||||
def initialize_tracker(self, sender, instance, **kwargs):
|
||||
def initialize_tracker(
|
||||
self,
|
||||
sender: type[models.Model],
|
||||
instance: models.Model,
|
||||
**kwargs: object
|
||||
) -> None:
|
||||
if not isinstance(instance, self.model_class):
|
||||
return # Only init instances of given model (including children)
|
||||
tracker = self.tracker_class(instance, self.fields, self.field_map)
|
||||
setattr(instance, self.attname, tracker)
|
||||
tracker.set_saved_fields()
|
||||
instance._instance_initialized = True
|
||||
cast('_AugmentedModel', instance)._instance_initialized = True
|
||||
|
||||
def patch_init(self, model):
|
||||
def patch_init(self, model: type[models.Model]) -> None:
|
||||
original = getattr(model, '__init__')
|
||||
|
||||
@wraps(original)
|
||||
def inner(instance, *args, **kwargs):
|
||||
def inner(instance: models.Model, *args: Any, **kwargs: Any) -> None:
|
||||
original(instance, *args, **kwargs)
|
||||
self.initialize_tracker(model, instance)
|
||||
|
||||
setattr(model, '__init__', inner)
|
||||
|
||||
def patch_save(self, model):
|
||||
def patch_save(self, model: type[models.Model]) -> None:
|
||||
self._patch(model, 'save_base', 'update_fields')
|
||||
self._patch(model, 'refresh_from_db', 'fields')
|
||||
|
||||
def _patch(self, model, method, fields_kwarg):
|
||||
def _patch(self, model: type[models.Model], method: str, fields_kwarg: str) -> None:
|
||||
original = getattr(model, method)
|
||||
|
||||
@wraps(original)
|
||||
def inner(instance, *args, **kwargs):
|
||||
update_fields = kwargs.get(fields_kwarg)
|
||||
def inner(instance: models.Model, *args: object, **kwargs: Any) -> object:
|
||||
update_fields: Iterable[str] | None = kwargs.get(fields_kwarg)
|
||||
if update_fields is None:
|
||||
fields = self.fields
|
||||
else:
|
||||
|
|
@ -341,7 +427,15 @@ class FieldTracker:
|
|||
|
||||
setattr(model, method, inner)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: type[models.Model]) -> FieldTracker:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: models.Model, owner: type[models.Model]) -> FieldInstanceTracker:
|
||||
...
|
||||
|
||||
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> FieldTracker | FieldInstanceTracker:
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
|
|
@ -350,16 +444,18 @@ class FieldTracker:
|
|||
|
||||
class ModelInstanceTracker(FieldInstanceTracker):
|
||||
|
||||
def has_changed(self, field):
|
||||
def has_changed(self, field: str) -> bool:
|
||||
"""Returns ``True`` if field has changed from currently saved value"""
|
||||
if not self.instance.pk:
|
||||
return True
|
||||
elif field in self.saved_data:
|
||||
return self.previous(field) != self.get_field_value(field)
|
||||
prev: object = self.previous(field)
|
||||
curr: object = self.get_field_value(field)
|
||||
return prev != curr
|
||||
else:
|
||||
raise FieldError('field "%s" not tracked' % field)
|
||||
|
||||
def changed(self):
|
||||
def changed(self) -> dict[str, Any]:
|
||||
"""Returns dict of fields that changed since save (with old values)"""
|
||||
if not self.instance.pk:
|
||||
return {}
|
||||
|
|
@ -371,5 +467,5 @@ class ModelInstanceTracker(FieldInstanceTracker):
|
|||
class ModelTracker(FieldTracker):
|
||||
tracker_class = ModelInstanceTracker
|
||||
|
||||
def get_field_map(self, cls):
|
||||
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
|
||||
return {field: field for field in self.fields}
|
||||
|
|
|
|||
2
mypy.ini
2
mypy.ini
|
|
@ -1,4 +1,6 @@
|
|||
[mypy]
|
||||
disallow_incomplete_defs=True
|
||||
disallow_untyped_defs=True
|
||||
implicit_reexport=False
|
||||
pretty=True
|
||||
show_error_codes=True
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
mypy==1.9.0
|
||||
django-stubs==4.2.7
|
||||
mypy==1.10.0
|
||||
django-stubs==5.0.2
|
||||
pytest==7.4.3
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.db import models
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
|
||||
|
||||
def mutable_from_db(value):
|
||||
def mutable_from_db(value: object) -> Any:
|
||||
if value == '':
|
||||
return None
|
||||
try:
|
||||
|
|
@ -12,7 +17,7 @@ def mutable_from_db(value):
|
|||
return value
|
||||
|
||||
|
||||
def mutable_to_db(value):
|
||||
def mutable_to_db(value: object) -> str:
|
||||
if value is None:
|
||||
return ''
|
||||
if isinstance(value, list):
|
||||
|
|
@ -21,12 +26,12 @@ def mutable_to_db(value):
|
|||
|
||||
|
||||
class MutableField(models.TextField):
|
||||
def to_python(self, value):
|
||||
def to_python(self, value: object) -> Any:
|
||||
return mutable_from_db(value)
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
def from_db_value(self, value: object, expression: object, connection: BaseDatabaseWrapper) -> Any:
|
||||
return mutable_from_db(value)
|
||||
|
||||
def get_db_prep_save(self, value, connection):
|
||||
def get_db_prep_save(self, value: object, connection: BaseDatabaseWrapper) -> str:
|
||||
value = super().get_db_prep_save(value, connection)
|
||||
return mutable_to_db(value)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
from typing import Any, ClassVar, TypeVar, overload
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import Manager
|
||||
from django.db.models.query import QuerySet
|
||||
from django.db.models.query_utils import DeferredAttribute
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
|
@ -11,7 +12,7 @@ from model_utils import Choices
|
|||
from model_utils.fields import MonitorField, SplitField, StatusField, UUIDField
|
||||
from model_utils.managers import (
|
||||
InheritanceManager,
|
||||
JoinManager,
|
||||
JoinQueryset,
|
||||
QueryManager,
|
||||
SoftDeletableManager,
|
||||
SoftDeletableQuerySet,
|
||||
|
|
@ -26,6 +27,8 @@ from model_utils.models import (
|
|||
from model_utils.tracker import FieldTracker, ModelTracker
|
||||
from tests.fields import MutableField
|
||||
|
||||
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
|
||||
|
||||
|
||||
class InheritanceManagerTestRelated(models.Model):
|
||||
pass
|
||||
|
|
@ -43,7 +46,7 @@ class InheritanceManagerTestParent(models.Model):
|
|||
on_delete=models.CASCADE)
|
||||
objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager()
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "{}({})".format(
|
||||
self.__class__.__name__[len('InheritanceManagerTest'):],
|
||||
self.pk,
|
||||
|
|
@ -128,7 +131,7 @@ class DoubleMonitored(models.Model):
|
|||
|
||||
|
||||
class Status(StatusModel):
|
||||
STATUS = Choices(
|
||||
STATUS: Choices[str] = Choices(
|
||||
("active", _("active")),
|
||||
("deleted", _("deleted")),
|
||||
("on_hold", _("on hold")),
|
||||
|
|
@ -184,7 +187,8 @@ class Post(models.Model):
|
|||
public: ClassVar[QueryManager[Post]] = QueryManager(published=True)
|
||||
public_confirmed: ClassVar[QueryManager[Post]] = QueryManager(
|
||||
models.Q(published=True) & models.Q(confirmed=True))
|
||||
public_reversed = QueryManager(published=True).order_by("-order")
|
||||
public_reversed: ClassVar[QueryManager[Post]] = QueryManager(
|
||||
published=True).order_by("-order")
|
||||
|
||||
class Meta:
|
||||
ordering = ("order",)
|
||||
|
|
@ -203,7 +207,6 @@ class SplitFieldAbstractParent(models.Model):
|
|||
|
||||
|
||||
class AbstractTracked(models.Model):
|
||||
number: models.IntegerField
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
|
@ -216,7 +219,7 @@ class Tracked(models.Model):
|
|||
|
||||
tracker = FieldTracker()
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
""" No-op save() to ensure that FieldTracker.patch_save() works. """
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
|
@ -228,7 +231,7 @@ class TrackerTimeStamped(TimeStampedModel):
|
|||
|
||||
tracker = FieldTracker()
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
""" Automatically add "modified" to update_fields."""
|
||||
update_fields = kwargs.get('update_fields')
|
||||
if update_fields is not None:
|
||||
|
|
@ -263,7 +266,7 @@ class TrackedNonFieldAttr(models.Model):
|
|||
number = models.FloatField()
|
||||
|
||||
@property
|
||||
def rounded(self):
|
||||
def rounded(self) -> int | None:
|
||||
return round(self.number) if self.number is not None else None
|
||||
|
||||
tracker = FieldTracker(fields=['rounded'])
|
||||
|
|
@ -352,43 +355,52 @@ class SoftDeletable(SoftDeletableModel):
|
|||
all_objects: ClassVar[Manager[SoftDeletable]] = models.Manager()
|
||||
|
||||
|
||||
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet):
|
||||
def only_read(self):
|
||||
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet[ModelT]):
|
||||
def only_read(self) -> QuerySet[ModelT]:
|
||||
return self.filter(is_read=True)
|
||||
|
||||
|
||||
class CustomSoftDelete(SoftDeletableModel):
|
||||
is_read = models.BooleanField(default=False)
|
||||
|
||||
available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() # type: ignore[misc]
|
||||
available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)()
|
||||
|
||||
|
||||
class StringyDescriptor:
|
||||
"""
|
||||
Descriptor that returns a string version of the underlying integer value.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def __get__(self, obj, cls=None):
|
||||
@overload
|
||||
def __get__(self, obj: None, cls: type[models.Model] | None = None) -> StringyDescriptor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, obj: models.Model, cls: type[models.Model]) -> str:
|
||||
...
|
||||
|
||||
def __get__(self, obj: models.Model | None, cls: type[models.Model] | None = None) -> StringyDescriptor | str:
|
||||
if obj is None:
|
||||
return self
|
||||
if self.name in obj.get_deferred_fields():
|
||||
# This queries the database, and sets the value on the instance.
|
||||
assert cls is not None
|
||||
fields_map = {f.name: f for f in cls._meta.fields}
|
||||
field = fields_map[self.name]
|
||||
DeferredAttribute(field=field).__get__(obj, cls)
|
||||
return str(obj.__dict__[self.name])
|
||||
|
||||
def __set__(self, obj, value):
|
||||
def __set__(self, obj: object, value: str) -> None:
|
||||
obj.__dict__[self.name] = int(value)
|
||||
|
||||
def __delete__(self, obj):
|
||||
def __delete__(self, obj: object) -> None:
|
||||
del obj.__dict__[self.name]
|
||||
|
||||
|
||||
class CustomDescriptorField(models.IntegerField):
|
||||
def contribute_to_class(self, cls, name, *args, **kwargs):
|
||||
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
|
||||
super().contribute_to_class(cls, name, *args, **kwargs)
|
||||
setattr(cls, name, StringyDescriptor(name))
|
||||
|
||||
|
|
@ -404,7 +416,7 @@ class ModelWithCustomDescriptor(models.Model):
|
|||
|
||||
class BoxJoinModel(models.Model):
|
||||
name = models.CharField(max_length=32)
|
||||
objects: ClassVar[JoinManager[BoxJoinModel]] = JoinManager()
|
||||
objects = JoinQueryset.as_manager()
|
||||
|
||||
|
||||
class JoinItemForeignKey(models.Model):
|
||||
|
|
@ -414,7 +426,7 @@ class JoinItemForeignKey(models.Model):
|
|||
null=True,
|
||||
on_delete=models.CASCADE
|
||||
)
|
||||
objects: ClassVar[JoinManager[JoinItemForeignKey]] = JoinManager()
|
||||
objects = JoinQueryset.as_manager()
|
||||
|
||||
|
||||
class CustomUUIDModel(UUIDModel):
|
||||
|
|
|
|||
|
|
@ -1,116 +1,129 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generic, TypeVar
|
||||
|
||||
import pytest
|
||||
from django.test import TestCase
|
||||
|
||||
from model_utils import Choices
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class ChoicesTests(TestCase):
|
||||
def setUp(self):
|
||||
|
||||
class ChoicesTestsMixin(Generic[T]):
|
||||
|
||||
STATUS: Choices[T]
|
||||
|
||||
def test_getattr(self) -> None:
|
||||
assert self.STATUS.DRAFT == 'DRAFT'
|
||||
|
||||
def test_len(self) -> None:
|
||||
assert len(self.STATUS) == 2
|
||||
|
||||
def test_repr(self) -> None:
|
||||
assert repr(self.STATUS) == "Choices" + repr((
|
||||
('DRAFT', 'DRAFT', 'DRAFT'),
|
||||
('PUBLISHED', 'PUBLISHED', 'PUBLISHED'),
|
||||
))
|
||||
|
||||
def test_wrong_length_tuple(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
Choices(('a',)) # type: ignore[arg-type]
|
||||
|
||||
def test_deepcopy(self) -> None:
|
||||
import copy
|
||||
assert list(self.STATUS) == list(copy.deepcopy(self.STATUS))
|
||||
|
||||
def test_equality(self) -> None:
|
||||
assert self.STATUS == Choices('DRAFT', 'PUBLISHED')
|
||||
|
||||
def test_inequality(self) -> None:
|
||||
assert self.STATUS != ['DRAFT', 'PUBLISHED']
|
||||
assert self.STATUS != Choices('DRAFT')
|
||||
|
||||
def test_composability(self) -> None:
|
||||
assert Choices('DRAFT') + Choices('PUBLISHED') == self.STATUS
|
||||
assert Choices('DRAFT') + ('PUBLISHED',) == self.STATUS
|
||||
assert ('DRAFT',) + Choices('PUBLISHED') == self.STATUS
|
||||
|
||||
def test_option_groups(self) -> None:
|
||||
# Note: The implementation accepts any kind of sequence, but the type system can only
|
||||
# track per-index types for tuples.
|
||||
if TYPE_CHECKING:
|
||||
c = Choices(('group a', ['one', 'two']), ('group b', ('three',)))
|
||||
else:
|
||||
c = Choices(('group a', ['one', 'two']), ['group b', ('three',)])
|
||||
assert list(c) == [
|
||||
('group a', [('one', 'one'), ('two', 'two')]),
|
||||
('group b', [('three', 'three')]),
|
||||
]
|
||||
|
||||
|
||||
class ChoicesTests(TestCase, ChoicesTestsMixin[str]):
|
||||
def setUp(self) -> None:
|
||||
self.STATUS = Choices('DRAFT', 'PUBLISHED')
|
||||
|
||||
def test_getattr(self):
|
||||
self.assertEqual(self.STATUS.DRAFT, 'DRAFT')
|
||||
|
||||
def test_indexing(self):
|
||||
def test_indexing(self) -> None:
|
||||
self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED')
|
||||
|
||||
def test_iteration(self):
|
||||
def test_iteration(self) -> None:
|
||||
self.assertEqual(tuple(self.STATUS),
|
||||
(('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED')))
|
||||
|
||||
def test_reversed(self):
|
||||
def test_reversed(self) -> None:
|
||||
self.assertEqual(tuple(reversed(self.STATUS)),
|
||||
(('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT')))
|
||||
|
||||
def test_len(self):
|
||||
self.assertEqual(len(self.STATUS), 2)
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(self.STATUS), "Choices" + repr((
|
||||
('DRAFT', 'DRAFT', 'DRAFT'),
|
||||
('PUBLISHED', 'PUBLISHED', 'PUBLISHED'),
|
||||
)))
|
||||
|
||||
def test_wrong_length_tuple(self):
|
||||
with self.assertRaises(ValueError):
|
||||
Choices(('a',))
|
||||
|
||||
def test_contains_value(self):
|
||||
def test_contains_value(self) -> None:
|
||||
self.assertTrue('PUBLISHED' in self.STATUS)
|
||||
self.assertTrue('DRAFT' in self.STATUS)
|
||||
|
||||
def test_doesnt_contain_value(self):
|
||||
def test_doesnt_contain_value(self) -> None:
|
||||
self.assertFalse('UNPUBLISHED' in self.STATUS)
|
||||
|
||||
def test_deepcopy(self):
|
||||
import copy
|
||||
self.assertEqual(list(self.STATUS),
|
||||
list(copy.deepcopy(self.STATUS)))
|
||||
|
||||
def test_equality(self):
|
||||
self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED'))
|
||||
|
||||
def test_inequality(self):
|
||||
self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED'])
|
||||
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
|
||||
|
||||
def test_composability(self):
|
||||
self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS)
|
||||
self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS)
|
||||
self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS)
|
||||
|
||||
def test_option_groups(self):
|
||||
c = Choices(('group a', ['one', 'two']), ['group b', ('three',)])
|
||||
self.assertEqual(
|
||||
list(c),
|
||||
[
|
||||
('group a', [('one', 'one'), ('two', 'two')]),
|
||||
('group b', [('three', 'three')]),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class LabelChoicesTests(ChoicesTests):
|
||||
def setUp(self):
|
||||
class LabelChoicesTests(TestCase, ChoicesTestsMixin[str]):
|
||||
def setUp(self) -> None:
|
||||
self.STATUS = Choices(
|
||||
('DRAFT', 'is draft'),
|
||||
('PUBLISHED', 'is published'),
|
||||
'DELETED',
|
||||
)
|
||||
|
||||
def test_iteration(self):
|
||||
def test_iteration(self) -> None:
|
||||
self.assertEqual(tuple(self.STATUS), (
|
||||
('DRAFT', 'is draft'),
|
||||
('PUBLISHED', 'is published'),
|
||||
('DELETED', 'DELETED'),
|
||||
))
|
||||
|
||||
def test_reversed(self):
|
||||
def test_reversed(self) -> None:
|
||||
self.assertEqual(tuple(reversed(self.STATUS)), (
|
||||
('DELETED', 'DELETED'),
|
||||
('PUBLISHED', 'is published'),
|
||||
('DRAFT', 'is draft'),
|
||||
))
|
||||
|
||||
def test_indexing(self):
|
||||
def test_indexing(self) -> None:
|
||||
self.assertEqual(self.STATUS['PUBLISHED'], 'is published')
|
||||
|
||||
def test_default(self):
|
||||
def test_default(self) -> None:
|
||||
self.assertEqual(self.STATUS.DELETED, 'DELETED')
|
||||
|
||||
def test_provided(self):
|
||||
def test_provided(self) -> None:
|
||||
self.assertEqual(self.STATUS.DRAFT, 'DRAFT')
|
||||
|
||||
def test_len(self):
|
||||
def test_len(self) -> None:
|
||||
self.assertEqual(len(self.STATUS), 3)
|
||||
|
||||
def test_equality(self):
|
||||
def test_equality(self) -> None:
|
||||
self.assertEqual(self.STATUS, Choices(
|
||||
('DRAFT', 'is draft'),
|
||||
('PUBLISHED', 'is published'),
|
||||
'DELETED',
|
||||
))
|
||||
|
||||
def test_inequality(self):
|
||||
def test_inequality(self) -> None:
|
||||
self.assertNotEqual(self.STATUS, [
|
||||
('DRAFT', 'is draft'),
|
||||
('PUBLISHED', 'is published'),
|
||||
|
|
@ -118,27 +131,27 @@ class LabelChoicesTests(ChoicesTests):
|
|||
])
|
||||
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
|
||||
|
||||
def test_repr(self):
|
||||
def test_repr(self) -> None:
|
||||
self.assertEqual(repr(self.STATUS), "Choices" + repr((
|
||||
('DRAFT', 'DRAFT', 'is draft'),
|
||||
('PUBLISHED', 'PUBLISHED', 'is published'),
|
||||
('DELETED', 'DELETED', 'DELETED'),
|
||||
)))
|
||||
|
||||
def test_contains_value(self):
|
||||
def test_contains_value(self) -> None:
|
||||
self.assertTrue('PUBLISHED' in self.STATUS)
|
||||
self.assertTrue('DRAFT' in self.STATUS)
|
||||
# This should be True, because both the display value
|
||||
# and the internal representation are both DELETED.
|
||||
self.assertTrue('DELETED' in self.STATUS)
|
||||
|
||||
def test_doesnt_contain_value(self):
|
||||
def test_doesnt_contain_value(self) -> None:
|
||||
self.assertFalse('UNPUBLISHED' in self.STATUS)
|
||||
|
||||
def test_doesnt_contain_display_value(self):
|
||||
def test_doesnt_contain_display_value(self) -> None:
|
||||
self.assertFalse('is draft' in self.STATUS)
|
||||
|
||||
def test_composability(self):
|
||||
def test_composability(self) -> None:
|
||||
self.assertEqual(
|
||||
Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'),
|
||||
self.STATUS
|
||||
|
|
@ -154,11 +167,17 @@ class LabelChoicesTests(ChoicesTests):
|
|||
self.STATUS
|
||||
)
|
||||
|
||||
def test_option_groups(self):
|
||||
c = Choices(
|
||||
('group a', [(1, 'one'), (2, 'two')]),
|
||||
['group b', ((3, 'three'),)]
|
||||
)
|
||||
def test_option_groups(self) -> None:
|
||||
if TYPE_CHECKING:
|
||||
c = Choices[int](
|
||||
('group a', [(1, 'one'), (2, 'two')]),
|
||||
('group b', ((3, 'three'),))
|
||||
)
|
||||
else:
|
||||
c = Choices(
|
||||
('group a', [(1, 'one'), (2, 'two')]),
|
||||
['group b', ((3, 'three'),)]
|
||||
)
|
||||
self.assertEqual(
|
||||
list(c),
|
||||
[
|
||||
|
|
@ -168,65 +187,65 @@ class LabelChoicesTests(ChoicesTests):
|
|||
)
|
||||
|
||||
|
||||
class IdentifierChoicesTests(ChoicesTests):
|
||||
def setUp(self):
|
||||
class IdentifierChoicesTests(TestCase, ChoicesTestsMixin[int]):
|
||||
def setUp(self) -> None:
|
||||
self.STATUS = Choices(
|
||||
(0, 'DRAFT', 'is draft'),
|
||||
(1, 'PUBLISHED', 'is published'),
|
||||
(2, 'DELETED', 'is deleted'))
|
||||
|
||||
def test_iteration(self):
|
||||
def test_iteration(self) -> None:
|
||||
self.assertEqual(tuple(self.STATUS), (
|
||||
(0, 'is draft'),
|
||||
(1, 'is published'),
|
||||
(2, 'is deleted'),
|
||||
))
|
||||
|
||||
def test_reversed(self):
|
||||
def test_reversed(self) -> None:
|
||||
self.assertEqual(tuple(reversed(self.STATUS)), (
|
||||
(2, 'is deleted'),
|
||||
(1, 'is published'),
|
||||
(0, 'is draft'),
|
||||
))
|
||||
|
||||
def test_indexing(self):
|
||||
def test_indexing(self) -> None:
|
||||
self.assertEqual(self.STATUS[1], 'is published')
|
||||
|
||||
def test_getattr(self):
|
||||
def test_getattr(self) -> None:
|
||||
self.assertEqual(self.STATUS.DRAFT, 0)
|
||||
|
||||
def test_len(self):
|
||||
def test_len(self) -> None:
|
||||
self.assertEqual(len(self.STATUS), 3)
|
||||
|
||||
def test_repr(self):
|
||||
def test_repr(self) -> None:
|
||||
self.assertEqual(repr(self.STATUS), "Choices" + repr((
|
||||
(0, 'DRAFT', 'is draft'),
|
||||
(1, 'PUBLISHED', 'is published'),
|
||||
(2, 'DELETED', 'is deleted'),
|
||||
)))
|
||||
|
||||
def test_contains_value(self):
|
||||
def test_contains_value(self) -> None:
|
||||
self.assertTrue(0 in self.STATUS)
|
||||
self.assertTrue(1 in self.STATUS)
|
||||
self.assertTrue(2 in self.STATUS)
|
||||
|
||||
def test_doesnt_contain_value(self):
|
||||
def test_doesnt_contain_value(self) -> None:
|
||||
self.assertFalse(3 in self.STATUS)
|
||||
|
||||
def test_doesnt_contain_display_value(self):
|
||||
self.assertFalse('is draft' in self.STATUS)
|
||||
def test_doesnt_contain_display_value(self) -> None:
|
||||
self.assertFalse('is draft' in self.STATUS) # type: ignore[operator]
|
||||
|
||||
def test_doesnt_contain_python_attr(self):
|
||||
self.assertFalse('PUBLISHED' in self.STATUS)
|
||||
def test_doesnt_contain_python_attr(self) -> None:
|
||||
self.assertFalse('PUBLISHED' in self.STATUS) # type: ignore[operator]
|
||||
|
||||
def test_equality(self):
|
||||
def test_equality(self) -> None:
|
||||
self.assertEqual(self.STATUS, Choices(
|
||||
(0, 'DRAFT', 'is draft'),
|
||||
(1, 'PUBLISHED', 'is published'),
|
||||
(2, 'DELETED', 'is deleted')
|
||||
))
|
||||
|
||||
def test_inequality(self):
|
||||
def test_inequality(self) -> None:
|
||||
self.assertNotEqual(self.STATUS, [
|
||||
(0, 'DRAFT', 'is draft'),
|
||||
(1, 'PUBLISHED', 'is published'),
|
||||
|
|
@ -234,7 +253,7 @@ class IdentifierChoicesTests(ChoicesTests):
|
|||
])
|
||||
self.assertNotEqual(self.STATUS, Choices('DRAFT'))
|
||||
|
||||
def test_composability(self):
|
||||
def test_composability(self) -> None:
|
||||
self.assertEqual(
|
||||
Choices(
|
||||
(0, 'DRAFT', 'is draft'),
|
||||
|
|
@ -265,11 +284,17 @@ class IdentifierChoicesTests(ChoicesTests):
|
|||
self.STATUS
|
||||
)
|
||||
|
||||
def test_option_groups(self):
|
||||
c = Choices(
|
||||
('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]),
|
||||
['group b', ((3, 'THREE', 'three'),)]
|
||||
)
|
||||
def test_option_groups(self) -> None:
|
||||
if TYPE_CHECKING:
|
||||
c = Choices[int](
|
||||
('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]),
|
||||
('group b', ((3, 'THREE', 'three'),))
|
||||
)
|
||||
else:
|
||||
c = Choices(
|
||||
('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]),
|
||||
['group b', ((3, 'THREE', 'three'),)]
|
||||
)
|
||||
self.assertEqual(
|
||||
list(c),
|
||||
[
|
||||
|
|
@ -281,26 +306,26 @@ class IdentifierChoicesTests(ChoicesTests):
|
|||
|
||||
class SubsetChoicesTest(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.choices = Choices(
|
||||
def setUp(self) -> None:
|
||||
self.choices = Choices[int](
|
||||
(0, 'a', 'A'),
|
||||
(1, 'b', 'B'),
|
||||
)
|
||||
|
||||
def test_nonexistent_identifiers_raise(self):
|
||||
def test_nonexistent_identifiers_raise(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
self.choices.subset('a', 'c')
|
||||
|
||||
def test_solo_nonexistent_identifiers_raise(self):
|
||||
def test_solo_nonexistent_identifiers_raise(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
self.choices.subset('c')
|
||||
|
||||
def test_empty_subset_passes(self):
|
||||
def test_empty_subset_passes(self) -> None:
|
||||
subset = self.choices.subset()
|
||||
|
||||
self.assertEqual(subset, Choices())
|
||||
|
||||
def test_subset_returns_correct_subset(self):
|
||||
def test_subset_returns_correct_subset(self) -> None:
|
||||
subset = self.choices.subset('a')
|
||||
|
||||
self.assertEqual(subset, Choices((0, 'a', 'A')))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import models
|
||||
|
|
@ -7,7 +9,7 @@ from django.db.models.fields.files import FieldFile
|
|||
from django.test import TestCase
|
||||
|
||||
from model_utils import FieldTracker
|
||||
from model_utils.tracker import DescriptorWrapper
|
||||
from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker
|
||||
from tests.models import (
|
||||
InheritedModelTracked,
|
||||
InheritedTracked,
|
||||
|
|
@ -26,12 +28,18 @@ from tests.models import (
|
|||
TrackerTimeStamped,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
MixinBase = TestCase
|
||||
else:
|
||||
MixinBase = object
|
||||
|
||||
class FieldTrackerTestCase(TestCase):
|
||||
|
||||
tracker = None
|
||||
class FieldTrackerMixin(MixinBase):
|
||||
|
||||
def assertHasChanged(self, *, tracker=None, **kwargs):
|
||||
tracker: FieldInstanceTracker
|
||||
instance: models.Model
|
||||
|
||||
def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
|
||||
if tracker is None:
|
||||
tracker = self.tracker
|
||||
for field, value in kwargs.items():
|
||||
|
|
@ -41,49 +49,57 @@ class FieldTrackerTestCase(TestCase):
|
|||
else:
|
||||
self.assertEqual(tracker.has_changed(field), value)
|
||||
|
||||
def assertPrevious(self, *, tracker=None, **kwargs):
|
||||
def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
|
||||
if tracker is None:
|
||||
tracker = self.tracker
|
||||
for field, value in kwargs.items():
|
||||
self.assertEqual(tracker.previous(field), value)
|
||||
|
||||
def assertChanged(self, *, tracker=None, **kwargs):
|
||||
def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
|
||||
if tracker is None:
|
||||
tracker = self.tracker
|
||||
self.assertEqual(tracker.changed(), kwargs)
|
||||
|
||||
def assertCurrent(self, *, tracker=None, **kwargs):
|
||||
def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
|
||||
if tracker is None:
|
||||
tracker = self.tracker
|
||||
self.assertEqual(tracker.current(), kwargs)
|
||||
|
||||
def update_instance(self, **kwargs):
|
||||
def update_instance(self, **kwargs: Any) -> None:
|
||||
for field, value in kwargs.items():
|
||||
setattr(self.instance, field, value)
|
||||
self.instance.save()
|
||||
|
||||
|
||||
class FieldTrackerCommonTests:
|
||||
class FieldTrackerCommonMixin(FieldTrackerMixin):
|
||||
|
||||
def test_pre_save_previous(self):
|
||||
instance: (
|
||||
Tracked | TrackedNotDefault | TrackedMultiple
|
||||
| ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple
|
||||
| TrackedAbstract
|
||||
)
|
||||
|
||||
def test_pre_save_previous(self) -> None:
|
||||
self.assertPrevious(name=None, number=None)
|
||||
self.instance.name = 'new age'
|
||||
self.instance.number = 8
|
||||
self.assertPrevious(name=None, number=None)
|
||||
|
||||
|
||||
class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||
class FieldTrackerTests(FieldTrackerCommonMixin, TestCase):
|
||||
|
||||
tracked_class: type[models.Model] = Tracked
|
||||
tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked
|
||||
instance: Tracked | ModelTracked | TrackedAbstract
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
self.tracker = self.instance.tracker
|
||||
|
||||
def test_descriptor(self):
|
||||
self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker))
|
||||
def test_descriptor(self) -> None:
|
||||
tracker = self.tracked_class.tracker
|
||||
self.assertTrue(isinstance(tracker, FieldTracker))
|
||||
|
||||
def test_pre_save_changed(self):
|
||||
def test_pre_save_changed(self) -> None:
|
||||
self.assertChanged(name=None)
|
||||
self.instance.name = 'new age'
|
||||
self.assertChanged(name=None)
|
||||
|
|
@ -94,7 +110,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.mutable = [1, 2, 3]
|
||||
self.assertChanged(name=None, number=None, mutable=None)
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
def test_pre_save_has_changed(self) -> None:
|
||||
self.assertHasChanged(name=True, number=False, mutable=False)
|
||||
self.instance.name = 'new age'
|
||||
self.assertHasChanged(name=True, number=False, mutable=False)
|
||||
|
|
@ -103,12 +119,12 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.mutable = [1, 2, 3]
|
||||
self.assertHasChanged(name=True, number=True, mutable=True)
|
||||
|
||||
def test_save_with_args(self):
|
||||
def test_save_with_args(self) -> None:
|
||||
self.instance.number = 1
|
||||
self.instance.save(False, False, None, None)
|
||||
self.assertChanged()
|
||||
|
||||
def test_first_save(self):
|
||||
def test_first_save(self) -> None:
|
||||
self.assertHasChanged(name=True, number=False, mutable=False)
|
||||
self.assertPrevious(name=None, number=None, mutable=None)
|
||||
self.assertCurrent(name='', number=None, id=None, mutable=None)
|
||||
|
|
@ -129,7 +145,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
with self.assertRaises(ValueError):
|
||||
self.instance.save(update_fields=['number'])
|
||||
|
||||
def test_post_save_has_changed(self):
|
||||
def test_post_save_has_changed(self) -> None:
|
||||
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||
self.assertHasChanged(name=False, number=False, mutable=False)
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -141,14 +157,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.name = 'retro'
|
||||
self.assertHasChanged(name=False, number=True, mutable=True)
|
||||
|
||||
def test_post_save_previous(self):
|
||||
def test_post_save_previous(self) -> None:
|
||||
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||
self.instance.name = 'new age'
|
||||
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
|
||||
self.instance.mutable[1] = 4
|
||||
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
|
||||
|
||||
def test_post_save_changed(self):
|
||||
def test_post_save_changed(self) -> None:
|
||||
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -162,7 +178,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.mutable = [1, 2, 3]
|
||||
self.assertChanged(number=4)
|
||||
|
||||
def test_current(self):
|
||||
def test_current(self) -> None:
|
||||
self.assertCurrent(id=None, name='', number=None, mutable=None)
|
||||
self.instance.name = 'new age'
|
||||
self.assertCurrent(id=None, name='new age', number=None, mutable=None)
|
||||
|
|
@ -175,7 +191,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.save()
|
||||
self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3])
|
||||
|
||||
def test_update_fields(self):
|
||||
def test_update_fields(self) -> None:
|
||||
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -198,7 +214,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.assertEqual(in_db.number, self.instance.number)
|
||||
self.assertEqual(in_db.mutable, self.instance.mutable)
|
||||
|
||||
def test_refresh_from_db(self):
|
||||
def test_refresh_from_db(self) -> None:
|
||||
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||
self.tracked_class.objects.filter(pk=self.instance.pk).update(
|
||||
name='new age', number=8, mutable=[3, 2, 1])
|
||||
|
|
@ -214,11 +230,12 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.refresh_from_db()
|
||||
self.assertChanged()
|
||||
|
||||
def test_with_deferred(self):
|
||||
def test_with_deferred(self) -> None:
|
||||
self.instance.name = 'new age'
|
||||
self.instance.number = 1
|
||||
self.instance.save()
|
||||
item = self.tracked_class.objects.only('name').first()
|
||||
assert item is not None
|
||||
self.assertTrue(item.get_deferred_fields())
|
||||
|
||||
# has_changed() returns False for deferred fields, without un-deferring them.
|
||||
|
|
@ -234,6 +251,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
|
||||
# examining a deferred field un-defers it
|
||||
item = self.tracked_class.objects.only('name').first()
|
||||
assert item is not None
|
||||
self.assertEqual(item.number, 1)
|
||||
self.assertTrue('number' not in item.get_deferred_fields())
|
||||
self.assertEqual(item.tracker.previous('number'), 1)
|
||||
|
|
@ -252,6 +270,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
if self.tracked_class == Tracked:
|
||||
|
||||
item = self.tracked_class.objects.only('name').first()
|
||||
assert item is not None
|
||||
item.number = 2
|
||||
|
||||
# previous() fetches correct value from database after deferred field is assigned
|
||||
|
|
@ -268,7 +287,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
|
||||
class FieldTrackerMultipleInstancesTests(TestCase):
|
||||
|
||||
def test_with_deferred_fields_access_multiple(self):
|
||||
def test_with_deferred_fields_access_multiple(self) -> None:
|
||||
Tracked.objects.create(pk=1, name='foo', number=1)
|
||||
Tracked.objects.create(pk=2, name='bar', number=2)
|
||||
|
||||
|
|
@ -278,16 +297,16 @@ class FieldTrackerMultipleInstancesTests(TestCase):
|
|||
instance.name
|
||||
|
||||
|
||||
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
||||
FieldTrackerCommonTests):
|
||||
class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase):
|
||||
|
||||
tracked_class: type[models.Model] = TrackedNotDefault
|
||||
tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault
|
||||
instance: TrackedNotDefault | ModelTrackedNotDefault
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
self.tracker = self.instance.name_tracker
|
||||
|
||||
def test_pre_save_changed(self):
|
||||
def test_pre_save_changed(self) -> None:
|
||||
self.assertChanged(name=None)
|
||||
self.instance.name = 'new age'
|
||||
self.assertChanged(name=None)
|
||||
|
|
@ -296,7 +315,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
|||
self.instance.name = ''
|
||||
self.assertChanged(name=None)
|
||||
|
||||
def test_first_save(self):
|
||||
def test_first_save(self) -> None:
|
||||
self.assertHasChanged(name=True, number=None)
|
||||
self.assertPrevious(name=None, number=None)
|
||||
self.assertCurrent(name='')
|
||||
|
|
@ -308,14 +327,14 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
|||
self.assertCurrent(name='retro')
|
||||
self.assertChanged(name=None)
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
def test_pre_save_has_changed(self) -> None:
|
||||
self.assertHasChanged(name=True, number=None)
|
||||
self.instance.name = 'new age'
|
||||
self.assertHasChanged(name=True, number=None)
|
||||
self.instance.number = 7
|
||||
self.assertHasChanged(name=True, number=None)
|
||||
|
||||
def test_post_save_has_changed(self):
|
||||
def test_post_save_has_changed(self) -> None:
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.assertHasChanged(name=False, number=None)
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -325,12 +344,12 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
|||
self.instance.name = 'retro'
|
||||
self.assertHasChanged(name=False, number=None)
|
||||
|
||||
def test_post_save_previous(self):
|
||||
def test_post_save_previous(self) -> None:
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.instance.name = 'new age'
|
||||
self.assertPrevious(name='retro', number=None)
|
||||
|
||||
def test_post_save_changed(self):
|
||||
def test_post_save_changed(self) -> None:
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -340,7 +359,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
|||
self.instance.name = 'retro'
|
||||
self.assertChanged()
|
||||
|
||||
def test_current(self):
|
||||
def test_current(self) -> None:
|
||||
self.assertCurrent(name='')
|
||||
self.instance.name = 'new age'
|
||||
self.assertCurrent(name='new age')
|
||||
|
|
@ -349,7 +368,7 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
|||
self.instance.save()
|
||||
self.assertCurrent(name='new age')
|
||||
|
||||
def test_update_fields(self):
|
||||
def test_update_fields(self) -> None:
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -358,15 +377,16 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
|||
self.assertChanged()
|
||||
|
||||
|
||||
class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
|
||||
class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase):
|
||||
|
||||
tracked_class = TrackedNonFieldAttr
|
||||
instance: TrackedNonFieldAttr
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
self.tracker = self.instance.tracker
|
||||
|
||||
def test_previous(self):
|
||||
def test_previous(self) -> None:
|
||||
self.assertPrevious(rounded=None)
|
||||
self.instance.number = 7.5
|
||||
self.assertPrevious(rounded=None)
|
||||
|
|
@ -377,7 +397,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
|
|||
self.instance.save()
|
||||
self.assertPrevious(rounded=7)
|
||||
|
||||
def test_has_changed(self):
|
||||
def test_has_changed(self) -> None:
|
||||
self.assertHasChanged(rounded=False)
|
||||
self.instance.number = 7.5
|
||||
self.assertHasChanged(rounded=True)
|
||||
|
|
@ -388,7 +408,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
|
|||
self.instance.number = 7.8
|
||||
self.assertHasChanged(rounded=False)
|
||||
|
||||
def test_changed(self):
|
||||
def test_changed(self) -> None:
|
||||
self.assertChanged()
|
||||
self.instance.number = 7.5
|
||||
self.assertPrevious(rounded=None)
|
||||
|
|
@ -401,7 +421,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
|
|||
self.instance.save()
|
||||
self.assertPrevious()
|
||||
|
||||
def test_current(self):
|
||||
def test_current(self) -> None:
|
||||
self.assertCurrent(rounded=None)
|
||||
self.instance.number = 7.5
|
||||
self.assertCurrent(rounded=8)
|
||||
|
|
@ -409,17 +429,17 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
|
|||
self.assertCurrent(rounded=8)
|
||||
|
||||
|
||||
class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
||||
FieldTrackerCommonTests):
|
||||
class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase):
|
||||
|
||||
tracked_class: type[models.Model] = TrackedMultiple
|
||||
tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple
|
||||
instance: TrackedMultiple | ModelTrackedMultiple
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
self.trackers = [self.instance.name_tracker,
|
||||
self.instance.number_tracker]
|
||||
|
||||
def test_pre_save_changed(self):
|
||||
def test_pre_save_changed(self) -> None:
|
||||
self.tracker = self.instance.name_tracker
|
||||
self.assertChanged(name=None)
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -435,7 +455,7 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
|||
self.instance.number = 8
|
||||
self.assertChanged(number=None)
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
def test_pre_save_has_changed(self) -> None:
|
||||
self.tracker = self.instance.name_tracker
|
||||
self.assertHasChanged(name=True, number=None)
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -445,12 +465,12 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
|||
self.instance.name = 'new age'
|
||||
self.assertHasChanged(name=None, number=False)
|
||||
|
||||
def test_pre_save_previous(self):
|
||||
def test_pre_save_previous(self) -> None:
|
||||
for tracker in self.trackers:
|
||||
self.tracker = tracker
|
||||
super().test_pre_save_previous()
|
||||
|
||||
def test_post_save_has_changed(self):
|
||||
def test_post_save_has_changed(self) -> None:
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.assertHasChanged(tracker=self.trackers[0], name=False, number=None)
|
||||
self.assertHasChanged(tracker=self.trackers[1], name=None, number=False)
|
||||
|
|
@ -465,14 +485,14 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
|||
self.assertHasChanged(tracker=self.trackers[0], name=False, number=None)
|
||||
self.assertHasChanged(tracker=self.trackers[1], name=None, number=False)
|
||||
|
||||
def test_post_save_previous(self):
|
||||
def test_post_save_previous(self) -> None:
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.instance.name = 'new age'
|
||||
self.instance.number = 8
|
||||
self.assertPrevious(tracker=self.trackers[0], name='retro', number=None)
|
||||
self.assertPrevious(tracker=self.trackers[1], name=None, number=4)
|
||||
|
||||
def test_post_save_changed(self):
|
||||
def test_post_save_changed(self) -> None:
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.assertChanged(tracker=self.trackers[0])
|
||||
self.assertChanged(tracker=self.trackers[1])
|
||||
|
|
@ -487,7 +507,7 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
|||
self.assertChanged(tracker=self.trackers[0])
|
||||
self.assertChanged(tracker=self.trackers[1])
|
||||
|
||||
def test_current(self):
|
||||
def test_current(self) -> None:
|
||||
self.assertCurrent(tracker=self.trackers[0], name='')
|
||||
self.assertCurrent(tracker=self.trackers[1], number=None)
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -501,88 +521,97 @@ class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
|||
self.assertCurrent(tracker=self.trackers[1], number=8)
|
||||
|
||||
|
||||
class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
|
||||
class FieldTrackerForeignKeyMixin(FieldTrackerMixin):
|
||||
|
||||
fk_class: type[models.Model] = Tracked
|
||||
tracked_class: type[models.Model] = TrackedFK
|
||||
fk_class: type[Tracked | ModelTracked]
|
||||
tracked_class: type[TrackedFK | ModelTrackedFK]
|
||||
instance: TrackedFK | ModelTrackedFK
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.old_fk = self.fk_class.objects.create(number=8)
|
||||
self.instance = self.tracked_class.objects.create(fk=self.old_fk)
|
||||
self.instance = self.tracked_class.objects.create(fk=self.old_fk) # type: ignore[misc]
|
||||
|
||||
def test_default(self):
|
||||
def test_default(self) -> None:
|
||||
self.tracker = self.instance.tracker
|
||||
self.assertChanged()
|
||||
self.assertPrevious()
|
||||
self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment]
|
||||
self.assertChanged(fk_id=self.old_fk.id)
|
||||
self.assertPrevious(fk_id=self.old_fk.id)
|
||||
self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id)
|
||||
|
||||
def test_custom(self):
|
||||
def test_custom(self) -> None:
|
||||
self.tracker = self.instance.custom_tracker
|
||||
self.assertChanged()
|
||||
self.assertPrevious()
|
||||
self.assertCurrent(fk_id=self.old_fk.id)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment]
|
||||
self.assertChanged(fk_id=self.old_fk.id)
|
||||
self.assertPrevious(fk_id=self.old_fk.id)
|
||||
self.assertCurrent(fk_id=self.instance.fk_id)
|
||||
|
||||
def test_custom_without_id(self):
|
||||
def test_custom_without_id(self) -> None:
|
||||
with self.assertNumQueries(1):
|
||||
self.tracked_class.objects.get()
|
||||
self.tracker = self.instance.custom_tracker_without_id
|
||||
self.assertChanged()
|
||||
self.assertPrevious()
|
||||
self.assertCurrent(fk=self.old_fk.id)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8)
|
||||
self.instance.fk = self.fk_class.objects.create(number=8) # type: ignore[assignment]
|
||||
self.assertChanged(fk=self.old_fk.id)
|
||||
self.assertPrevious(fk=self.old_fk.id)
|
||||
self.assertCurrent(fk=self.instance.fk_id)
|
||||
|
||||
|
||||
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase):
|
||||
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError."""
|
||||
class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
|
||||
|
||||
fk_class = Tracked
|
||||
tracked_class = TrackedFK
|
||||
|
||||
def setUp(self):
|
||||
|
||||
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase):
|
||||
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError."""
|
||||
|
||||
fk_class = Tracked
|
||||
tracked_class = TrackedFK
|
||||
instance: TrackedFK
|
||||
|
||||
def setUp(self) -> None:
|
||||
model_tracked = self.fk_class.objects.create(name="", number=0)
|
||||
self.instance = self.tracked_class.objects.create(fk=model_tracked)
|
||||
|
||||
def test_default(self):
|
||||
def test_default(self) -> None:
|
||||
self.tracker = self.instance.tracker
|
||||
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
|
||||
|
||||
def test_custom(self):
|
||||
def test_custom(self) -> None:
|
||||
self.tracker = self.instance.custom_tracker
|
||||
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
|
||||
|
||||
def test_custom_without_id(self):
|
||||
def test_custom_without_id(self) -> None:
|
||||
self.tracker = self.instance.custom_tracker_without_id
|
||||
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
|
||||
|
||||
|
||||
class FieldTrackerTimeStampedTests(FieldTrackerTestCase):
|
||||
class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase):
|
||||
|
||||
fk_class = Tracked
|
||||
tracked_class = TrackerTimeStamped
|
||||
instance: TrackerTimeStamped
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class.objects.create(name='old', number=1)
|
||||
self.tracker = self.instance.tracker
|
||||
|
||||
def test_set_modified_on_save(self):
|
||||
def test_set_modified_on_save(self) -> None:
|
||||
old_modified = self.instance.modified
|
||||
self.instance.name = 'new'
|
||||
self.instance.save()
|
||||
self.assertGreater(self.instance.modified, old_modified)
|
||||
self.assertChanged()
|
||||
|
||||
def test_set_modified_on_save_update_fields(self):
|
||||
def test_set_modified_on_save_update_fields(self) -> None:
|
||||
old_modified = self.instance.modified
|
||||
self.instance.name = 'new'
|
||||
self.instance.save(update_fields=('name',))
|
||||
|
|
@ -594,7 +623,7 @@ class InheritedFieldTrackerTests(FieldTrackerTests):
|
|||
|
||||
tracked_class = InheritedTracked
|
||||
|
||||
def test_child_fields_not_tracked(self):
|
||||
def test_child_fields_not_tracked(self) -> None:
|
||||
self.name2 = 'test'
|
||||
self.assertEqual(self.tracker.previous('name2'), None)
|
||||
self.assertRaises(FieldError, self.tracker.has_changed, 'name2')
|
||||
|
|
@ -605,17 +634,18 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests):
|
|||
tracked_class = InheritedTrackedFK
|
||||
|
||||
|
||||
class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
||||
class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase):
|
||||
|
||||
tracked_class = TrackedFileField
|
||||
instance: TrackedFileField
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = self.tracked_class()
|
||||
self.tracker = self.instance.tracker
|
||||
self.some_file = 'something.txt'
|
||||
self.another_file = 'another.txt'
|
||||
|
||||
def test_saved_data_without_instance(self):
|
||||
def test_saved_data_without_instance(self) -> None:
|
||||
"""
|
||||
Tests that instance won't get copied by the Field Tracker.
|
||||
|
||||
|
|
@ -629,27 +659,27 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
|||
self.assertEqual(self.tracker.saved_data, {})
|
||||
self.update_instance(some_file=self.some_file)
|
||||
field_file_copy = self.tracker.saved_data.get('some_file')
|
||||
self.assertIsNotNone(field_file_copy)
|
||||
assert field_file_copy is not None
|
||||
self.assertEqual(field_file_copy.__getstate__().get('instance'), None)
|
||||
self.assertEqual(self.instance.some_file.instance, self.instance)
|
||||
self.assertIsInstance(self.instance.some_file, FieldFile)
|
||||
|
||||
def test_pre_save_changed(self):
|
||||
def test_pre_save_changed(self) -> None:
|
||||
self.assertChanged(some_file=None)
|
||||
self.instance.some_file = self.some_file
|
||||
self.assertChanged(some_file=None)
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
def test_pre_save_has_changed(self) -> None:
|
||||
self.assertHasChanged(some_file=True)
|
||||
self.instance.some_file = self.some_file
|
||||
self.assertHasChanged(some_file=True)
|
||||
|
||||
def test_pre_save_previous(self):
|
||||
def test_pre_save_previous(self) -> None:
|
||||
self.assertPrevious(some_file=None)
|
||||
self.instance.some_file = self.some_file
|
||||
self.assertPrevious(some_file=None)
|
||||
|
||||
def test_post_save_changed(self):
|
||||
def test_post_save_changed(self) -> None:
|
||||
self.update_instance(some_file=self.some_file)
|
||||
self.assertChanged()
|
||||
previous_file = self.instance.some_file
|
||||
|
|
@ -667,7 +697,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
|||
some_file=previous_file,
|
||||
)
|
||||
|
||||
def test_post_save_has_changed(self):
|
||||
def test_post_save_has_changed(self) -> None:
|
||||
self.update_instance(some_file=self.some_file)
|
||||
self.assertHasChanged(some_file=False)
|
||||
self.instance.some_file = self.another_file
|
||||
|
|
@ -687,7 +717,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
|||
some_file=True,
|
||||
)
|
||||
|
||||
def test_post_save_previous(self):
|
||||
def test_post_save_previous(self) -> None:
|
||||
self.update_instance(some_file=self.some_file)
|
||||
previous_file = self.instance.some_file
|
||||
self.instance.some_file = self.another_file
|
||||
|
|
@ -707,7 +737,7 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
|||
some_file=previous_file,
|
||||
)
|
||||
|
||||
def test_current(self):
|
||||
def test_current(self) -> None:
|
||||
self.assertCurrent(some_file=self.instance.some_file, id=None)
|
||||
self.instance.some_file = self.some_file
|
||||
self.assertCurrent(some_file=self.instance.some_file, id=None)
|
||||
|
|
@ -730,9 +760,10 @@ class FieldTrackerFileFieldTests(FieldTrackerTestCase):
|
|||
|
||||
class ModelTrackerTests(FieldTrackerTests):
|
||||
|
||||
tracked_class: type[models.Model] = ModelTracked
|
||||
tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked
|
||||
instance: ModelTracked
|
||||
|
||||
def test_cache_compatible(self):
|
||||
def test_cache_compatible(self) -> None:
|
||||
cache.set('key', self.instance)
|
||||
instance = cache.get('key')
|
||||
instance.number = 1
|
||||
|
|
@ -742,7 +773,7 @@ class ModelTrackerTests(FieldTrackerTests):
|
|||
instance.number = 2
|
||||
self.assertHasChanged(number=True)
|
||||
|
||||
def test_pre_save_changed(self):
|
||||
def test_pre_save_changed(self) -> None:
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
self.assertChanged()
|
||||
|
|
@ -753,7 +784,7 @@ class ModelTrackerTests(FieldTrackerTests):
|
|||
self.instance.mutable = [1, 2, 3]
|
||||
self.assertChanged()
|
||||
|
||||
def test_first_save(self):
|
||||
def test_first_save(self) -> None:
|
||||
self.assertHasChanged(name=True, number=True, mutable=True)
|
||||
self.assertPrevious(name=None, number=None, mutable=None)
|
||||
self.assertCurrent(name='', number=None, id=None, mutable=None)
|
||||
|
|
@ -774,7 +805,7 @@ class ModelTrackerTests(FieldTrackerTests):
|
|||
with self.assertRaises(ValueError):
|
||||
self.instance.save(update_fields=['number'])
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
def test_pre_save_has_changed(self) -> None:
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
self.instance.name = 'new age'
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
|
|
@ -786,7 +817,7 @@ class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):
|
|||
|
||||
tracked_class = ModelTrackedNotDefault
|
||||
|
||||
def test_first_save(self):
|
||||
def test_first_save(self) -> None:
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
self.assertPrevious(name=None, number=None)
|
||||
self.assertCurrent(name='')
|
||||
|
|
@ -798,14 +829,14 @@ class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests):
|
|||
self.assertCurrent(name='retro')
|
||||
self.assertChanged()
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
def test_pre_save_has_changed(self) -> None:
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
self.instance.name = 'new age'
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
self.instance.number = 7
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
|
||||
def test_pre_save_changed(self):
|
||||
def test_pre_save_changed(self) -> None:
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
self.assertChanged()
|
||||
|
|
@ -819,7 +850,7 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
|
|||
|
||||
tracked_class = ModelTrackedMultiple
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
def test_pre_save_has_changed(self) -> None:
|
||||
self.tracker = self.instance.name_tracker
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -829,7 +860,7 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
|
|||
self.instance.name = 'new age'
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
|
||||
def test_pre_save_changed(self):
|
||||
def test_pre_save_changed(self) -> None:
|
||||
self.tracker = self.instance.name_tracker
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
|
|
@ -846,12 +877,13 @@ class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests):
|
|||
self.assertChanged()
|
||||
|
||||
|
||||
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):
|
||||
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
|
||||
|
||||
fk_class = ModelTracked
|
||||
tracked_class = ModelTrackedFK
|
||||
instance: ModelTrackedFK
|
||||
|
||||
def test_custom_without_id(self):
|
||||
def test_custom_without_id(self) -> None:
|
||||
with self.assertNumQueries(2):
|
||||
self.tracked_class.objects.get()
|
||||
self.tracker = self.instance.custom_tracker_without_id
|
||||
|
|
@ -869,7 +901,7 @@ class InheritedModelTrackerTests(ModelTrackerTests):
|
|||
|
||||
tracked_class = InheritedModelTracked
|
||||
|
||||
def test_child_fields_not_tracked(self):
|
||||
def test_child_fields_not_tracked(self) -> None:
|
||||
self.name2 = 'test'
|
||||
self.assertEqual(self.tracker.previous('name2'), None)
|
||||
self.assertTrue(self.tracker.has_changed('name2'))
|
||||
|
|
@ -882,19 +914,19 @@ class AbstractModelTrackerTests(ModelTrackerTests):
|
|||
|
||||
class TrackerContextDecoratorTests(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = Tracked.objects.create(number=1)
|
||||
self.tracker = self.instance.tracker
|
||||
|
||||
def assertChanged(self, *fields):
|
||||
def assertChanged(self, *fields: str) -> None:
|
||||
for f in fields:
|
||||
self.assertTrue(self.tracker.has_changed(f))
|
||||
|
||||
def assertNotChanged(self, *fields):
|
||||
def assertNotChanged(self, *fields: str) -> None:
|
||||
for f in fields:
|
||||
self.assertFalse(self.tracker.has_changed(f))
|
||||
|
||||
def test_context_manager(self):
|
||||
def test_context_manager(self) -> None:
|
||||
with self.tracker:
|
||||
with self.tracker:
|
||||
self.instance.name = 'new'
|
||||
|
|
@ -905,7 +937,7 @@ class TrackerContextDecoratorTests(TestCase):
|
|||
|
||||
self.assertNotChanged('name')
|
||||
|
||||
def test_context_manager_fields(self):
|
||||
def test_context_manager_fields(self) -> None:
|
||||
with self.tracker('number'):
|
||||
with self.tracker('number', 'name'):
|
||||
self.instance.name = 'new'
|
||||
|
|
@ -918,10 +950,10 @@ class TrackerContextDecoratorTests(TestCase):
|
|||
|
||||
self.assertNotChanged('number', 'name')
|
||||
|
||||
def test_tracker_decorator(self):
|
||||
def test_tracker_decorator(self) -> None:
|
||||
|
||||
@Tracked.tracker
|
||||
def tracked_method(obj):
|
||||
def tracked_method(obj: Tracked) -> None:
|
||||
obj.name = 'new'
|
||||
self.assertChanged('name')
|
||||
|
||||
|
|
@ -929,10 +961,10 @@ class TrackerContextDecoratorTests(TestCase):
|
|||
|
||||
self.assertNotChanged('name')
|
||||
|
||||
def test_tracker_decorator_fields(self):
|
||||
def test_tracker_decorator_fields(self) -> None:
|
||||
|
||||
@Tracked.tracker(fields=['name'])
|
||||
def tracked_method(obj):
|
||||
def tracked_method(obj: Tracked) -> None:
|
||||
obj.name = 'new'
|
||||
obj.number += 1
|
||||
self.assertChanged('name', 'number')
|
||||
|
|
@ -942,7 +974,7 @@ class TrackerContextDecoratorTests(TestCase):
|
|||
self.assertChanged('number')
|
||||
self.assertNotChanged('name')
|
||||
|
||||
def test_tracker_context_with_save(self):
|
||||
def test_tracker_context_with_save(self) -> None:
|
||||
|
||||
with self.tracker:
|
||||
self.instance.name = 'new'
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import time_machine
|
||||
|
|
@ -8,33 +10,33 @@ from tests.models import DoubleMonitored, Monitored, MonitorWhen, MonitorWhenEmp
|
|||
|
||||
|
||||
class MonitorFieldTests(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)):
|
||||
self.instance = Monitored(name='Charlie')
|
||||
self.created = self.instance.name_changed
|
||||
|
||||
def test_save_no_change(self):
|
||||
def test_save_no_change(self) -> None:
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, self.created)
|
||||
|
||||
def test_save_changed(self):
|
||||
def test_save_changed(self) -> None:
|
||||
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
|
||||
self.instance.name = 'Maria'
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
|
||||
|
||||
def test_double_save(self):
|
||||
def test_double_save(self) -> None:
|
||||
self.instance.name = 'Jose'
|
||||
self.instance.save()
|
||||
changed = self.instance.name_changed
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, changed)
|
||||
|
||||
def test_no_monitor_arg(self):
|
||||
def test_no_monitor_arg(self) -> None:
|
||||
with self.assertRaises(TypeError):
|
||||
MonitorField()
|
||||
MonitorField() # type: ignore[call-arg]
|
||||
|
||||
def test_monitor_default_is_none_when_nullable(self):
|
||||
def test_monitor_default_is_none_when_nullable(self) -> None:
|
||||
self.assertIsNone(self.instance.name_changed_nullable)
|
||||
expected_datetime = datetime(2022, 1, 18, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
|
@ -49,33 +51,33 @@ class MonitorWhenFieldTests(TestCase):
|
|||
"""
|
||||
Will record changes only when name is 'Jose' or 'Maria'
|
||||
"""
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
with time_machine.travel(datetime(2016, 1, 1, 10, 0, 0, tzinfo=timezone.utc)):
|
||||
self.instance = MonitorWhen(name='Charlie')
|
||||
self.created = self.instance.name_changed
|
||||
|
||||
def test_save_no_change(self):
|
||||
def test_save_no_change(self) -> None:
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, self.created)
|
||||
|
||||
def test_save_changed_to_Jose(self):
|
||||
def test_save_changed_to_Jose(self) -> None:
|
||||
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
|
||||
self.instance.name = 'Jose'
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
|
||||
|
||||
def test_save_changed_to_Maria(self):
|
||||
def test_save_changed_to_Maria(self) -> None:
|
||||
with time_machine.travel(datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc)):
|
||||
self.instance.name = 'Maria'
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0, tzinfo=timezone.utc))
|
||||
|
||||
def test_save_changed_to_Pedro(self):
|
||||
def test_save_changed_to_Pedro(self) -> None:
|
||||
self.instance.name = 'Pedro'
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, self.created)
|
||||
|
||||
def test_double_save(self):
|
||||
def test_double_save(self) -> None:
|
||||
self.instance.name = 'Jose'
|
||||
self.instance.save()
|
||||
changed = self.instance.name_changed
|
||||
|
|
@ -87,20 +89,20 @@ class MonitorWhenEmptyFieldTests(TestCase):
|
|||
"""
|
||||
Monitor should never be updated id when is an empty list.
|
||||
"""
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = MonitorWhenEmpty(name='Charlie')
|
||||
self.created = self.instance.name_changed
|
||||
|
||||
def test_save_no_change(self):
|
||||
def test_save_no_change(self) -> None:
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, self.created)
|
||||
|
||||
def test_save_changed_to_Jose(self):
|
||||
def test_save_changed_to_Jose(self) -> None:
|
||||
self.instance.name = 'Jose'
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, self.created)
|
||||
|
||||
def test_save_changed_to_Maria(self):
|
||||
def test_save_changed_to_Maria(self) -> None:
|
||||
self.instance.name = 'Maria'
|
||||
self.instance.save()
|
||||
self.assertEqual(self.instance.name_changed, self.created)
|
||||
|
|
@ -108,18 +110,18 @@ class MonitorWhenEmptyFieldTests(TestCase):
|
|||
|
||||
class MonitorDoubleFieldTests(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
DoubleMonitored.objects.create(name='Charlie', name2='Charlie2')
|
||||
|
||||
def test_recursion_error_with_only(self):
|
||||
def test_recursion_error_with_only(self) -> None:
|
||||
# Any field passed to only() is generating a recursion error
|
||||
list(DoubleMonitored.objects.only('id'))
|
||||
|
||||
def test_recursion_error_with_defer(self):
|
||||
def test_recursion_error_with_defer(self) -> None:
|
||||
# Only monitored fields passed to defer() are failing
|
||||
list(DoubleMonitored.objects.defer('name'))
|
||||
|
||||
def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self):
|
||||
def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self) -> None:
|
||||
obj = DoubleMonitored.objects.defer('name').get(name='Charlie')
|
||||
with time_machine.travel(datetime(2016, 12, 1, tzinfo=timezone.utc)):
|
||||
obj.name = 'Charlie2'
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from tests.models import Article, SplitFieldAbstractParent
|
||||
|
|
@ -7,62 +9,62 @@ class SplitFieldTests(TestCase):
|
|||
full_text = 'summary\n\n<!-- split -->\n\nmore'
|
||||
excerpt = 'summary\n'
|
||||
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.post = Article.objects.create(
|
||||
title='example post', body=self.full_text)
|
||||
|
||||
def test_unicode_content(self):
|
||||
def test_unicode_content(self) -> None:
|
||||
self.assertEqual(str(self.post.body), self.full_text)
|
||||
|
||||
def test_excerpt(self):
|
||||
def test_excerpt(self) -> None:
|
||||
self.assertEqual(self.post.body.excerpt, self.excerpt)
|
||||
|
||||
def test_content(self):
|
||||
def test_content(self) -> None:
|
||||
self.assertEqual(self.post.body.content, self.full_text)
|
||||
|
||||
def test_has_more(self):
|
||||
def test_has_more(self) -> None:
|
||||
self.assertTrue(self.post.body.has_more)
|
||||
|
||||
def test_not_has_more(self):
|
||||
def test_not_has_more(self) -> None:
|
||||
post = Article.objects.create(title='example 2',
|
||||
body='some text\n\nsome more\n')
|
||||
self.assertFalse(post.body.has_more)
|
||||
|
||||
def test_load_back(self):
|
||||
def test_load_back(self) -> None:
|
||||
post = Article.objects.get(pk=self.post.pk)
|
||||
self.assertEqual(post.body.content, self.post.body.content)
|
||||
self.assertEqual(post.body.excerpt, self.post.body.excerpt)
|
||||
|
||||
def test_assign_to_body(self):
|
||||
def test_assign_to_body(self) -> None:
|
||||
new_text = 'different\n\n<!-- split -->\n\nother'
|
||||
self.post.body = new_text
|
||||
self.post.save()
|
||||
self.assertEqual(str(self.post.body), new_text)
|
||||
|
||||
def test_assign_to_content(self):
|
||||
def test_assign_to_content(self) -> None:
|
||||
new_text = 'different\n\n<!-- split -->\n\nother'
|
||||
self.post.body.content = new_text
|
||||
self.post.save()
|
||||
self.assertEqual(str(self.post.body), new_text)
|
||||
|
||||
def test_assign_to_excerpt(self):
|
||||
def test_assign_to_excerpt(self) -> None:
|
||||
with self.assertRaises(AttributeError):
|
||||
self.post.body.excerpt = 'this should fail'
|
||||
self.post.body.excerpt = 'this should fail' # type: ignore[misc]
|
||||
|
||||
def test_access_via_class(self):
|
||||
def test_access_via_class(self) -> None:
|
||||
with self.assertRaises(AttributeError):
|
||||
Article.body
|
||||
|
||||
def test_assign_splittext(self):
|
||||
def test_assign_splittext(self) -> None:
|
||||
a = Article(title='Some Title')
|
||||
a.body = self.post.body
|
||||
self.assertEqual(a.body.excerpt, 'summary\n')
|
||||
|
||||
def test_value_to_string(self):
|
||||
def test_value_to_string(self) -> None:
|
||||
f = self.post._meta.get_field('body')
|
||||
self.assertEqual(f.value_to_string(self.post), self.full_text)
|
||||
|
||||
def test_abstract_inheritance(self):
|
||||
def test_abstract_inheritance(self) -> None:
|
||||
class Child(SplitFieldAbstractParent):
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from model_utils.fields import StatusField
|
||||
|
|
@ -11,22 +13,22 @@ from tests.models import (
|
|||
|
||||
class StatusFieldTests(TestCase):
|
||||
|
||||
def test_status_with_default_filled(self):
|
||||
def test_status_with_default_filled(self) -> None:
|
||||
instance = StatusFieldDefaultFilled()
|
||||
self.assertEqual(instance.status, instance.STATUS.yes)
|
||||
|
||||
def test_status_with_default_not_filled(self):
|
||||
def test_status_with_default_not_filled(self) -> None:
|
||||
instance = StatusFieldDefaultNotFilled()
|
||||
self.assertEqual(instance.status, instance.STATUS.no)
|
||||
|
||||
def test_no_check_for_status(self):
|
||||
def test_no_check_for_status(self) -> None:
|
||||
field = StatusField(no_check_for_status=True)
|
||||
# this model has no STATUS attribute, so checking for it would error
|
||||
field.prepare_class(Article)
|
||||
|
||||
def test_get_status_display(self):
|
||||
def test_get_status_display(self) -> None:
|
||||
instance = StatusFieldDefaultFilled()
|
||||
self.assertEqual(instance.get_status_display(), "Yes")
|
||||
|
||||
def test_choices_name(self):
|
||||
def test_choices_name(self) -> None:
|
||||
StatusFieldChoicesName()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from django.db.models import NOT_PROVIDED
|
||||
|
|
@ -7,41 +9,41 @@ from model_utils.fields import UrlsafeTokenField
|
|||
|
||||
|
||||
class UrlsaftTokenFieldTests(TestCase):
|
||||
def test_editable_default(self):
|
||||
def test_editable_default(self) -> None:
|
||||
field = UrlsafeTokenField()
|
||||
self.assertFalse(field.editable)
|
||||
|
||||
def test_editable(self):
|
||||
def test_editable(self) -> None:
|
||||
field = UrlsafeTokenField(editable=True)
|
||||
self.assertTrue(field.editable)
|
||||
|
||||
def test_max_length_default(self):
|
||||
def test_max_length_default(self) -> None:
|
||||
field = UrlsafeTokenField()
|
||||
self.assertEqual(field.max_length, 128)
|
||||
|
||||
def test_max_length(self):
|
||||
def test_max_length(self) -> None:
|
||||
field = UrlsafeTokenField(max_length=256)
|
||||
self.assertEqual(field.max_length, 256)
|
||||
|
||||
def test_factory_default(self):
|
||||
def test_factory_default(self) -> None:
|
||||
field = UrlsafeTokenField()
|
||||
self.assertIsNone(field._factory)
|
||||
|
||||
def test_factory_not_callable(self):
|
||||
def test_factory_not_callable(self) -> None:
|
||||
with self.assertRaises(TypeError):
|
||||
UrlsafeTokenField(factory='INVALID')
|
||||
UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type]
|
||||
|
||||
def test_get_default(self):
|
||||
def test_get_default(self) -> None:
|
||||
field = UrlsafeTokenField()
|
||||
value = field.get_default()
|
||||
self.assertEqual(len(value), field.max_length)
|
||||
|
||||
def test_get_default_with_non_default_max_length(self):
|
||||
def test_get_default_with_non_default_max_length(self) -> None:
|
||||
field = UrlsafeTokenField(max_length=64)
|
||||
value = field.get_default()
|
||||
self.assertEqual(len(value), 64)
|
||||
|
||||
def test_get_default_with_factory(self):
|
||||
def test_get_default_with_factory(self) -> None:
|
||||
token = 'SAMPLE_TOKEN'
|
||||
factory = Mock(return_value=token)
|
||||
field = UrlsafeTokenField(factory=factory)
|
||||
|
|
@ -50,13 +52,13 @@ class UrlsaftTokenFieldTests(TestCase):
|
|||
self.assertEqual(value, token)
|
||||
factory.assert_called_once_with(field.max_length)
|
||||
|
||||
def test_no_default_param(self):
|
||||
def test_no_default_param(self) -> None:
|
||||
field = UrlsafeTokenField(default='DEFAULT')
|
||||
self.assertIs(field.default, NOT_PROVIDED)
|
||||
|
||||
def test_deconstruct(self):
|
||||
def test_factory():
|
||||
pass
|
||||
def test_deconstruct(self) -> None:
|
||||
def test_factory(max_length: int) -> str:
|
||||
assert False
|
||||
instance = UrlsafeTokenField(factory=test_factory)
|
||||
name, path, args, kwargs = instance.deconstruct()
|
||||
new_instance = UrlsafeTokenField(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
|
|
@ -8,31 +10,31 @@ from model_utils.fields import UUIDField
|
|||
|
||||
class UUIDFieldTests(TestCase):
|
||||
|
||||
def test_uuid_version_default(self):
|
||||
def test_uuid_version_default(self) -> None:
|
||||
instance = UUIDField()
|
||||
self.assertEqual(instance.default, uuid.uuid4)
|
||||
|
||||
def test_uuid_version_1(self):
|
||||
def test_uuid_version_1(self) -> None:
|
||||
instance = UUIDField(version=1)
|
||||
self.assertEqual(instance.default, uuid.uuid1)
|
||||
|
||||
def test_uuid_version_2_error(self):
|
||||
def test_uuid_version_2_error(self) -> None:
|
||||
self.assertRaises(ValidationError, UUIDField, 'version', 2)
|
||||
|
||||
def test_uuid_version_3(self):
|
||||
def test_uuid_version_3(self) -> None:
|
||||
instance = UUIDField(version=3)
|
||||
self.assertEqual(instance.default, uuid.uuid3)
|
||||
|
||||
def test_uuid_version_4(self):
|
||||
def test_uuid_version_4(self) -> None:
|
||||
instance = UUIDField(version=4)
|
||||
self.assertEqual(instance.default, uuid.uuid4)
|
||||
|
||||
def test_uuid_version_5(self):
|
||||
def test_uuid_version_5(self) -> None:
|
||||
instance = UUIDField(version=5)
|
||||
self.assertEqual(instance.default, uuid.uuid5)
|
||||
|
||||
def test_uuid_version_bellow_min(self):
|
||||
def test_uuid_version_bellow_min(self) -> None:
|
||||
self.assertRaises(ValidationError, UUIDField, 'version', 0)
|
||||
|
||||
def test_uuid_version_above_max(self):
|
||||
def test_uuid_version_above_max(self) -> None:
|
||||
self.assertRaises(ValidationError, UUIDField, 'version', 6)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.db.models import Prefetch
|
||||
from django.test import TestCase
|
||||
|
||||
|
|
@ -5,7 +7,7 @@ from tests.models import InheritanceManagerTestChild1, InheritanceManagerTestPar
|
|||
|
||||
|
||||
class InheritanceIterableTest(TestCase):
|
||||
def test_prefetch(self):
|
||||
def test_prefetch(self) -> None:
|
||||
qs = InheritanceManagerTestChild1.objects.all().prefetch_related(
|
||||
Prefetch(
|
||||
'normal_field',
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
|
||||
from model_utils.managers import InheritanceManager
|
||||
from tests.models import (
|
||||
InheritanceManagerTestChild1,
|
||||
InheritanceManagerTestChild2,
|
||||
|
|
@ -14,19 +19,22 @@ from tests.models import (
|
|||
TimeFrame,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models.fields.related_descriptors import RelatedManager
|
||||
|
||||
|
||||
class InheritanceManagerTests(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.child1 = InheritanceManagerTestChild1.objects.create()
|
||||
self.child2 = InheritanceManagerTestChild2.objects.create()
|
||||
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
|
||||
self.grandchild1_2 = \
|
||||
InheritanceManagerTestGrandChild1_2.objects.create()
|
||||
|
||||
def get_manager(self):
|
||||
def get_manager(self) -> InheritanceManager[InheritanceManagerTestParent]:
|
||||
return InheritanceManagerTestParent.objects
|
||||
|
||||
def test_normal(self):
|
||||
def test_normal(self) -> None:
|
||||
children = {
|
||||
InheritanceManagerTestParent(pk=self.child1.pk),
|
||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
||||
|
|
@ -35,14 +43,14 @@ class InheritanceManagerTests(TestCase):
|
|||
}
|
||||
self.assertEqual(set(self.get_manager().all()), children)
|
||||
|
||||
def test_select_all_subclasses(self):
|
||||
def test_select_all_subclasses(self) -> None:
|
||||
children = {self.child1, self.child2}
|
||||
children.add(self.grandchild1)
|
||||
children.add(self.grandchild1_2)
|
||||
self.assertEqual(
|
||||
set(self.get_manager().select_subclasses()), children)
|
||||
|
||||
def test_select_subclasses_invalid_relation(self):
|
||||
def test_select_subclasses_invalid_relation(self) -> None:
|
||||
"""
|
||||
If an invalid relation string is provided, we can provide the user
|
||||
with a list which is valid, rather than just have the select_related()
|
||||
|
|
@ -52,7 +60,7 @@ class InheritanceManagerTests(TestCase):
|
|||
with self.assertRaisesRegex(ValueError, regex):
|
||||
self.get_manager().select_subclasses('user')
|
||||
|
||||
def test_select_specific_subclasses(self):
|
||||
def test_select_specific_subclasses(self) -> None:
|
||||
children = {
|
||||
self.child1,
|
||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
||||
|
|
@ -67,7 +75,7 @@ class InheritanceManagerTests(TestCase):
|
|||
children,
|
||||
)
|
||||
|
||||
def test_select_specific_grandchildren(self):
|
||||
def test_select_specific_grandchildren(self) -> None:
|
||||
children = {
|
||||
InheritanceManagerTestParent(pk=self.child1.pk),
|
||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
||||
|
|
@ -83,7 +91,7 @@ class InheritanceManagerTests(TestCase):
|
|||
children,
|
||||
)
|
||||
|
||||
def test_children_and_grandchildren(self):
|
||||
def test_children_and_grandchildren(self) -> None:
|
||||
children = {
|
||||
self.child1,
|
||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
||||
|
|
@ -100,24 +108,24 @@ class InheritanceManagerTests(TestCase):
|
|||
children,
|
||||
)
|
||||
|
||||
def test_get_subclass(self):
|
||||
def test_get_subclass(self) -> None:
|
||||
self.assertEqual(
|
||||
self.get_manager().get_subclass(pk=self.child1.pk),
|
||||
self.child1)
|
||||
|
||||
def test_get_subclass_on_queryset(self):
|
||||
def test_get_subclass_on_queryset(self) -> None:
|
||||
self.assertEqual(
|
||||
self.get_manager().all().get_subclass(pk=self.child1.pk),
|
||||
self.child1)
|
||||
|
||||
def test_prior_select_related(self):
|
||||
def test_prior_select_related(self) -> None:
|
||||
with self.assertNumQueries(1):
|
||||
obj = self.get_manager().select_related(
|
||||
"inheritancemanagertestchild1").select_subclasses(
|
||||
"inheritancemanagertestchild2").get(pk=self.child1.pk)
|
||||
obj.inheritancemanagertestchild1
|
||||
|
||||
def test_manually_specifying_parent_fk_including_grandchildren(self):
|
||||
def test_manually_specifying_parent_fk_including_grandchildren(self) -> None:
|
||||
"""
|
||||
given a Model which inherits from another Model, but also declares
|
||||
the OneToOne link manually using `related_name` and `parent_link`,
|
||||
|
|
@ -148,7 +156,7 @@ class InheritanceManagerTests(TestCase):
|
|||
self.assertEqual(set(results.subclasses),
|
||||
set(expected_related_names))
|
||||
|
||||
def test_manually_specifying_parent_fk_single_subclass(self):
|
||||
def test_manually_specifying_parent_fk_single_subclass(self) -> None:
|
||||
"""
|
||||
Using a string related_name when the relation is manually defined
|
||||
instead of implicit should still work in the same way.
|
||||
|
|
@ -168,11 +176,11 @@ class InheritanceManagerTests(TestCase):
|
|||
self.assertEqual(set(results.subclasses),
|
||||
set(expected_related_names))
|
||||
|
||||
def test_filter_on_values_queryset(self):
|
||||
def test_filter_on_values_queryset(self) -> None:
|
||||
queryset = InheritanceManagerTestChild1.objects.values('id').filter(pk=self.child1.pk)
|
||||
self.assertEqual(list(queryset), [{'id': self.child1.pk}])
|
||||
|
||||
def test_values_list_on_select_subclasses(self):
|
||||
def test_values_list_on_select_subclasses(self) -> None:
|
||||
"""
|
||||
Using `select_subclasses` in conjunction with `values_list()` raised an
|
||||
exception in `_get_sub_obj_recurse()` because the result of `values_list()`
|
||||
|
|
@ -217,14 +225,14 @@ class InheritanceManagerTests(TestCase):
|
|||
|
||||
|
||||
class InheritanceManagerUsingModelsTests(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.parent1 = InheritanceManagerTestParent.objects.create()
|
||||
self.child1 = InheritanceManagerTestChild1.objects.create()
|
||||
self.child2 = InheritanceManagerTestChild2.objects.create()
|
||||
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
|
||||
self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create()
|
||||
|
||||
def test_select_subclass_by_child_model(self):
|
||||
def test_select_subclass_by_child_model(self) -> None:
|
||||
"""
|
||||
Confirm that passing a child model works the same as passing the
|
||||
select_related manually
|
||||
|
|
@ -236,7 +244,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
self.assertEqual(objs.subclasses, objsmodels.subclasses)
|
||||
self.assertEqual(list(objs), list(objsmodels))
|
||||
|
||||
def test_select_subclass_by_grandchild_model(self):
|
||||
def test_select_subclass_by_grandchild_model(self) -> None:
|
||||
"""
|
||||
Confirm that passing a grandchild model works the same as passing the
|
||||
select_related manually
|
||||
|
|
@ -249,7 +257,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
self.assertEqual(objs.subclasses, objsmodels.subclasses)
|
||||
self.assertEqual(list(objs), list(objsmodels))
|
||||
|
||||
def test_selecting_all_subclasses_specifically_grandchildren(self):
|
||||
def test_selecting_all_subclasses_specifically_grandchildren(self) -> None:
|
||||
"""
|
||||
A bare select_subclasses() should achieve the same results as doing
|
||||
select_subclasses and specifying all possible subclasses.
|
||||
|
|
@ -266,7 +274,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses))
|
||||
self.assertEqual(list(objs), list(objsmodels))
|
||||
|
||||
def test_selecting_all_subclasses_specifically_children(self):
|
||||
def test_selecting_all_subclasses_specifically_children(self) -> None:
|
||||
"""
|
||||
A bare select_subclasses() should achieve the same results as doing
|
||||
select_subclasses and specifying all possible subclasses.
|
||||
|
|
@ -294,7 +302,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses))
|
||||
self.assertEqual(list(objs), list(objsmodels))
|
||||
|
||||
def test_select_subclass_just_self(self):
|
||||
def test_select_subclass_just_self(self) -> None:
|
||||
"""
|
||||
Passing in the same model as the manager/queryset is bound against
|
||||
(ie: the root parent) should have no effect on the result set.
|
||||
|
|
@ -310,7 +318,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
|
||||
])
|
||||
|
||||
def test_select_subclass_invalid_related_model(self):
|
||||
def test_select_subclass_invalid_related_model(self) -> None:
|
||||
"""
|
||||
Confirming that giving a stupid model doesn't work.
|
||||
"""
|
||||
|
|
@ -319,7 +327,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
InheritanceManagerTestParent.objects.select_subclasses(
|
||||
TimeFrame).order_by('pk')
|
||||
|
||||
def test_mixing_strings_and_classes_with_grandchildren(self):
|
||||
def test_mixing_strings_and_classes_with_grandchildren(self) -> None:
|
||||
"""
|
||||
Given arguments consisting of both strings and model classes,
|
||||
ensure the right resolutions take place, accounting for the extra
|
||||
|
|
@ -340,7 +348,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
]
|
||||
self.assertEqual(list(objs), expecting2)
|
||||
|
||||
def test_mixing_strings_and_classes_with_children(self):
|
||||
def test_mixing_strings_and_classes_with_children(self) -> None:
|
||||
"""
|
||||
Given arguments consisting of both strings and model classes,
|
||||
ensure the right resolutions take place, walking down as far as
|
||||
|
|
@ -362,7 +370,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
]
|
||||
self.assertEqual(list(objs), expecting2)
|
||||
|
||||
def test_duplications(self):
|
||||
def test_duplications(self) -> None:
|
||||
"""
|
||||
Check that even if the same thing is provided as a string and a model
|
||||
that the right results are retrieved.
|
||||
|
|
@ -379,7 +387,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
|
||||
])
|
||||
|
||||
def test_child_doesnt_accidentally_get_parent(self):
|
||||
def test_child_doesnt_accidentally_get_parent(self) -> None:
|
||||
"""
|
||||
Given a Child model which also has an InheritanceManager,
|
||||
none of the returned objects should be Parent objects.
|
||||
|
|
@ -392,7 +400,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
InheritanceManagerTestChild1(pk=self.grandchild1_2.pk),
|
||||
], list(objs))
|
||||
|
||||
def test_manually_specifying_parent_fk_only_specific_child(self):
|
||||
def test_manually_specifying_parent_fk_only_specific_child(self) -> None:
|
||||
"""
|
||||
given a Model which inherits from another Model, but also declares
|
||||
the OneToOne link manually using `related_name` and `parent_link`,
|
||||
|
|
@ -416,7 +424,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
self.assertEqual(set(results.subclasses),
|
||||
set(expected_related_names))
|
||||
|
||||
def test_extras_descend(self):
|
||||
def test_extras_descend(self) -> None:
|
||||
"""
|
||||
Ensure that extra(select=) values are copied onto sub-classes.
|
||||
"""
|
||||
|
|
@ -425,25 +433,25 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
)
|
||||
self.assertTrue(all(result.foo == (result.id + 1) for result in results))
|
||||
|
||||
def test_limit_to_specific_subclass(self):
|
||||
def test_limit_to_specific_subclass(self) -> None:
|
||||
child3 = InheritanceManagerTestChild3.objects.create()
|
||||
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3)
|
||||
|
||||
self.assertEqual([child3], list(results))
|
||||
|
||||
def test_limit_to_specific_subclass_with_custom_db_column(self):
|
||||
def test_limit_to_specific_subclass_with_custom_db_column(self) -> None:
|
||||
item = InheritanceManagerTestChild3_1.objects.create()
|
||||
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3_1)
|
||||
|
||||
self.assertEqual([item], list(results))
|
||||
|
||||
def test_limit_to_specific_grandchild_class(self):
|
||||
def test_limit_to_specific_grandchild_class(self) -> None:
|
||||
grandchild1 = InheritanceManagerTestGrandChild1.objects.get()
|
||||
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1)
|
||||
|
||||
self.assertEqual([grandchild1], list(results))
|
||||
|
||||
def test_limit_to_child_fetches_grandchildren_as_child_class(self):
|
||||
def test_limit_to_child_fetches_grandchildren_as_child_class(self) -> None:
|
||||
# Not sure if this is the desired behaviour...?
|
||||
children = InheritanceManagerTestChild1.objects.all()
|
||||
|
||||
|
|
@ -451,7 +459,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
|
||||
self.assertEqual(set(children), set(results))
|
||||
|
||||
def test_can_fetch_limited_class_grandchildren(self):
|
||||
def test_can_fetch_limited_class_grandchildren(self) -> None:
|
||||
# Not sure if this is the desired behaviour...?
|
||||
children = InheritanceManagerTestChild1.objects.select_subclasses()
|
||||
|
||||
|
|
@ -459,7 +467,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
|
||||
self.assertEqual(set(children), set(results))
|
||||
|
||||
def test_selecting_multiple_instance_classes(self):
|
||||
def test_selecting_multiple_instance_classes(self) -> None:
|
||||
child3 = InheritanceManagerTestChild3.objects.create()
|
||||
children1 = InheritanceManagerTestChild1.objects.all()
|
||||
|
||||
|
|
@ -467,7 +475,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
|
||||
self.assertEqual(set([child3] + list(children1)), set(results))
|
||||
|
||||
def test_selecting_multiple_instance_classes_including_grandchildren(self):
|
||||
def test_selecting_multiple_instance_classes_including_grandchildren(self) -> None:
|
||||
child3 = InheritanceManagerTestChild3.objects.create()
|
||||
grandchild1 = InheritanceManagerTestGrandChild1.objects.get()
|
||||
|
||||
|
|
@ -475,7 +483,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
|
||||
self.assertEqual({child3, grandchild1}, set(results))
|
||||
|
||||
def test_select_subclasses_interaction_with_instance_of(self):
|
||||
def test_select_subclasses_interaction_with_instance_of(self) -> None:
|
||||
child3 = InheritanceManagerTestChild3.objects.create()
|
||||
|
||||
results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3)
|
||||
|
|
@ -484,7 +492,7 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
|||
|
||||
|
||||
class InheritanceManagerRelatedTests(InheritanceManagerTests):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.related = InheritanceManagerTestRelated.objects.create()
|
||||
self.child1 = InheritanceManagerTestChild1.objects.create(
|
||||
related=self.related)
|
||||
|
|
@ -493,16 +501,16 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
|
|||
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related)
|
||||
self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create(related=self.related)
|
||||
|
||||
def get_manager(self):
|
||||
def get_manager(self) -> RelatedManager[InheritanceManagerTestParent]: # type: ignore[override]
|
||||
return self.related.imtests
|
||||
|
||||
def test_get_method_with_select_subclasses(self):
|
||||
def test_get_method_with_select_subclasses(self) -> None:
|
||||
self.assertEqual(
|
||||
InheritanceManagerTestParent.objects.select_subclasses().get(
|
||||
id=self.child1.id),
|
||||
self.child1)
|
||||
|
||||
def test_get_method_with_select_subclasses_check_for_useless_join(self):
|
||||
def test_get_method_with_select_subclasses_check_for_useless_join(self) -> None:
|
||||
child4 = InheritanceManagerTestChild4.objects.create(related=self.related, other_onetoone=self.child1)
|
||||
self.assertEqual(
|
||||
str(InheritanceManagerTestChild4.objects.select_subclasses().filter(
|
||||
|
|
@ -510,26 +518,26 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
|
|||
str(InheritanceManagerTestChild4.objects.select_subclasses().select_related(None).filter(
|
||||
id=child4.id).query))
|
||||
|
||||
def test_annotate_with_select_subclasses(self):
|
||||
def test_annotate_with_select_subclasses(self) -> None:
|
||||
qs = InheritanceManagerTestParent.objects.select_subclasses().annotate(
|
||||
models.Count('id'))
|
||||
self.assertEqual(qs.get(id=self.child1.id).id__count, 1)
|
||||
|
||||
def test_annotate_with_named_arguments_with_select_subclasses(self):
|
||||
def test_annotate_with_named_arguments_with_select_subclasses(self) -> None:
|
||||
qs = InheritanceManagerTestParent.objects.select_subclasses().annotate(
|
||||
test_count=models.Count('id'))
|
||||
self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
|
||||
|
||||
def test_annotate_before_select_subclasses(self):
|
||||
def test_annotate_before_select_subclasses(self) -> None:
|
||||
qs = InheritanceManagerTestParent.objects.annotate(
|
||||
models.Count('id')).select_subclasses()
|
||||
self.assertEqual(qs.get(id=self.child1.id).id__count, 1)
|
||||
|
||||
def test_annotate_with_named_arguments_before_select_subclasses(self):
|
||||
def test_annotate_with_named_arguments_before_select_subclasses(self) -> None:
|
||||
qs = InheritanceManagerTestParent.objects.annotate(
|
||||
test_count=models.Count('id')).select_subclasses()
|
||||
self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
|
||||
|
||||
def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self):
|
||||
def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self) -> None:
|
||||
qs = InheritanceManagerTestParent.objects.select_subclasses()
|
||||
self.assertEqual(qs.subclasses, qs._clone().subclasses)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from tests.models import BoxJoinModel, JoinItemForeignKey
|
||||
|
||||
|
||||
class JoinManagerTest(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
for i in range(20):
|
||||
BoxJoinModel.objects.create(name=f'name_{i}')
|
||||
|
||||
|
|
@ -13,24 +15,24 @@ class JoinManagerTest(TestCase):
|
|||
)
|
||||
JoinItemForeignKey.objects.create(weight=20)
|
||||
|
||||
def test_self_join(self):
|
||||
def test_self_join(self) -> None:
|
||||
a_slice = BoxJoinModel.objects.all()[0:10]
|
||||
with self.assertNumQueries(1):
|
||||
result = a_slice.join()
|
||||
self.assertEqual(result.count(), 10)
|
||||
|
||||
def test_self_join_with_where_statement(self):
|
||||
def test_self_join_with_where_statement(self) -> None:
|
||||
qs = BoxJoinModel.objects.filter(name='name_1')
|
||||
result = qs.join()
|
||||
self.assertEqual(result.count(), 1)
|
||||
|
||||
def test_join_with_other_qs(self):
|
||||
def test_join_with_other_qs(self) -> None:
|
||||
item_qs = JoinItemForeignKey.objects.filter(weight=10)
|
||||
boxes = BoxJoinModel.objects.all().join(qs=item_qs)
|
||||
self.assertEqual(boxes.count(), 1)
|
||||
self.assertEqual(boxes[0].name, 'name_1')
|
||||
|
||||
def test_reverse_join(self):
|
||||
def test_reverse_join(self) -> None:
|
||||
box_qs = BoxJoinModel.objects.filter(name='name_1')
|
||||
items = JoinItemForeignKey.objects.all().join(box_qs)
|
||||
self.assertEqual(items.count(), 1)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from tests.models import Post
|
||||
|
||||
|
||||
class QueryManagerTests(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
data = ((True, True, 0),
|
||||
(True, False, 4),
|
||||
(False, False, 2),
|
||||
|
|
@ -14,14 +16,14 @@ class QueryManagerTests(TestCase):
|
|||
for p, c, o in data:
|
||||
Post.objects.create(published=p, confirmed=c, order=o)
|
||||
|
||||
def test_passing_kwargs(self):
|
||||
def test_passing_kwargs(self) -> None:
|
||||
qs = Post.public.all()
|
||||
self.assertEqual([p.order for p in qs], [0, 1, 4, 5])
|
||||
|
||||
def test_passing_Q(self):
|
||||
def test_passing_Q(self) -> None:
|
||||
qs = Post.public_confirmed.all()
|
||||
self.assertEqual([p.order for p in qs], [0, 1])
|
||||
|
||||
def test_ordering(self):
|
||||
def test_ordering(self) -> None:
|
||||
qs = Post.public_reversed.all()
|
||||
self.assertEqual([p.order for p in qs], [5, 4, 1, 0])
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from tests.models import CustomSoftDelete
|
||||
|
|
@ -5,21 +7,21 @@ from tests.models import CustomSoftDelete
|
|||
|
||||
class CustomSoftDeleteManagerTests(TestCase):
|
||||
|
||||
def test_custom_manager_empty(self):
|
||||
def test_custom_manager_empty(self) -> None:
|
||||
qs = CustomSoftDelete.available_objects.only_read()
|
||||
self.assertEqual(qs.count(), 0)
|
||||
|
||||
def test_custom_qs_empty(self):
|
||||
def test_custom_qs_empty(self) -> None:
|
||||
qs = CustomSoftDelete.available_objects.all().only_read()
|
||||
self.assertEqual(qs.count(), 0)
|
||||
|
||||
def test_is_read(self):
|
||||
def test_is_read(self) -> None:
|
||||
for is_read in [True, False, True, False]:
|
||||
CustomSoftDelete.available_objects.create(is_read=is_read)
|
||||
qs = CustomSoftDelete.available_objects.only_read()
|
||||
self.assertEqual(qs.count(), 2)
|
||||
|
||||
def test_is_read_removed(self):
|
||||
def test_is_read_removed(self) -> None:
|
||||
for is_read, is_removed in [(True, True), (True, False), (False, False), (False, True)]:
|
||||
CustomSoftDelete.available_objects.create(is_read=is_read, is_removed=is_removed)
|
||||
qs = CustomSoftDelete.available_objects.only_read()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
|
|
@ -8,10 +10,10 @@ from tests.models import StatusManagerAdded
|
|||
|
||||
|
||||
class StatusManagerAddedTests(TestCase):
|
||||
def test_manager_available(self):
|
||||
def test_manager_available(self) -> None:
|
||||
self.assertTrue(isinstance(StatusManagerAdded.active, QueryManager))
|
||||
|
||||
def test_conflict_error(self):
|
||||
def test_conflict_error(self) -> None:
|
||||
with self.assertRaises(ImproperlyConfigured):
|
||||
class ErrorModel(StatusModel):
|
||||
STATUS = (
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
|
||||
|
|
@ -5,23 +7,23 @@ from model_utils.fields import get_excerpt
|
|||
|
||||
|
||||
class MigrationsTests(TestCase):
|
||||
def test_makemigrations(self):
|
||||
def test_makemigrations(self) -> None:
|
||||
call_command('makemigrations', dry_run=True)
|
||||
|
||||
|
||||
class GetExcerptTests(TestCase):
|
||||
def test_split(self):
|
||||
def test_split(self) -> None:
|
||||
e = get_excerpt("some content\n\n<!-- split -->\n\nsome more")
|
||||
self.assertEqual(e, 'some content\n')
|
||||
|
||||
def test_auto_split(self):
|
||||
def test_auto_split(self) -> None:
|
||||
e = get_excerpt("para one\n\npara two\n\npara three")
|
||||
self.assertEqual(e, 'para one\n\npara two')
|
||||
|
||||
def test_middle_of_para(self):
|
||||
def test_middle_of_para(self) -> None:
|
||||
e = get_excerpt("some text\n<!-- split -->\nmore text")
|
||||
self.assertEqual(e, 'some text')
|
||||
|
||||
def test_middle_of_line(self):
|
||||
def test_middle_of_line(self) -> None:
|
||||
e = get_excerpt("some text <!-- split --> more text")
|
||||
self.assertEqual(e, "some text <!-- split --> more text")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from tests.models import ModelWithCustomDescriptor
|
||||
|
||||
|
||||
class CustomDescriptorTests(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.instance = ModelWithCustomDescriptor.objects.create(
|
||||
custom_field='1',
|
||||
tracked_custom_field='1',
|
||||
|
|
@ -12,7 +14,7 @@ class CustomDescriptorTests(TestCase):
|
|||
tracked_regular_field=1,
|
||||
)
|
||||
|
||||
def test_custom_descriptor_works(self):
|
||||
def test_custom_descriptor_works(self) -> None:
|
||||
instance = self.instance
|
||||
self.assertEqual(instance.custom_field, '1')
|
||||
self.assertEqual(instance.__dict__['custom_field'], 1)
|
||||
|
|
@ -25,7 +27,7 @@ class CustomDescriptorTests(TestCase):
|
|||
self.assertEqual(instance.custom_field, '2')
|
||||
self.assertEqual(instance.__dict__['custom_field'], 2)
|
||||
|
||||
def test_deferred(self):
|
||||
def test_deferred(self) -> None:
|
||||
instance = ModelWithCustomDescriptor.objects.only('id').get(
|
||||
pk=self.instance.pk)
|
||||
self.assertIn('custom_field', instance.get_deferred_fields())
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.test import TestCase
|
||||
from django.utils.connection import ConnectionDoesNotExist
|
||||
|
||||
|
|
@ -5,7 +7,7 @@ from tests.models import SoftDeletable
|
|||
|
||||
|
||||
class SoftDeletableModelTests(TestCase):
|
||||
def test_can_only_see_not_removed_entries(self):
|
||||
def test_can_only_see_not_removed_entries(self) -> None:
|
||||
SoftDeletable.available_objects.create(name='a', is_removed=True)
|
||||
SoftDeletable.available_objects.create(name='b', is_removed=False)
|
||||
|
||||
|
|
@ -14,7 +16,7 @@ class SoftDeletableModelTests(TestCase):
|
|||
self.assertEqual(queryset.count(), 1)
|
||||
self.assertEqual(queryset[0].name, 'b')
|
||||
|
||||
def test_instance_cannot_be_fully_deleted(self):
|
||||
def test_instance_cannot_be_fully_deleted(self) -> None:
|
||||
instance = SoftDeletable.available_objects.create(name='a')
|
||||
|
||||
instance.delete()
|
||||
|
|
@ -22,7 +24,7 @@ class SoftDeletableModelTests(TestCase):
|
|||
self.assertEqual(SoftDeletable.available_objects.count(), 0)
|
||||
self.assertEqual(SoftDeletable.all_objects.count(), 1)
|
||||
|
||||
def test_instance_cannot_be_fully_deleted_via_queryset(self):
|
||||
def test_instance_cannot_be_fully_deleted_via_queryset(self) -> None:
|
||||
SoftDeletable.available_objects.create(name='a')
|
||||
|
||||
SoftDeletable.available_objects.all().delete()
|
||||
|
|
@ -30,12 +32,12 @@ class SoftDeletableModelTests(TestCase):
|
|||
self.assertEqual(SoftDeletable.available_objects.count(), 0)
|
||||
self.assertEqual(SoftDeletable.all_objects.count(), 1)
|
||||
|
||||
def test_delete_instance_no_connection(self):
|
||||
def test_delete_instance_no_connection(self) -> None:
|
||||
obj = SoftDeletable.available_objects.create(name='a')
|
||||
|
||||
self.assertRaises(ConnectionDoesNotExist, obj.delete, using='other')
|
||||
|
||||
def test_instance_purge(self):
|
||||
def test_instance_purge(self) -> None:
|
||||
instance = SoftDeletable.available_objects.create(name='a')
|
||||
|
||||
instance.delete(soft=False)
|
||||
|
|
@ -43,11 +45,11 @@ class SoftDeletableModelTests(TestCase):
|
|||
self.assertEqual(SoftDeletable.available_objects.count(), 0)
|
||||
self.assertEqual(SoftDeletable.all_objects.count(), 0)
|
||||
|
||||
def test_instance_purge_no_connection(self):
|
||||
def test_instance_purge_no_connection(self) -> None:
|
||||
instance = SoftDeletable.available_objects.create(name='a')
|
||||
|
||||
self.assertRaises(ConnectionDoesNotExist, instance.delete,
|
||||
using='other', soft=False)
|
||||
|
||||
def test_deprecation_warning(self):
|
||||
def test_deprecation_warning(self) -> None:
|
||||
self.assertWarns(DeprecationWarning, SoftDeletable.objects.all)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import time_machine
|
||||
|
|
@ -7,12 +9,14 @@ from tests.models import CustomManagerStatusModel, Status, StatusPlainTuple
|
|||
|
||||
|
||||
class StatusModelTests(TestCase):
|
||||
def setUp(self):
|
||||
model: type[Status] | type[StatusPlainTuple]
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model = Status
|
||||
self.on_hold = Status.STATUS.on_hold
|
||||
self.active = Status.STATUS.active
|
||||
|
||||
def test_created(self):
|
||||
def test_created(self) -> None:
|
||||
with time_machine.travel(datetime(2016, 1, 1)):
|
||||
c1 = self.model.objects.create()
|
||||
self.assertTrue(c1.status_changed, datetime(2016, 1, 1))
|
||||
|
|
@ -21,7 +25,7 @@ class StatusModelTests(TestCase):
|
|||
self.assertEqual(self.model.active.count(), 2)
|
||||
self.assertEqual(self.model.deleted.count(), 0)
|
||||
|
||||
def test_modification(self):
|
||||
def test_modification(self) -> None:
|
||||
t1 = self.model.objects.create()
|
||||
date_created = t1.status_changed
|
||||
t1.status = self.on_hold
|
||||
|
|
@ -37,7 +41,7 @@ class StatusModelTests(TestCase):
|
|||
t1.save()
|
||||
self.assertTrue(t1.status_changed > date_active_again)
|
||||
|
||||
def test_save_with_update_fields_overrides_status_changed_provided(self):
|
||||
def test_save_with_update_fields_overrides_status_changed_provided(self) -> None:
|
||||
'''
|
||||
Tests if the save method updated status_changed field
|
||||
accordingly when update_fields is used as an argument
|
||||
|
|
@ -52,7 +56,7 @@ class StatusModelTests(TestCase):
|
|||
|
||||
self.assertEqual(t1.status_changed, datetime(2020, 1, 2, tzinfo=timezone.utc))
|
||||
|
||||
def test_save_with_update_fields_overrides_status_changed_not_provided(self):
|
||||
def test_save_with_update_fields_overrides_status_changed_not_provided(self) -> None:
|
||||
'''
|
||||
Tests if the save method updated status_changed field
|
||||
accordingly when update_fields is used as an argument
|
||||
|
|
@ -69,7 +73,7 @@ class StatusModelTests(TestCase):
|
|||
|
||||
|
||||
class StatusModelPlainTupleTests(StatusModelTests):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.model = StatusPlainTuple
|
||||
self.on_hold = StatusPlainTuple.STATUS[2][0]
|
||||
self.active = StatusPlainTuple.STATUS[0][0]
|
||||
|
|
@ -77,7 +81,7 @@ class StatusModelPlainTupleTests(StatusModelTests):
|
|||
|
||||
class StatusModelDefaultManagerTests(TestCase):
|
||||
|
||||
def test_default_manager_is_not_status_model_generated_ones(self):
|
||||
def test_default_manager_is_not_status_model_generated_ones(self) -> None:
|
||||
# Regression test for GH-251
|
||||
# The logic behind order for managers seems to have changed in Django 1.10
|
||||
# and affects default manager.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
|
@ -10,36 +12,36 @@ from tests.models import TimeFrame, TimeFrameManagerAdded
|
|||
|
||||
|
||||
class TimeFramedModelTests(TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
self.now = datetime.now()
|
||||
|
||||
def test_not_yet_begun(self):
|
||||
def test_not_yet_begun(self) -> None:
|
||||
TimeFrame.objects.create(start=self.now + timedelta(days=2))
|
||||
self.assertEqual(TimeFrame.timeframed.count(), 0)
|
||||
|
||||
def test_finished(self):
|
||||
def test_finished(self) -> None:
|
||||
TimeFrame.objects.create(end=self.now - timedelta(days=1))
|
||||
self.assertEqual(TimeFrame.timeframed.count(), 0)
|
||||
|
||||
def test_no_end(self):
|
||||
def test_no_end(self) -> None:
|
||||
TimeFrame.objects.create(start=self.now - timedelta(days=10))
|
||||
self.assertEqual(TimeFrame.timeframed.count(), 1)
|
||||
|
||||
def test_no_start(self):
|
||||
def test_no_start(self) -> None:
|
||||
TimeFrame.objects.create(end=self.now + timedelta(days=2))
|
||||
self.assertEqual(TimeFrame.timeframed.count(), 1)
|
||||
|
||||
def test_within_range(self):
|
||||
def test_within_range(self) -> None:
|
||||
TimeFrame.objects.create(start=self.now - timedelta(days=1),
|
||||
end=self.now + timedelta(days=1))
|
||||
self.assertEqual(TimeFrame.timeframed.count(), 1)
|
||||
|
||||
|
||||
class TimeFrameManagerAddedTests(TestCase):
|
||||
def test_manager_available(self):
|
||||
def test_manager_available(self) -> None:
|
||||
self.assertTrue(isinstance(TimeFrameManagerAdded.timeframed, QueryManager))
|
||||
|
||||
def test_conflict_error(self):
|
||||
def test_conflict_error(self) -> None:
|
||||
with self.assertRaises(ImproperlyConfigured):
|
||||
class ErrorModel(TimeFramedModel):
|
||||
timeframed = models.BooleanField()
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import time_machine
|
||||
|
|
@ -7,19 +10,19 @@ from tests.models import TimeStamp, TimeStampWithStatusModel
|
|||
|
||||
|
||||
class TimeStampedModelTests(TestCase):
|
||||
def test_created(self):
|
||||
def test_created(self) -> None:
|
||||
with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)):
|
||||
t1 = TimeStamp.objects.create()
|
||||
self.assertEqual(t1.created, datetime(2016, 1, 1, tzinfo=timezone.utc))
|
||||
|
||||
def test_created_sets_modified(self):
|
||||
def test_created_sets_modified(self) -> None:
|
||||
'''
|
||||
Ensure that on creation that modified is set exactly equal to created.
|
||||
'''
|
||||
t1 = TimeStamp.objects.create()
|
||||
self.assertEqual(t1.created, t1.modified)
|
||||
|
||||
def test_modified(self):
|
||||
def test_modified(self) -> None:
|
||||
with time_machine.travel(datetime(2016, 1, 1, tzinfo=timezone.utc)):
|
||||
t1 = TimeStamp.objects.create()
|
||||
|
||||
|
|
@ -28,7 +31,7 @@ class TimeStampedModelTests(TestCase):
|
|||
|
||||
self.assertEqual(t1.modified, datetime(2016, 1, 2, tzinfo=timezone.utc))
|
||||
|
||||
def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self):
|
||||
def test_overriding_created_via_object_creation_also_uses_creation_date_for_modified(self) -> None:
|
||||
"""
|
||||
Setting the created date when first creating an object
|
||||
should be permissible.
|
||||
|
|
@ -38,7 +41,7 @@ class TimeStampedModelTests(TestCase):
|
|||
self.assertEqual(t1.created, different_date)
|
||||
self.assertEqual(t1.modified, different_date)
|
||||
|
||||
def test_overriding_modified_via_object_creation(self):
|
||||
def test_overriding_modified_via_object_creation(self) -> None:
|
||||
"""
|
||||
Setting the modified date explicitly should be possible when
|
||||
first creating an object, but not thereafter.
|
||||
|
|
@ -48,7 +51,7 @@ class TimeStampedModelTests(TestCase):
|
|||
self.assertEqual(t1.modified, different_date)
|
||||
self.assertNotEqual(t1.created, different_date)
|
||||
|
||||
def test_overriding_created_after_object_created(self):
|
||||
def test_overriding_created_after_object_created(self) -> None:
|
||||
"""
|
||||
The created date may be changed post-create
|
||||
"""
|
||||
|
|
@ -58,7 +61,7 @@ class TimeStampedModelTests(TestCase):
|
|||
t1.save()
|
||||
self.assertEqual(t1.created, different_date)
|
||||
|
||||
def test_overriding_modified_after_object_created(self):
|
||||
def test_overriding_modified_after_object_created(self) -> None:
|
||||
"""
|
||||
The modified date should always be updated when the object
|
||||
is saved, regardless of attempts to change it.
|
||||
|
|
@ -69,7 +72,7 @@ class TimeStampedModelTests(TestCase):
|
|||
t1.save()
|
||||
self.assertNotEqual(t1.modified, different_date)
|
||||
|
||||
def test_overrides_using_save(self):
|
||||
def test_overrides_using_save(self) -> None:
|
||||
"""
|
||||
The first time an object is saved, allow modification of both
|
||||
created and modified fields.
|
||||
|
|
@ -90,7 +93,7 @@ class TimeStampedModelTests(TestCase):
|
|||
self.assertNotEqual(t1.modified, different_date2)
|
||||
self.assertNotEqual(t1.modified, different_date)
|
||||
|
||||
def test_save_with_update_fields_overrides_modified_provided_within_a(self):
|
||||
def test_save_with_update_fields_overrides_modified_provided_within_a(self) -> None:
|
||||
"""
|
||||
Tests if the save method updated modified field
|
||||
accordingly when update_fields is used as an argument
|
||||
|
|
@ -111,8 +114,8 @@ class TimeStampedModelTests(TestCase):
|
|||
t1.save(update_fields=update_fields)
|
||||
self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc))
|
||||
|
||||
def test_save_is_skipped_for_empty_update_fields_iterable(self):
|
||||
tests = (
|
||||
def test_save_is_skipped_for_empty_update_fields_iterable(self) -> None:
|
||||
tests: Iterable[Iterable[str]] = (
|
||||
[], # list
|
||||
(), # tuple
|
||||
set(), # set
|
||||
|
|
@ -131,7 +134,7 @@ class TimeStampedModelTests(TestCase):
|
|||
self.assertEqual(t1.test_field, 0)
|
||||
self.assertEqual(t1.modified, datetime(2020, 1, 1, tzinfo=timezone.utc))
|
||||
|
||||
def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self):
|
||||
def test_save_updates_modified_value_when_update_fields_explicitly_set_to_none(self) -> None:
|
||||
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)):
|
||||
t1 = TimeStamp.objects.create()
|
||||
|
||||
|
|
@ -140,7 +143,7 @@ class TimeStampedModelTests(TestCase):
|
|||
|
||||
self.assertEqual(t1.modified, datetime(2020, 1, 2, tzinfo=timezone.utc))
|
||||
|
||||
def test_model_inherit_timestampmodel_and_statusmodel(self):
|
||||
def test_model_inherit_timestampmodel_and_statusmodel(self) -> None:
|
||||
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc)):
|
||||
t1 = TimeStampWithStatusModel.objects.create()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from tests.models import CustomNotPrimaryUUIDModel, CustomUUIDModel
|
||||
|
|
@ -5,13 +7,13 @@ from tests.models import CustomNotPrimaryUUIDModel, CustomUUIDModel
|
|||
|
||||
class UUIDFieldTests(TestCase):
|
||||
|
||||
def test_uuid_model_with_uuid_field_as_primary_key(self):
|
||||
def test_uuid_model_with_uuid_field_as_primary_key(self) -> None:
|
||||
instance = CustomUUIDModel()
|
||||
instance.save()
|
||||
self.assertEqual(instance.id.__class__.__name__, 'UUID')
|
||||
self.assertEqual(instance.id, instance.pk)
|
||||
|
||||
def test_uuid_model_with_uuid_field_as_not_primary_key(self):
|
||||
def test_uuid_model_with_uuid_field_as_not_primary_key(self) -> None:
|
||||
instance = CustomNotPrimaryUUIDModel()
|
||||
instance.save()
|
||||
self.assertEqual(instance.uuid.__class__.__name__, 'UUID')
|
||||
|
|
|
|||
Loading…
Reference in a new issue