diff --git a/model_utils/managers.py b/model_utils/managers.py index 10e9c34..c4631d2 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -1,20 +1,35 @@ from __future__ import annotations import warnings +from typing import TYPE_CHECKING, Any, Generic, Sequence, TypeVar, cast, overload from django.core.exceptions import ObjectDoesNotExist from django.db import connection, models from django.db.models.constants import LOOKUP_SEP from django.db.models.fields.related import OneToOneField, OneToOneRel -from django.db.models.query import ModelIterable, QuerySet from django.db.models.sql.datastructures import Join +ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) + +if TYPE_CHECKING: + from collections.abc import Iterator + + from django.db.models.query import BaseIterable + from django.db.models.query import ModelIterable as ModelIterableGeneric + from django.db.models.query import QuerySet as QuerySetGeneric + + ModelIterable = ModelIterableGeneric[ModelT] + QuerySet = QuerySetGeneric[ModelT] +else: + from django.db.models.query import ModelIterable, QuerySet + class InheritanceIterable(ModelIterable): - def __iter__(self): + def __iter__(self) -> Iterator[ModelT]: queryset = self.queryset - iter = ModelIterable(queryset) - if getattr(queryset, 'subclasses', False): + iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset) + if hasattr(queryset, 'subclasses'): + assert hasattr(queryset, '_get_sub_obj_recurse') extras = tuple(queryset.query.extra.keys()) # sort the subclass names longest first, # so with 'a' and 'a__b' it goes as deep as possible @@ -28,7 +43,7 @@ class InheritanceIterable(ModelIterable): if not sub_obj: sub_obj = obj - if getattr(queryset, '_annotated', False): + if hasattr(queryset, '_annotated'): for k in queryset._annotated: setattr(sub_obj, k, getattr(obj, k)) @@ -40,26 +55,31 @@ class InheritanceIterable(ModelIterable): yield from iter -class InheritanceQuerySetMixin: - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._iterable_class = InheritanceIterable +class InheritanceQuerySetMixin(Generic[ModelT]): - def select_subclasses(self, *subclasses): - calculated_subclasses = self._get_subclasses_recurse(self.model) + model: type[ModelT] + subclasses: Sequence[str] + + def __init__(self, *args: object, **kwargs: object): + super().__init__(*args, **kwargs) + self._iterable_class: type[BaseIterable[ModelT]] = InheritanceIterable + + def select_subclasses(self, *subclasses: str | type[models.Model]) -> InheritanceQuerySet[ModelT]: + model: type[ModelT] = self.model + calculated_subclasses = self._get_subclasses_recurse(model) # if none were passed in, we can just short circuit and select all if not subclasses: - subclasses = calculated_subclasses + selected_subclasses = calculated_subclasses else: - verified_subclasses = [] + verified_subclasses: list[str] = [] for subclass in subclasses: # special case for passing in the same model as the queryset # is bound against. Rather than raise an error later, we know # we can allow this through. - if subclass is self.model: + if subclass is model: continue - if not isinstance(subclass, (str,)): + if not isinstance(subclass, str): subclass = self._get_ancestors_path(subclass) if subclass in calculated_subclasses: @@ -69,38 +89,39 @@ class InheritanceQuerySetMixin: '{!r} is not in the discovered subclasses, tried: {}'.format( subclass, ', '.join(calculated_subclasses)) ) - subclasses = verified_subclasses + selected_subclasses = verified_subclasses - if subclasses: - new_qs = self.select_related(*subclasses) - else: - new_qs = self - new_qs.subclasses = subclasses + new_qs = cast('InheritanceQuerySet[ModelT]', self) + if selected_subclasses: + new_qs = new_qs.select_related(*selected_subclasses) + new_qs.subclasses = selected_subclasses return new_qs - def _chain(self, **kwargs): + def _chain(self, **kwargs: object) -> InheritanceQuerySet[ModelT]: update = {} for name in ['subclasses', '_annotated']: if hasattr(self, name): update[name] = getattr(self, name) - chained = super()._chain(**kwargs) + # django-stubs doesn't include this private API. + chained = super()._chain(**kwargs) # type: ignore[misc] chained.__dict__.update(update) return chained - def _clone(self): - qs = super()._clone() + def _clone(self) -> InheritanceQuerySet[ModelT]: + # django-stubs doesn't include this private API. + qs = super()._clone() # type: ignore[misc] for name in ['subclasses', '_annotated']: if hasattr(self, name): setattr(qs, name, getattr(self, name)) return qs - def annotate(self, *args, **kwargs): - qset = super().annotate(*args, **kwargs) + def annotate(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + qset = cast(QuerySet[ModelT], super()).annotate(*args, **kwargs) qset._annotated = [a.default_alias for a in args] + list(kwargs.keys()) return qset - def _get_subclasses_recurse(self, model): + def _get_subclasses_recurse(self, model: type[models.Model]) -> list[str]: """ Given a Model class, find all related objects, exploring children recursively, returning a `list` of strings representing the @@ -126,7 +147,7 @@ class InheritanceQuerySetMixin: subclasses.append(rel.get_accessor_name()) return subclasses - def _get_ancestors_path(self, model): + def _get_ancestors_path(self, model: type[models.Model]) -> str: """ Serves as an opposite to _get_subclasses_recurse, instead walking from the Model class up the Model's ancestry and constructing the desired @@ -136,7 +157,7 @@ class InheritanceQuerySetMixin: raise ValueError( f"{model!r} is not a subclass of {self.model!r}") - ancestry = [] + ancestry: list[str] = [] # should be a OneToOneField or None parent_link = model._meta.get_ancestor_link(self.model) @@ -149,7 +170,7 @@ class InheritanceQuerySetMixin: return LOOKUP_SEP.join(ancestry) - def _get_sub_obj_recurse(self, obj, s): + def _get_sub_obj_recurse(self, obj: models.Model, s: str) -> ModelT | None: rel, _, s = s.partition(LOOKUP_SEP) try: @@ -162,12 +183,14 @@ class InheritanceQuerySetMixin: else: return node - def get_subclass(self, *args, **kwargs): + def get_subclass(self, *args: object, **kwargs: object) -> ModelT: return self.select_subclasses().get(*args, **kwargs) -class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): - def instance_of(self, *models): +# Defining the 'model' attribute using a generic type triggers a bug in mypy: +# https://github.com/python/mypy/issues/9031 +class InheritanceQuerySet(InheritanceQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc] + def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]: """ Fetch only objects that are instances of the provided model(s). """ @@ -190,88 +213,118 @@ class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): ) for field in model._meta.parents.values() ]) + ')') - return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) + return cast( + 'InheritanceQuerySet[ModelT]', + self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) + ) -class InheritanceManagerMixin: +class InheritanceManagerMixin(Generic[ModelT]): _queryset_class = InheritanceQuerySet - def get_queryset(self): - return self._queryset_class(self.model) + if TYPE_CHECKING: + def filter(self, *args: Any, **kwargs: Any) -> InheritanceQuerySet[ModelT]: + ... - def select_subclasses(self, *subclasses): + def get_queryset(self) -> InheritanceQuerySet[ModelT]: + model: type[ModelT] = self.model # type: ignore[attr-defined] + return self._queryset_class(model) + + def select_subclasses( + self, *subclasses: str | type[models.Model] + ) -> InheritanceQuerySet[ModelT]: return self.get_queryset().select_subclasses(*subclasses) - def get_subclass(self, *args, **kwargs): + def get_subclass(self, *args: object, **kwargs: object) -> ModelT: return self.get_queryset().get_subclass(*args, **kwargs) - def instance_of(self, *models): + def instance_of(self, *models: type[ModelT]) -> InheritanceQuerySet[ModelT]: return self.get_queryset().instance_of(*models) -class InheritanceManager(InheritanceManagerMixin, models.Manager): +class InheritanceManager(InheritanceManagerMixin[ModelT], models.Manager[ModelT]): pass -class QueryManagerMixin: +class QueryManagerMixin(Generic[ModelT]): - def __init__(self, *args, **kwargs): + @overload + def __init__(self, *args: models.Q): + ... + + @overload + def __init__(self, **kwargs: object): + ... + + def __init__(self, *args: models.Q, **kwargs: object): if args: self._q = args[0] else: self._q = models.Q(**kwargs) - self._order_by = None + self._order_by: tuple[Any, ...] | None = None super().__init__() - def order_by(self, *args): + def order_by(self, *args: Any) -> QueryManager[ModelT]: self._order_by = args - return self + return cast('QueryManager[ModelT]', self) - def get_queryset(self): - qs = super().get_queryset().filter(self._q) + def get_queryset(self) -> QuerySet[ModelT]: + qs = super().get_queryset() # type: ignore[misc] + qs = qs.filter(self._q) if self._order_by is not None: return qs.order_by(*self._order_by) return qs -class QueryManager(QueryManagerMixin, models.Manager): +class QueryManager(QueryManagerMixin[ModelT], models.Manager[ModelT]): # type: ignore[misc] pass -class SoftDeletableQuerySetMixin: +class SoftDeletableQuerySetMixin(Generic[ModelT]): """ QuerySet for SoftDeletableModel. Instead of removing instance sets its ``is_removed`` field to True. """ - def delete(self): + def delete(self) -> None: """ Soft delete objects from queryset (set their ``is_removed`` field to True) """ - self.update(is_removed=True) + cast(QuerySet[ModelT], self).update(is_removed=True) -class SoftDeletableQuerySet(SoftDeletableQuerySetMixin, QuerySet): +# Note that our delete() method does not return anything, unlike Django's. +# https://github.com/jazzband/django-model-utils/issues/541 +class SoftDeletableQuerySet(SoftDeletableQuerySetMixin[ModelT], QuerySet[ModelT]): # type: ignore[misc] pass -class SoftDeletableManagerMixin: +class SoftDeletableManagerMixin(Generic[ModelT]): """ Manager that limits the queryset by default to show only not removed instances of model. """ _queryset_class = SoftDeletableQuerySet - def __init__(self, *args, _emit_deprecation_warnings=False, **kwargs): + _db: str | None + + def __init__( + self, + *args: object, + _emit_deprecation_warnings: bool = False, + **kwargs: object + ): self.emit_deprecation_warnings = _emit_deprecation_warnings super().__init__(*args, **kwargs) - def get_queryset(self): + def get_queryset(self) -> SoftDeletableQuerySet[ModelT]: """ Return queryset limited to not removed entries. """ + model: type[ModelT] = self.model # type: ignore[attr-defined] + if self.emit_deprecation_warnings: warning_message = ( "{0}.objects model manager will include soft-deleted objects in an " @@ -279,23 +332,23 @@ class SoftDeletableManagerMixin: "excluding soft-deleted objects. See " "https://django-model-utils.readthedocs.io/en/stable/models.html" "#softdeletablemodel for more information." - ).format(self.model.__class__.__name__) + ).format(model.__class__.__name__) warnings.warn(warning_message, DeprecationWarning) - kwargs = {'model': self.model, 'using': self._db} - if hasattr(self, '_hints'): - kwargs['hints'] = self._hints - - return self._queryset_class(**kwargs).filter(is_removed=False) + return self._queryset_class( + model=model, + using=self._db, + **({'hints': self._hints} if hasattr(self, '_hints') else {}) + ).filter(is_removed=False) -class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager): +class SoftDeletableManager(SoftDeletableManagerMixin[ModelT], models.Manager[ModelT]): pass -class JoinQueryset(models.QuerySet): +class JoinQueryset(models.QuerySet[Any]): - def join(self, qs=None): + def join(self, qs: QuerySet[Any] | None = None) -> QuerySet[Any]: ''' Join one queryset together with another using a temporary table. If no queryset is used, it will use the current queryset and join that @@ -310,11 +363,11 @@ class JoinQueryset(models.QuerySet): to_field = 'id' if qs: - fk = [ + fks = [ fk for fk in qs.model._meta.fields if getattr(fk, 'related_model', None) == self.model ] - fk = fk[0] if fk else None + fk = fks[0] if fks else None model_set = f'{self.model.__name__.lower()}_set' key = fk or getattr(qs.model, model_set, None) @@ -371,21 +424,24 @@ class JoinQueryset(models.QuerySet): return new_qs -class JoinManagerMixin: - """ - Manager that adds a method join. This method allows you to join two - querysets together. - """ - _queryset_class = JoinQueryset +if not TYPE_CHECKING: + # Hide deprecated API during type checking, to encourage switch to + # 'JoinQueryset.as_manager()', which is supported by the mypy plugin + # of django-stubs. - def get_queryset(self): - warnings.warn( - "JoinManager and JoinManagerMixin are deprecated. " - "Please use 'JoinQueryset.as_manager()' instead.", - DeprecationWarning - ) - return self._queryset_class(model=self.model, using=self._db) + class JoinManagerMixin: + """ + Manager that adds a method join. This method allows you to join two + querysets together. + """ + def get_queryset(self): + warnings.warn( + "JoinManager and JoinManagerMixin are deprecated. " + "Please use 'JoinQueryset.as_manager()' instead.", + DeprecationWarning + ) + return self._queryset_class(model=self.model, using=self._db) -class JoinManager(JoinManagerMixin, models.Manager): - pass + class JoinManager(JoinManagerMixin): + pass