diff --git a/model_utils/managers.py b/model_utils/managers.py index 1a870d1..899b988 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -7,6 +7,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.db import connection, models from django.db.models.constants import LOOKUP_SEP from django.db.models.fields.related import OneToOneField, OneToOneRel +from django.db.models.query import ModelIterable, QuerySet from django.db.models.sql.datastructures import Join ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) @@ -15,44 +16,51 @@ if TYPE_CHECKING: from collections.abc import Iterator from django.db.models.query import BaseIterable - from django.db.models.query import ModelIterable as ModelIterableGeneric - from django.db.models.query import QuerySet as QuerySetGeneric - - ModelIterable = ModelIterableGeneric[ModelT] - QuerySet = QuerySetGeneric[ModelT] -else: - from django.db.models.query import ModelIterable, QuerySet -class InheritanceIterable(ModelIterable): - def __iter__(self) -> Iterator[ModelT]: - queryset = self.queryset - iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset) - if hasattr(queryset, 'subclasses'): - assert hasattr(queryset, '_get_sub_obj_recurse') - extras = tuple(queryset.query.extra.keys()) - # sort the subclass names longest first, - # so with 'a' and 'a__b' it goes as deep as possible - subclasses = sorted(queryset.subclasses, key=len, reverse=True) - for obj in iter: - sub_obj = None - for s in subclasses: - sub_obj = queryset._get_sub_obj_recurse(obj, s) - if sub_obj: - break - if not sub_obj: - sub_obj = obj +def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]: + iter: ModelIterable[ModelT] = ModelIterable(queryset) + if hasattr(queryset, 'subclasses'): + assert hasattr(queryset, '_get_sub_obj_recurse') + extras = tuple(queryset.query.extra.keys()) + # sort the subclass names longest first, + # so with 'a' and 'a__b' it goes as deep as possible + subclasses = sorted(queryset.subclasses, key=len, reverse=True) + for obj in iter: + sub_obj = None + for s in subclasses: + sub_obj = queryset._get_sub_obj_recurse(obj, s) + if sub_obj: + break + if not sub_obj: + sub_obj = obj - if hasattr(queryset, '_annotated'): - for k in queryset._annotated: - setattr(sub_obj, k, getattr(obj, k)) - - for k in extras: + if hasattr(queryset, '_annotated'): + for k in queryset._annotated: setattr(sub_obj, k, getattr(obj, k)) - yield sub_obj - else: - yield from iter + for k in extras: + setattr(sub_obj, k, getattr(obj, k)) + + yield sub_obj + else: + yield from iter + + +if TYPE_CHECKING: + class InheritanceIterable(ModelIterable[ModelT]): + queryset: QuerySet[ModelT] + + def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any): + ... + + def __iter__(self) -> Iterator[ModelT]: + ... + +else: + class InheritanceIterable(ModelIterable): + def __iter__(self): + return _iter_inheritance_queryset(self.queryset) class InheritanceQuerySetMixin(Generic[ModelT]):