From c83ef89bc9b8f50e48697bb53b0f8696f01d4c11 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Fri, 22 Mar 2024 18:38:11 +0100 Subject: [PATCH] Annotate test helpers --- tests/fields.py | 13 ++++++++----- tests/models.py | 41 +++++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/tests/fields.py b/tests/fields.py index df074f1..d57960c 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1,9 +1,12 @@ from __future__ import annotations +from typing import Any + from django.db import models +from django.db.backends.base.base import BaseDatabaseWrapper -def mutable_from_db(value): +def mutable_from_db(value: object) -> Any: if value == '': return None try: @@ -14,7 +17,7 @@ def mutable_from_db(value): return value -def mutable_to_db(value): +def mutable_to_db(value: object) -> str: if value is None: return '' if isinstance(value, list): @@ -23,12 +26,12 @@ def mutable_to_db(value): class MutableField(models.TextField): - def to_python(self, value): + def to_python(self, value: object) -> Any: return mutable_from_db(value) - def from_db_value(self, value, expression, connection): + def from_db_value(self, value: object, expression: object, connection: BaseDatabaseWrapper) -> Any: return mutable_from_db(value) - def get_db_prep_save(self, value, connection): + def get_db_prep_save(self, value: object, connection: BaseDatabaseWrapper) -> str: value = super().get_db_prep_save(value, connection) return mutable_to_db(value) diff --git a/tests/models.py b/tests/models.py index 58c97c1..f06d15b 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import ClassVar, TypeVar +from typing import Any, ClassVar, TypeVar, overload from django.db import models from django.db.models import Manager +from django.db.models.query import QuerySet from django.db.models.query_utils import DeferredAttribute from django.utils.translation import gettext_lazy as _ @@ -45,7 +46,7 @@ class InheritanceManagerTestParent(models.Model): on_delete=models.CASCADE) objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager() - def __str__(self): + def __str__(self) -> str: return "{}({})".format( self.__class__.__name__[len('InheritanceManagerTest'):], self.pk, @@ -186,7 +187,7 @@ class Post(models.Model): public: ClassVar[QueryManager[Post]] = QueryManager(published=True) public_confirmed: ClassVar[QueryManager[Post]] = QueryManager( models.Q(published=True) & models.Q(confirmed=True)) - public_reversed: QueryManager[Post] = QueryManager( + public_reversed: ClassVar[QueryManager[Post]] = QueryManager( published=True).order_by("-order") class Meta: @@ -206,7 +207,6 @@ class SplitFieldAbstractParent(models.Model): class AbstractTracked(models.Model): - number: models.IntegerField class Meta: abstract = True @@ -219,7 +219,7 @@ class Tracked(models.Model): tracker = FieldTracker() - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ No-op save() to ensure that FieldTracker.patch_save() works. """ super().save(*args, **kwargs) @@ -231,7 +231,7 @@ class TrackerTimeStamped(TimeStampedModel): tracker = FieldTracker() - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: """ Automatically add "modified" to update_fields.""" update_fields = kwargs.get('update_fields') if update_fields is not None: @@ -249,7 +249,7 @@ class TrackedFK(models.Model): class TrackedAbstract(AbstractTracked): name = models.CharField(max_length=20) - number = models.IntegerField() # type: ignore[assignment] + number = models.IntegerField() mutable = MutableField(default=None) tracker = ModelTracker() @@ -266,7 +266,7 @@ class TrackedNonFieldAttr(models.Model): number = models.FloatField() @property - def rounded(self): + def rounded(self) -> int | None: return round(self.number) if self.number is not None else None tracker = FieldTracker(fields=['rounded']) @@ -355,43 +355,52 @@ class SoftDeletable(SoftDeletableModel): all_objects: ClassVar[Manager[SoftDeletable]] = models.Manager() -class CustomSoftDeleteQuerySet(SoftDeletableQuerySet): - def only_read(self): +class CustomSoftDeleteQuerySet(SoftDeletableQuerySet[ModelT]): + def only_read(self) -> QuerySet[ModelT]: return self.filter(is_read=True) class CustomSoftDelete(SoftDeletableModel): is_read = models.BooleanField(default=False) - available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() # type: ignore[misc] + available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() class StringyDescriptor: """ Descriptor that returns a string version of the underlying integer value. """ - def __init__(self, name): + def __init__(self, name: str): self.name = name - def __get__(self, obj, cls=None): + @overload + def __get__(self, obj: None, cls: type[models.Model] | None = None) -> StringyDescriptor: + ... + + @overload + def __get__(self, obj: models.Model, cls: type[models.Model]) -> str: + ... + + def __get__(self, obj: models.Model | None, cls: type[models.Model] | None = None) -> StringyDescriptor | str: if obj is None: return self if self.name in obj.get_deferred_fields(): # This queries the database, and sets the value on the instance. + assert cls is not None fields_map = {f.name: f for f in cls._meta.fields} field = fields_map[self.name] DeferredAttribute(field=field).__get__(obj, cls) return str(obj.__dict__[self.name]) - def __set__(self, obj, value): + def __set__(self, obj: object, value: str) -> None: obj.__dict__[self.name] = int(value) - def __delete__(self, obj): + def __delete__(self, obj: object) -> None: del obj.__dict__[self.name] class CustomDescriptorField(models.IntegerField): - def contribute_to_class(self, cls, name, *args, **kwargs): + def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None: super().contribute_to_class(cls, name, *args, **kwargs) setattr(cls, name, StringyDescriptor(name))