mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Annotate the managers module
This commit is contained in:
parent
56ea527286
commit
bde2d8f9a9
1 changed files with 139 additions and 83 deletions
|
|
@ -1,20 +1,35 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar, cast, overload
|
||||
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db import connection, models
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields.related import OneToOneField, OneToOneRel
|
||||
from django.db.models.query import ModelIterable, QuerySet
|
||||
from django.db.models.sql.datastructures import Join
|
||||
|
||||
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from django.db.models.query import BaseIterable
|
||||
from django.db.models.query import ModelIterable as ModelIterableGeneric
|
||||
from django.db.models.query import QuerySet as QuerySetGeneric
|
||||
|
||||
ModelIterable = ModelIterableGeneric[ModelT]
|
||||
QuerySet = QuerySetGeneric[ModelT]
|
||||
else:
|
||||
from django.db.models.query import ModelIterable, QuerySet
|
||||
|
||||
|
||||
class InheritanceIterable(ModelIterable):
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Iterator[ModelT]:
|
||||
queryset = self.queryset
|
||||
iter = ModelIterable(queryset)
|
||||
if getattr(queryset, 'subclasses', False):
|
||||
iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset)
|
||||
if hasattr(queryset, 'subclasses'):
|
||||
assert hasattr(queryset, '_get_sub_obj_recurse')
|
||||
extras = tuple(queryset.query.extra.keys())
|
||||
# sort the subclass names longest first,
|
||||
# so with 'a' and 'a__b' it goes as deep as possible
|
||||
|
|
@ -28,7 +43,7 @@ class InheritanceIterable(ModelIterable):
|
|||
if not sub_obj:
|
||||
sub_obj = obj
|
||||
|
||||
if getattr(queryset, '_annotated', False):
|
||||
if hasattr(queryset, '_annotated'):
|
||||
for k in queryset._annotated:
|
||||
setattr(sub_obj, k, getattr(obj, k))
|
||||
|
||||
|
|
@ -40,26 +55,31 @@ class InheritanceIterable(ModelIterable):
|
|||
yield from iter
|
||||
|
||||
|
||||
class InheritanceQuerySetMixin:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._iterable_class = InheritanceIterable
|
||||
class InheritanceQuerySetMixin(Generic[ModelT]):
|
||||
|
||||
def select_subclasses(self, *subclasses):
|
||||
calculated_subclasses = self._get_subclasses_recurse(self.model)
|
||||
model: type[ModelT]
|
||||
subclasses: Sequence[str]
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._iterable_class: type[BaseIterable[ModelT]] = InheritanceIterable
|
||||
|
||||
def select_subclasses(self, *subclasses: str | type[models.Model]) -> InheritanceQuerySet[ModelT]:
|
||||
model: type[ModelT] = self.model
|
||||
calculated_subclasses = self._get_subclasses_recurse(model)
|
||||
# if none were passed in, we can just short circuit and select all
|
||||
if not subclasses:
|
||||
subclasses = calculated_subclasses
|
||||
selected_subclasses = calculated_subclasses
|
||||
else:
|
||||
verified_subclasses = []
|
||||
verified_subclasses: list[str] = []
|
||||
for subclass in subclasses:
|
||||
# special case for passing in the same model as the queryset
|
||||
# is bound against. Rather than raise an error later, we know
|
||||
# we can allow this through.
|
||||
if subclass is self.model:
|
||||
if subclass is model:
|
||||
continue
|
||||
|
||||
if not isinstance(subclass, (str,)):
|
||||
if not isinstance(subclass, str):
|
||||
subclass = self._get_ancestors_path(subclass)
|
||||
|
||||
if subclass in calculated_subclasses:
|
||||
|
|
@ -69,38 +89,39 @@ class InheritanceQuerySetMixin:
|
|||
'{!r} is not in the discovered subclasses, tried: {}'.format(
|
||||
subclass, ', '.join(calculated_subclasses))
|
||||
)
|
||||
subclasses = verified_subclasses
|
||||
selected_subclasses = verified_subclasses
|
||||
|
||||
if subclasses:
|
||||
new_qs = self.select_related(*subclasses)
|
||||
else:
|
||||
new_qs = self
|
||||
new_qs.subclasses = subclasses
|
||||
new_qs = cast('InheritanceQuerySet[ModelT]', self)
|
||||
if selected_subclasses:
|
||||
new_qs = new_qs.select_related(*selected_subclasses)
|
||||
new_qs.subclasses = selected_subclasses
|
||||
return new_qs
|
||||
|
||||
def _chain(self, **kwargs):
|
||||
def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]:
|
||||
update = {}
|
||||
for name in ['subclasses', '_annotated']:
|
||||
if hasattr(self, name):
|
||||
update[name] = getattr(self, name)
|
||||
|
||||
chained = super()._chain(**kwargs)
|
||||
# django-stubs doesn't include this private API.
|
||||
chained = super()._chain(**kwargs) # type: ignore[misc]
|
||||
chained.__dict__.update(update)
|
||||
return chained
|
||||
|
||||
def _clone(self):
|
||||
qs = super()._clone()
|
||||
def _clone(self) -> InheritanceQuerySet[ModelT]:
|
||||
# django-stubs doesn't include this private API.
|
||||
qs = super()._clone() # type: ignore[misc]
|
||||
for name in ['subclasses', '_annotated']:
|
||||
if hasattr(self, name):
|
||||
setattr(qs, name, getattr(self, name))
|
||||
return qs
|
||||
|
||||
def annotate(self, *args, **kwargs):
|
||||
qset = super().annotate(*args, **kwargs)
|
||||
def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
qset = cast(QuerySet[ModelT], super()).annotate(*args, **kwargs)
|
||||
qset._annotated = [a.default_alias for a in args] + list(kwargs.keys())
|
||||
return qset
|
||||
|
||||
def _get_subclasses_recurse(self, model):
|
||||
def _get_subclasses_recurse(self, model: type[models.Model]) -> list[str]:
|
||||
"""
|
||||
Given a Model class, find all related objects, exploring children
|
||||
recursively, returning a `list` of strings representing the
|
||||
|
|
@ -126,7 +147,7 @@ class InheritanceQuerySetMixin:
|
|||
subclasses.append(rel.get_accessor_name())
|
||||
return subclasses
|
||||
|
||||
def _get_ancestors_path(self, model):
|
||||
def _get_ancestors_path(self, model: type[models.Model]) -> str:
|
||||
"""
|
||||
Serves as an opposite to _get_subclasses_recurse, instead walking from
|
||||
the Model class up the Model's ancestry and constructing the desired
|
||||
|
|
@ -136,7 +157,7 @@ class InheritanceQuerySetMixin:
|
|||
raise ValueError(
|
||||
f"{model!r} is not a subclass of {self.model!r}")
|
||||
|
||||
ancestry = []
|
||||
ancestry: list[str] = []
|
||||
# should be a OneToOneField or None
|
||||
parent_link = model._meta.get_ancestor_link(self.model)
|
||||
|
||||
|
|
@ -149,7 +170,7 @@ class InheritanceQuerySetMixin:
|
|||
|
||||
return LOOKUP_SEP.join(ancestry)
|
||||
|
||||
def _get_sub_obj_recurse(self, obj, s):
|
||||
def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None:
|
||||
rel, _, s = s.partition(LOOKUP_SEP)
|
||||
|
||||
try:
|
||||
|
|
@ -162,12 +183,14 @@ class InheritanceQuerySetMixin:
|
|||
else:
|
||||
return node
|
||||
|
||||
def get_subclass(self, *args, **kwargs):
|
||||
def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
|
||||
return self.select_subclasses().get(*args, **kwargs)
|
||||
|
||||
|
||||
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
|
||||
def instance_of(self, *models):
|
||||
# Defining the 'model' attribute using a generic type triggers a bug in mypy:
|
||||
# https://github.com/python/mypy/issues/9031
|
||||
class InheritanceQuerySet(InheritanceQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc]
|
||||
def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
|
||||
"""
|
||||
Fetch only objects that are instances of the provided model(s).
|
||||
"""
|
||||
|
|
@ -190,88 +213,118 @@ class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
|
|||
) for field in model._meta.parents.values()
|
||||
]) + ')')
|
||||
|
||||
return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
|
||||
return cast(
|
||||
'InheritanceQuerySet[ModelT]',
|
||||
self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
|
||||
)
|
||||
|
||||
|
||||
class InheritanceManagerMixin:
|
||||
class InheritanceManagerMixin(Generic[ModelT]):
|
||||
_queryset_class = InheritanceQuerySet
|
||||
|
||||
def get_queryset(self):
|
||||
return self._queryset_class(self.model)
|
||||
if TYPE_CHECKING:
|
||||
def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]:
|
||||
...
|
||||
|
||||
def select_subclasses(self, *subclasses):
|
||||
def get_queryset(self) -> InheritanceQuerySet[ModelT]:
|
||||
model: type[ModelT] = self.model # type: ignore[attr-defined]
|
||||
return self._queryset_class(model)
|
||||
|
||||
def select_subclasses(
|
||||
self, *subclasses: str | type[models.Model]
|
||||
) -> InheritanceQuerySet[ModelT]:
|
||||
return self.get_queryset().select_subclasses(*subclasses)
|
||||
|
||||
def get_subclass(self, *args, **kwargs):
|
||||
def get_subclass(self, *args: object, **kwargs: object) -> ModelT:
|
||||
return self.get_queryset().get_subclass(*args, **kwargs)
|
||||
|
||||
def instance_of(self, *models):
|
||||
def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]:
|
||||
return self.get_queryset().instance_of(*models)
|
||||
|
||||
|
||||
class InheritanceManager(InheritanceManagerMixin, models.Manager):
|
||||
class InheritanceManager(InheritanceManagerMixin[ModelT], models.Manager[ModelT]):
|
||||
pass
|
||||
|
||||
|
||||
class QueryManagerMixin:
|
||||
class QueryManagerMixin(Generic[ModelT]):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@overload
|
||||
def __init__(self, *args: models.Q):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, **kwargs: object):
|
||||
...
|
||||
|
||||
def __init__(self, *args: models.Q, **kwargs: object):
|
||||
if args:
|
||||
self._q = args[0]
|
||||
else:
|
||||
self._q = models.Q(**kwargs)
|
||||
self._order_by = None
|
||||
self._order_by: tuple[Any, ...] | None = None
|
||||
super().__init__()
|
||||
|
||||
def order_by(self, *args):
|
||||
def order_by(self, *args: Any) -> QueryManager[ModelT]:
|
||||
self._order_by = args
|
||||
return self
|
||||
return cast('QueryManager[ModelT]', self)
|
||||
|
||||
def get_queryset(self):
|
||||
qs = super().get_queryset().filter(self._q)
|
||||
def get_queryset(self) -> QuerySet[ModelT]:
|
||||
qs = super().get_queryset() # type: ignore[misc]
|
||||
qs = qs.filter(self._q)
|
||||
if self._order_by is not None:
|
||||
return qs.order_by(*self._order_by)
|
||||
return qs
|
||||
|
||||
|
||||
class QueryManager(QueryManagerMixin, models.Manager):
|
||||
class QueryManager(QueryManagerMixin[ModelT], models.Manager[ModelT]): # type: ignore[misc]
|
||||
pass
|
||||
|
||||
|
||||
class SoftDeletableQuerySetMixin:
|
||||
class SoftDeletableQuerySetMixin(Generic[ModelT]):
|
||||
"""
|
||||
QuerySet for SoftDeletableModel. Instead of removing instance sets
|
||||
its ``is_removed`` field to True.
|
||||
"""
|
||||
|
||||
def delete(self):
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Soft delete objects from queryset (set their ``is_removed``
|
||||
field to True)
|
||||
"""
|
||||
self.update(is_removed=True)
|
||||
cast(QuerySet[ModelT], self).update(is_removed=True)
|
||||
|
||||
|
||||
class SoftDeletableQuerySet(SoftDeletableQuerySetMixin, QuerySet):
|
||||
# Note that our delete() method does not return anything, unlike Django's.
|
||||
# https://github.com/jazzband/django-model-utils/issues/541
|
||||
class SoftDeletableQuerySet(SoftDeletableQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc]
|
||||
pass
|
||||
|
||||
|
||||
class SoftDeletableManagerMixin:
|
||||
class SoftDeletableManagerMixin(Generic[ModelT]):
|
||||
"""
|
||||
Manager that limits the queryset by default to show only not removed
|
||||
instances of model.
|
||||
"""
|
||||
_queryset_class = SoftDeletableQuerySet
|
||||
|
||||
def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs):
|
||||
_db: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: object,
|
||||
_emit_deprecation_warnings: bool = False,
|
||||
**kwargs: object
|
||||
):
|
||||
self.emit_deprecation_warnings = _emit_deprecation_warnings
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
def get_queryset(self) -> SoftDeletableQuerySet[ModelT]:
|
||||
"""
|
||||
Return queryset limited to not removed entries.
|
||||
"""
|
||||
|
||||
model: type[ModelT] = self.model # type: ignore[attr-defined]
|
||||
|
||||
if self.emit_deprecation_warnings:
|
||||
warning_message = (
|
||||
"{0}.objects model manager will include soft-deleted objects in an "
|
||||
|
|
@ -279,23 +332,23 @@ class SoftDeletableManagerMixin:
|
|||
"excluding soft-deleted objects. See "
|
||||
"https://django-model-utils.readthedocs.io/en/stable/models.html"
|
||||
"#softdeletablemodel for more information."
|
||||
).format(self.model.__class__.__name__)
|
||||
).format(model.__class__.__name__)
|
||||
warnings.warn(warning_message, DeprecationWarning)
|
||||
|
||||
kwargs = {'model': self.model, 'using': self._db}
|
||||
if hasattr(self, '_hints'):
|
||||
kwargs['hints'] = self._hints
|
||||
|
||||
return self._queryset_class(**kwargs).filter(is_removed=False)
|
||||
return self._queryset_class(
|
||||
model=model,
|
||||
using=self._db,
|
||||
**({'hints': self._hints} if hasattr(self, '_hints') else {})
|
||||
).filter(is_removed=False)
|
||||
|
||||
|
||||
class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager):
|
||||
class SoftDeletableManager(SoftDeletableManagerMixin[ModelT], models.Manager[ModelT]):
|
||||
pass
|
||||
|
||||
|
||||
class JoinQueryset(models.QuerySet):
|
||||
class JoinQueryset(models.QuerySet[Any]):
|
||||
|
||||
def join(self, qs=None):
|
||||
def join(self, qs: QuerySet[Any] | None = None) -> QuerySet[Any]:
|
||||
'''
|
||||
Join one queryset together with another using a temporary table. If
|
||||
no queryset is used, it will use the current queryset and join that
|
||||
|
|
@ -310,11 +363,11 @@ class JoinQueryset(models.QuerySet):
|
|||
to_field = 'id'
|
||||
|
||||
if qs:
|
||||
fk = [
|
||||
fks = [
|
||||
fk for fk in qs.model._meta.fields
|
||||
if getattr(fk, 'related_model', None) == self.model
|
||||
]
|
||||
fk = fk[0] if fk else None
|
||||
fk = fks[0] if fks else None
|
||||
model_set = f'{self.model.__name__.lower()}_set'
|
||||
key = fk or getattr(qs.model, model_set, None)
|
||||
|
||||
|
|
@ -371,21 +424,24 @@ class JoinQueryset(models.QuerySet):
|
|||
return new_qs
|
||||
|
||||
|
||||
class JoinManagerMixin:
|
||||
"""
|
||||
Manager that adds a method join. This method allows you to join two
|
||||
querysets together.
|
||||
"""
|
||||
_queryset_class = JoinQueryset
|
||||
if not TYPE_CHECKING:
|
||||
# Hide deprecated API during type checking, to encourage switch to
|
||||
# 'JoinQueryset.as_manager()', which is supported by the mypy plugin
|
||||
# of django-stubs.
|
||||
|
||||
def get_queryset(self):
|
||||
warnings.warn(
|
||||
"JoinManager and JoinManagerMixin are deprecated. "
|
||||
"Please use 'JoinQueryset.as_manager()' instead.",
|
||||
DeprecationWarning
|
||||
)
|
||||
return self._queryset_class(model=self.model, using=self._db)
|
||||
class JoinManagerMixin:
|
||||
"""
|
||||
Manager that adds a method join. This method allows you to join two
|
||||
querysets together.
|
||||
"""
|
||||
|
||||
def get_queryset(self):
|
||||
warnings.warn(
|
||||
"JoinManager and JoinManagerMixin are deprecated. "
|
||||
"Please use 'JoinQueryset.as_manager()' instead.",
|
||||
DeprecationWarning
|
||||
)
|
||||
return self._queryset_class(model=self.model, using=self._db)
|
||||
|
||||
class JoinManager(JoinManagerMixin, models.Manager):
|
||||
pass
|
||||
class JoinManager(JoinManagerMixin):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Reference in a new issue