Annotate the managers module

This commit is contained in:
Maarten ter Huurne 2023-03-20 19:16:03 +01:00
parent 56ea527286
commit bde2d8f9a9

View file

@ -1,20 +1,35 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar, cast, overload
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db import connection, models from django.db import connection, models
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related import OneToOneField, OneToOneRel from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import ModelIterable, QuerySet
from django.db.models.sql.datastructures import Join from django.db.models.sql.datastructures import Join
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
if TYPE_CHECKING:
from collections.abc import Iterator
from django.db.models.query import BaseIterable
from django.db.models.query import ModelIterable as ModelIterableGeneric
from django.db.models.query import QuerySet as QuerySetGeneric
ModelIterable = ModelIterableGeneric[ModelT]
QuerySet = QuerySetGeneric[ModelT]
else:
from django.db.models.query import ModelIterable, QuerySet
class InheritanceIterable(ModelIterable): class InheritanceIterable(ModelIterable):
def __iter__(self): def __iter__(self) -> Iterator[ModelT]:
queryset = self.queryset queryset = self.queryset
iter = ModelIterable(queryset) iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset)
if getattr(queryset, 'subclasses', False): if hasattr(queryset, 'subclasses'):
assert hasattr(queryset, '_get_sub_obj_recurse')
extras = tuple(queryset.query.extra.keys()) extras = tuple(queryset.query.extra.keys())
# sort the subclass names longest first, # sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible # so with 'a' and 'a__b' it goes as deep as possible
@ -28,7 +43,7 @@ class InheritanceIterable(ModelIterable):
if not sub_obj: if not sub_obj:
sub_obj = obj sub_obj = obj
if getattr(queryset, '_annotated', False): if hasattr(queryset, '_annotated'):
for k in queryset._annotated: for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k)) setattr(sub_obj, k, getattr(obj, k))
@ -40,26 +55,31 @@ class InheritanceIterable(ModelIterable):
yield from iter yield from iter
class InheritanceQuerySetMixin: class InheritanceQuerySetMixin(Generic[ModelT]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._iterable_class = InheritanceIterable
def select_subclasses(self, *subclasses): model: type[ModelT]
calculated_subclasses = self._get_subclasses_recurse(self.model) subclasses: Sequence[str]
def __init__(self, *args: object, **kwargs: object):
super().__init__(*args, **kwargs)
self._iterable_class: type[BaseIterable[ModelT]] = InheritanceIterable
def select_subclasses(self, *subclasses: str | type[models.Model]) -> InheritanceQuerySet[ModelT]:
model: type[ModelT] = self.model
calculated_subclasses = self._get_subclasses_recurse(model)
# if none were passed in, we can just short circuit and select all # if none were passed in, we can just short circuit and select all
if not subclasses: if not subclasses:
subclasses = calculated_subclasses selected_subclasses = calculated_subclasses
else: else:
verified_subclasses = [] verified_subclasses: list[str] = []
for subclass in subclasses: for subclass in subclasses:
# special case for passing in the same model as the queryset # special case for passing in the same model as the queryset
# is bound against. Rather than raise an error later, we know # is bound against. Rather than raise an error later, we know
# we can allow this through. # we can allow this through.
if subclass is self.model: if subclass is model:
continue continue
if not isinstance(subclass, (str,)): if not isinstance(subclass, str):
subclass = self._get_ancestors_path(subclass) subclass = self._get_ancestors_path(subclass)
if subclass in calculated_subclasses: if subclass in calculated_subclasses:
@ -69,38 +89,39 @@ class InheritanceQuerySetMixin:
'{!r} is not in the discovered subclasses, tried: {}'.format( '{!r} is not in the discovered subclasses, tried: {}'.format(
subclass, ', '.join(calculated_subclasses)) subclass, ', '.join(calculated_subclasses))
) )
subclasses = verified_subclasses selected_subclasses = verified_subclasses
if subclasses: new_qs = cast('InheritanceQuerySet[ModelT]', self)
new_qs = self.select_related(*subclasses) if selected_subclasses:
else: new_qs = new_qs.select_related(*selected_subclasses)
new_qs = self new_qs.subclasses = selected_subclasses
new_qs.subclasses = subclasses
return new_qs return new_qs
def _chain(self, **kwargs): def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]:
update = {} update = {}
for name in ['subclasses', '_annotated']: for name in ['subclasses', '_annotated']:
if hasattr(self, name): if hasattr(self, name):
update[name] = getattr(self, name) update[name] = getattr(self, name)
chained = super()._chain(**kwargs) # django-stubs doesn't include this private API.
chained = super()._chain(**kwargs) # type: ignore[misc]
chained.__dict__.update(update) chained.__dict__.update(update)
return chained return chained
def _clone(self): def _clone(self) -> InheritanceQuerySet[ModelT]:
qs = super()._clone() # django-stubs doesn't include this private API.
qs = super()._clone() # type: ignore[misc]
for name in ['subclasses', '_annotated']: for name in ['subclasses', '_annotated']:
if hasattr(self, name): if hasattr(self, name):
setattr(qs, name, getattr(self, name)) setattr(qs, name, getattr(self, name))
return qs return qs
def annotate(self, *args, **kwargs): def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
qset = super().annotate(*args, **kwargs) qset = cast(QuerySet[ModelT], super()).annotate(*args, **kwargs)
qset._annotated = [a.default_alias for a in args] + list(kwargs.keys()) qset._annotated = [a.default_alias for a in args] + list(kwargs.keys())
return qset return qset
def _get_subclasses_recurse(self, model): def _get_subclasses_recurse(self, model: type[models.Model]) -> list[str]:
""" """
Given a Model class, find all related objects, exploring children Given a Model class, find all related objects, exploring children
recursively, returning a `list` of strings representing the recursively, returning a `list` of strings representing the
@ -126,7 +147,7 @@ class InheritanceQuerySetMixin:
subclasses.append(rel.get_accessor_name()) subclasses.append(rel.get_accessor_name())
return subclasses return subclasses
def _get_ancestors_path(self, model): def _get_ancestors_path(self, model: type[models.Model]) -> str:
""" """
Serves as an opposite to _get_subclasses_recurse, instead walking from Serves as an opposite to _get_subclasses_recurse, instead walking from
the Model class up the Model's ancestry and constructing the desired the Model class up the Model's ancestry and constructing the desired
@ -136,7 +157,7 @@ class InheritanceQuerySetMixin:
raise ValueError( raise ValueError(
f"{model!r} is not a subclass of {self.model!r}") f"{model!r} is not a subclass of {self.model!r}")
ancestry = [] ancestry: list[str] = []
# should be a OneToOneField or None # should be a OneToOneField or None
parent_link = model._meta.get_ancestor_link(self.model) parent_link = model._meta.get_ancestor_link(self.model)
@ -149,7 +170,7 @@ class InheritanceQuerySetMixin:
return LOOKUP_SEP.join(ancestry) return LOOKUP_SEP.join(ancestry)
def _get_sub_obj_recurse(self, obj, s): def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None:
rel, _, s = s.partition(LOOKUP_SEP) rel, _, s = s.partition(LOOKUP_SEP)
try: try:
@ -162,12 +183,14 @@ class InheritanceQuerySetMixin:
else: else:
return node return node
def get_subclass(self, *args, **kwargs): def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
return self.select_subclasses().get(*args, **kwargs) return self.select_subclasses().get(*args, **kwargs)
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): # Defining the 'model' attribute using a generic type triggers a bug in mypy:
def instance_of(self, *models): # https://github.com/python/mypy/issues/9031
class InheritanceQuerySet(InheritanceQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc]
def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
""" """
Fetch only objects that are instances of the provided model(s). Fetch only objects that are instances of the provided model(s).
""" """
@ -190,88 +213,118 @@ class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
) for field in model._meta.parents.values() ) for field in model._meta.parents.values()
]) + ')') ]) + ')')
return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) return cast(
'InheritanceQuerySet[ModelT]',
self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
)
class InheritanceManagerMixin: class InheritanceManagerMixin(Generic[ModelT]):
_queryset_class = InheritanceQuerySet _queryset_class = InheritanceQuerySet
def get_queryset(self): if TYPE_CHECKING:
return self._queryset_class(self.model) def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
...
def select_subclasses(self, *subclasses): def get_queryset(self) -> InheritanceQuerySet[ModelT]:
model: type[ModelT] = self.model # type: ignore[attr-defined]
return self._queryset_class(model)
def select_subclasses(
self, *subclasses: str | type[models.Model]
) -> InheritanceQuerySet[ModelT]:
return self.get_queryset().select_subclasses(*subclasses) return self.get_queryset().select_subclasses(*subclasses)
def get_subclass(self, *args, **kwargs): def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
return self.get_queryset().get_subclass(*args, **kwargs) return self.get_queryset().get_subclass(*args, **kwargs)
def instance_of(self, *models): def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
return self.get_queryset().instance_of(*models) return self.get_queryset().instance_of(*models)
class InheritanceManager(InheritanceManagerMixin, models.Manager): class InheritanceManager(InheritanceManagerMixin[ModelT], models.Manager[ModelT]):
pass pass
class QueryManagerMixin: class QueryManagerMixin(Generic[ModelT]):
def __init__(self, *args, **kwargs): @overload
def __init__(self, *args: models.Q):
...
@overload
def __init__(self, **kwargs: object):
...
def __init__(self, *args: models.Q, **kwargs: object):
if args: if args:
self._q = args[0] self._q = args[0]
else: else:
self._q = models.Q(**kwargs) self._q = models.Q(**kwargs)
self._order_by = None self._order_by: tuple[Any, ...] | None = None
super().__init__() super().__init__()
def order_by(self, *args): def order_by(self, *args: Any) -> QueryManager[ModelT]:
self._order_by = args self._order_by = args
return self return cast('QueryManager[ModelT]', self)
def get_queryset(self): def get_queryset(self) -> QuerySet[ModelT]:
qs = super().get_queryset().filter(self._q) qs = super().get_queryset() # type: ignore[misc]
qs = qs.filter(self._q)
if self._order_by is not None: if self._order_by is not None:
return qs.order_by(*self._order_by) return qs.order_by(*self._order_by)
return qs return qs
class QueryManager(QueryManagerMixin, models.Manager): class QueryManager(QueryManagerMixin[ModelT], models.Manager[ModelT]): # type: ignore[misc]
pass pass
class SoftDeletableQuerySetMixin: class SoftDeletableQuerySetMixin(Generic[ModelT]):
""" """
QuerySet for SoftDeletableModel. Instead of removing instance sets QuerySet for SoftDeletableModel. Instead of removing instance sets
its ``is_removed`` field to True. its ``is_removed`` field to True.
""" """
def delete(self): def delete(self) -> None:
""" """
Soft delete objects from queryset (set their ``is_removed`` Soft delete objects from queryset (set their ``is_removed``
field to True) field to True)
""" """
self.update(is_removed=True) cast(QuerySet[ModelT], self).update(is_removed=True)
class SoftDeletableQuerySet(SoftDeletableQuerySetMixin, QuerySet): # Note that our delete() method does not return anything, unlike Django's.
# https://github.com/jazzband/django-model-utils/issues/541
class SoftDeletableQuerySet(SoftDeletableQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc]
pass pass
class SoftDeletableManagerMixin: class SoftDeletableManagerMixin(Generic[ModelT]):
""" """
Manager that limits the queryset by default to show only not removed Manager that limits the queryset by default to show only not removed
instances of model. instances of model.
""" """
_queryset_class = SoftDeletableQuerySet _queryset_class = SoftDeletableQuerySet
def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs): _db: str | None
def __init__(
self,
*args: object,
_emit_deprecation_warnings: bool = False,
**kwargs: object
):
self.emit_deprecation_warnings = _emit_deprecation_warnings self.emit_deprecation_warnings = _emit_deprecation_warnings
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def get_queryset(self): def get_queryset(self) -> SoftDeletableQuerySet[ModelT]:
""" """
Return queryset limited to not removed entries. Return queryset limited to not removed entries.
""" """
model: type[ModelT] = self.model # type: ignore[attr-defined]
if self.emit_deprecation_warnings: if self.emit_deprecation_warnings:
warning_message = ( warning_message = (
"{0}.objects model manager will include soft-deleted objects in an " "{0}.objects model manager will include soft-deleted objects in an "
@ -279,23 +332,23 @@ class SoftDeletableManagerMixin:
"excluding soft-deleted objects. See " "excluding soft-deleted objects. See "
"https://django-model-utils.readthedocs.io/en/stable/models.html" "https://django-model-utils.readthedocs.io/en/stable/models.html"
"#softdeletablemodel for more information." "#softdeletablemodel for more information."
).format(self.model.__class__.__name__) ).format(model.__class__.__name__)
warnings.warn(warning_message, DeprecationWarning) warnings.warn(warning_message, DeprecationWarning)
kwargs = {'model': self.model, 'using': self._db} return self._queryset_class(
if hasattr(self, '_hints'): model=model,
kwargs['hints'] = self._hints using=self._db,
**({'hints': self._hints} if hasattr(self, '_hints') else {})
return self._queryset_class(**kwargs).filter(is_removed=False) ).filter(is_removed=False)
class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager): class SoftDeletableManager(SoftDeletableManagerMixin[ModelT], models.Manager[ModelT]):
pass pass
class JoinQueryset(models.QuerySet): class JoinQueryset(models.QuerySet[Any]):
def join(self, qs=None): def join(self, qs: QuerySet[Any] | None = None) -> QuerySet[Any]:
''' '''
Join one queryset together with another using a temporary table. If Join one queryset together with another using a temporary table. If
no queryset is used, it will use the current queryset and join that no queryset is used, it will use the current queryset and join that
@ -310,11 +363,11 @@ class JoinQueryset(models.QuerySet):
to_field = 'id' to_field = 'id'
if qs: if qs:
fk = [ fks = [
fk for fk in qs.model._meta.fields fk for fk in qs.model._meta.fields
if getattr(fk, 'related_model', None) == self.model if getattr(fk, 'related_model', None) == self.model
] ]
fk = fk[0] if fk else None fk = fks[0] if fks else None
model_set = f'{self.model.__name__.lower()}_set' model_set = f'{self.model.__name__.lower()}_set'
key = fk or getattr(qs.model, model_set, None) key = fk or getattr(qs.model, model_set, None)
@ -371,21 +424,24 @@ class JoinQueryset(models.QuerySet):
return new_qs return new_qs
class JoinManagerMixin: if not TYPE_CHECKING:
""" # Hide deprecated API during type checking, to encourage switch to
Manager that adds a method join. This method allows you to join two # 'JoinQueryset.as_manager()', which is supported by the mypy plugin
querysets together. # of django-stubs.
"""
_queryset_class = JoinQueryset
def get_queryset(self): class JoinManagerMixin:
warnings.warn( """
"JoinManager and JoinManagerMixin are deprecated. " Manager that adds a method join. This method allows you to join two
"Please use 'JoinQueryset.as_manager()' instead.", querysets together.
DeprecationWarning """
)
return self._queryset_class(model=self.model, using=self._db)
def get_queryset(self):
warnings.warn(
"JoinManager and JoinManagerMixin are deprecated. "
"Please use 'JoinQueryset.as_manager()' instead.",
DeprecationWarning
)
return self._queryset_class(model=self.model, using=self._db)
class JoinManager(JoinManagerMixin, models.Manager): class JoinManager(JoinManagerMixin):
pass pass