mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Annotate test helpers
This commit is contained in:
parent
23f1811b9d
commit
c83ef89bc9
2 changed files with 33 additions and 21 deletions
|
|
@ -1,9 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.db import models
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
|
||||
|
||||
def mutable_from_db(value):
|
||||
def mutable_from_db(value: object) -> Any:
|
||||
if value == '':
|
||||
return None
|
||||
try:
|
||||
|
|
@ -14,7 +17,7 @@ def mutable_from_db(value):
|
|||
return value
|
||||
|
||||
|
||||
def mutable_to_db(value):
|
||||
def mutable_to_db(value: object) -> str:
|
||||
if value is None:
|
||||
return ''
|
||||
if isinstance(value, list):
|
||||
|
|
@ -23,12 +26,12 @@ def mutable_to_db(value):
|
|||
|
||||
|
||||
class MutableField(models.TextField):
|
||||
def to_python(self, value):
|
||||
def to_python(self, value: object) -> Any:
|
||||
return mutable_from_db(value)
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
def from_db_value(self, value: object, expression: object, connection: BaseDatabaseWrapper) -> Any:
|
||||
return mutable_from_db(value)
|
||||
|
||||
def get_db_prep_save(self, value, connection):
|
||||
def get_db_prep_save(self, value: object, connection: BaseDatabaseWrapper) -> str:
|
||||
value = super().get_db_prep_save(value, connection)
|
||||
return mutable_to_db(value)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar, TypeVar
|
||||
from typing import Any, ClassVar, TypeVar, overload
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import Manager
|
||||
from django.db.models.query import QuerySet
|
||||
from django.db.models.query_utils import DeferredAttribute
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
|
@ -45,7 +46,7 @@ class InheritanceManagerTestParent(models.Model):
|
|||
on_delete=models.CASCADE)
|
||||
objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager()
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "{}({})".format(
|
||||
self.__class__.__name__[len('InheritanceManagerTest'):],
|
||||
self.pk,
|
||||
|
|
@ -186,7 +187,7 @@ class Post(models.Model):
|
|||
public: ClassVar[QueryManager[Post]] = QueryManager(published=True)
|
||||
public_confirmed: ClassVar[QueryManager[Post]] = QueryManager(
|
||||
models.Q(published=True) & models.Q(confirmed=True))
|
||||
public_reversed: QueryManager[Post] = QueryManager(
|
||||
public_reversed: ClassVar[QueryManager[Post]] = QueryManager(
|
||||
published=True).order_by("-order")
|
||||
|
||||
class Meta:
|
||||
|
|
@ -206,7 +207,6 @@ class SplitFieldAbstractParent(models.Model):
|
|||
|
||||
|
||||
class AbstractTracked(models.Model):
|
||||
number: models.IntegerField
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
|
@ -219,7 +219,7 @@ class Tracked(models.Model):
|
|||
|
||||
tracker = FieldTracker()
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
""" No-op save() to ensure that FieldTracker.patch_save() works. """
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
|
@ -231,7 +231,7 @@ class TrackerTimeStamped(TimeStampedModel):
|
|||
|
||||
tracker = FieldTracker()
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
def save(self, *args: Any, **kwargs: Any) -> None:
|
||||
""" Automatically add "modified" to update_fields."""
|
||||
update_fields = kwargs.get('update_fields')
|
||||
if update_fields is not None:
|
||||
|
|
@ -249,7 +249,7 @@ class TrackedFK(models.Model):
|
|||
|
||||
class TrackedAbstract(AbstractTracked):
|
||||
name = models.CharField(max_length=20)
|
||||
number = models.IntegerField() # type: ignore[assignment]
|
||||
number = models.IntegerField()
|
||||
mutable = MutableField(default=None)
|
||||
|
||||
tracker = ModelTracker()
|
||||
|
|
@ -266,7 +266,7 @@ class TrackedNonFieldAttr(models.Model):
|
|||
number = models.FloatField()
|
||||
|
||||
@property
|
||||
def rounded(self):
|
||||
def rounded(self) -> int | None:
|
||||
return round(self.number) if self.number is not None else None
|
||||
|
||||
tracker = FieldTracker(fields=['rounded'])
|
||||
|
|
@ -355,43 +355,52 @@ class SoftDeletable(SoftDeletableModel):
|
|||
all_objects: ClassVar[Manager[SoftDeletable]] = models.Manager()
|
||||
|
||||
|
||||
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet):
|
||||
def only_read(self):
|
||||
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet[ModelT]):
|
||||
def only_read(self) -> QuerySet[ModelT]:
|
||||
return self.filter(is_read=True)
|
||||
|
||||
|
||||
class CustomSoftDelete(SoftDeletableModel):
|
||||
is_read = models.BooleanField(default=False)
|
||||
|
||||
available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)() # type: ignore[misc]
|
||||
available_objects = SoftDeletableManager.from_queryset(CustomSoftDeleteQuerySet)()
|
||||
|
||||
|
||||
class StringyDescriptor:
|
||||
"""
|
||||
Descriptor that returns a string version of the underlying integer value.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
def __get__(self, obj, cls=None):
|
||||
@overload
|
||||
def __get__(self, obj: None, cls: type[models.Model] | None = None) -> StringyDescriptor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, obj: models.Model, cls: type[models.Model]) -> str:
|
||||
...
|
||||
|
||||
def __get__(self, obj: models.Model | None, cls: type[models.Model] | None = None) -> StringyDescriptor | str:
|
||||
if obj is None:
|
||||
return self
|
||||
if self.name in obj.get_deferred_fields():
|
||||
# This queries the database, and sets the value on the instance.
|
||||
assert cls is not None
|
||||
fields_map = {f.name: f for f in cls._meta.fields}
|
||||
field = fields_map[self.name]
|
||||
DeferredAttribute(field=field).__get__(obj, cls)
|
||||
return str(obj.__dict__[self.name])
|
||||
|
||||
def __set__(self, obj, value):
|
||||
def __set__(self, obj: object, value: str) -> None:
|
||||
obj.__dict__[self.name] = int(value)
|
||||
|
||||
def __delete__(self, obj):
|
||||
def __delete__(self, obj: object) -> None:
|
||||
del obj.__dict__[self.name]
|
||||
|
||||
|
||||
class CustomDescriptorField(models.IntegerField):
|
||||
def contribute_to_class(self, cls, name, *args, **kwargs):
|
||||
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
|
||||
super().contribute_to_class(cls, name, *args, **kwargs)
|
||||
setattr(cls, name, StringyDescriptor(name))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue