Fix type generics in InheritanceIterable

This commit is contained in:
Maarten ter Huurne 2024-04-17 17:58:42 +02:00
parent f4653f08e5
commit 1db7d6ba33

View file

@ -7,6 +7,7 @@ from django.core.exceptions import ObjectDoesNotExist
from django.db import connection, models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import ModelIterable, QuerySet
from django.db.models.sql.datastructures import Join
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
@ -15,44 +16,51 @@ if TYPE_CHECKING:
from collections.abc import Iterator
from django.db.models.query import BaseIterable
from django.db.models.query import ModelIterable as ModelIterableGeneric
from django.db.models.query import QuerySet as QuerySetGeneric
ModelIterable = ModelIterableGeneric[ModelT]
QuerySet = QuerySetGeneric[ModelT]
else:
from django.db.models.query import ModelIterable, QuerySet
class InheritanceIterable(ModelIterable):
def __iter__(self) -> Iterator[ModelT]:
queryset = self.queryset
iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset)
if hasattr(queryset, 'subclasses'):
assert hasattr(queryset, '_get_sub_obj_recurse')
extras = tuple(queryset.query.extra.keys())
# sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in subclasses:
sub_obj = queryset._get_sub_obj_recurse(obj, s)
if sub_obj:
break
if not sub_obj:
sub_obj = obj
def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]:
iter: ModelIterable[ModelT] = ModelIterable(queryset)
if hasattr(queryset, 'subclasses'):
assert hasattr(queryset, '_get_sub_obj_recurse')
extras = tuple(queryset.query.extra.keys())
# sort the subclass names longest first,
# so with 'a' and 'a__b' it goes as deep as possible
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in subclasses:
sub_obj = queryset._get_sub_obj_recurse(obj, s)
if sub_obj:
break
if not sub_obj:
sub_obj = obj
if hasattr(queryset, '_annotated'):
for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k))
for k in extras:
if hasattr(queryset, '_annotated'):
for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k))
yield sub_obj
else:
yield from iter
for k in extras:
setattr(sub_obj, k, getattr(obj, k))
yield sub_obj
else:
yield from iter
if TYPE_CHECKING:
class InheritanceIterable(ModelIterable[ModelT]):
queryset: QuerySet[ModelT]
def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any):
...
def __iter__(self) -> Iterator[ModelT]:
...
else:
class InheritanceIterable(ModelIterable):
def __iter__(self):
return _iter_inheritance_queryset(self.queryset)
class InheritanceQuerySetMixin(Generic[ModelT]):