Annotate test helpers

This commit is contained in:
Maarten ter Huurne 2024-03-22 18:38:11 +01:00
parent 23f1811b9d
commit c83ef89bc9
2 changed files with 33 additions and 21 deletions

View file

@ -1,9 +1,12 @@
from __future__ import annotations
from typing import Any
from django.db import models
from django.db.backends.base.base import BaseDatabaseWrapper
def mutable_from_db(value):
def mutable_from_db(value: object) -> Any:
if value == '':
return None
try:
@ -14,7 +17,7 @@ def mutable_from_db(value):
return value
def mutable_to_db(value):
def mutable_to_db(value: object) -> str:
if value is None:
return ''
if isinstance(value, list):
@ -23,12 +26,12 @@ def mutable_to_db(value):
class MutableField(models.TextField):
def to_python(self, value):
def to_python(self, value: object) -> Any:
return mutable_from_db(value)
def from_db_value(self, value, expression, connection):
def from_db_value(self, value: object, expression: object, connection: BaseDatabaseWrapper) -> Any:
return mutable_from_db(value)
def get_db_prep_save(self, value, connection):
def get_db_prep_save(self, value: object, connection: BaseDatabaseWrapper) -> str:
value = super().get_db_prep_save(value, connection)
return mutable_to_db(value)

View file

@ -1,9 +1,10 @@
from __future__ import annotations
from typing import ClassVar, TypeVar
from typing import Any, ClassVar, TypeVar, overload
from django.db import models
from django.db.models import Manager
from django.db.models.query import QuerySet
from django.db.models.query_utils import DeferredAttribute
from django.utils.translation import gettext_lazy as _
@ -45,7 +46,7 @@ class InheritanceManagerTestParent(models.Model):
on_delete=models.CASCADE)
objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager()
def __str__(self):
def __str__(self) -> str:
return "{}({})".format(
self.__class__.__name__[len('InheritanceManagerTest'):],
self.pk,
@ -186,7 +187,7 @@ class Post(models.Model):
public: ClassVar[QueryManager[Post]] = QueryManager(published=True)
public_confirmed: ClassVar[QueryManager[Post]] = QueryManager(
models.Q(published=True) & models.Q(confirmed=True))
public_reversed: QueryManager[Post] = QueryManager(
public_reversed: ClassVar[QueryManager[Post]] = QueryManager(
published=True).order_by("-order")
class Meta:
@ -206,7 +207,6 @@ class SplitFieldAbstractParent(models.Model):
class AbstractTracked(models.Model):
number: models.IntegerField
class Meta:
abstract = True
@ -219,7 +219,7 @@ class Tracked(models.Model):
tracker = FieldTracker()
def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
""" No-op save() to ensure that FieldTracker.patch_save() works. """
super().save(*args, **kwargs)
@ -231,7 +231,7 @@ class TrackerTimeStamped(TimeStampedModel):
tracker = FieldTracker()
def save(self, *args, **kwargs):
def save(self, *args: Any, **kwargs: Any) -> None:
""" Automatically add "modified" to update_fields."""
update_fields = kwargs.get('update_fields')
if update_fields is not None:
@ -249,7 +249,7 @@ class TrackedFK(models.Model):
class TrackedAbstract(AbstractTracked):
name = models.CharField(max_length=20)
number = models.IntegerField() # type: ignore[assignment]
number = models.IntegerField()
mutable = MutableField(default=None)
tracker = ModelTracker()
@ -266,7 +266,7 @@ class TrackedNonFieldAttr(models.Model):
number = models.FloatField()
@property
def rounded(self):
def rounded(self) -> int | None:
return round(self.number) if self.number is not None else None
tracker = FieldTracker(fields=['rounded'])
@ -355,43 +355,52 @@ class SoftDeletable(SoftDeletableModel):
all_objects: ClassVar[Manager[SoftDeletable]] = models.Manager()
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet):
def only_read(self):
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet[ModelT]):
def only_read(self) -> QuerySet[ModelT]:
return self.filter(is_read=True)
class CustomSoftDelete(SoftDeletableModel):
is_read = models.BooleanField(default=False)
available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() # type: ignore[misc]
available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)()
class StringyDescriptor:
"""
Descriptor that returns a string version of the underlying integer value.
"""
def __init__(self, name):
def __init__(self, name: str):
self.name = name
def __get__(self, obj, cls=None):
@overload
def __get__(self, obj: None, cls: type[models.Model] | None = None) -> StringyDescriptor:
...
@overload
def __get__(self, obj: models.Model, cls: type[models.Model]) -> str:
...
def __get__(self, obj: models.Model | None, cls: type[models.Model] | None = None) -> StringyDescriptor | str:
if obj is None:
return self
if self.name in obj.get_deferred_fields():
# This queries the database, and sets the value on the instance.
assert cls is not None
fields_map = {f.name: f for f in cls._meta.fields}
field = fields_map[self.name]
DeferredAttribute(field=field).__get__(obj, cls)
return str(obj.__dict__[self.name])
def __set__(self, obj, value):
def __set__(self, obj: object, value: str) -> None:
obj.__dict__[self.name] = int(value)
def __delete__(self, obj):
def __delete__(self, obj: object) -> None:
del obj.__dict__[self.name]
class CustomDescriptorField(models.IntegerField):
def contribute_to_class(self, cls, name, *args, **kwargs):
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
super().contribute_to_class(cls, name, *args, **kwargs)
setattr(cls, name, StringyDescriptor(name))