Fix type generics in InheritanceIterable

This commit is contained in:
Maarten ter Huurne 2024-04-17 17:58:42 +02:00
parent f4653f08e5
commit 1db7d6ba33

View file

@ -7,6 +7,7 @@ from django.core.exceptions import ObjectDoesNotExist
from django.db import connection, models from django.db import connection, models
from django.db.models.constants import LOOKUP_SEP from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related import OneToOneField, OneToOneRel from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import ModelIterable, QuerySet
from django.db.models.sql.datastructures import Join from django.db.models.sql.datastructures import Join
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True) ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
@ -15,44 +16,51 @@ if TYPE_CHECKING:
from collections.abc import Iterator from collections.abc import Iterator
from django.db.models.query import BaseIterable from django.db.models.query import BaseIterable
from django.db.models.query import ModelIterable as ModelIterableGeneric
from django.db.models.query import QuerySet as QuerySetGeneric
ModelIterable = ModelIterableGeneric[ModelT]
QuerySet = QuerySetGeneric[ModelT]
else:
from django.db.models.query import ModelIterable, QuerySet
class InheritanceIterable(ModelIterable): def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]:
def __iter__(self) -> Iterator[ModelT]: iter: ModelIterable[ModelT] = ModelIterable(queryset)
queryset = self.queryset if hasattr(queryset, 'subclasses'):
iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset) assert hasattr(queryset, '_get_sub_obj_recurse')
if hasattr(queryset, 'subclasses'): extras = tuple(queryset.query.extra.keys())
assert hasattr(queryset, '_get_sub_obj_recurse') # sort the subclass names longest first,
extras = tuple(queryset.query.extra.keys()) # so with 'a' and 'a__b' it goes as deep as possible
# sort the subclass names longest first, subclasses = sorted(queryset.subclasses, key=len, reverse=True)
# so with 'a' and 'a__b' it goes as deep as possible for obj in iter:
subclasses = sorted(queryset.subclasses, key=len, reverse=True) sub_obj = None
for obj in iter: for s in subclasses:
sub_obj = None sub_obj = queryset._get_sub_obj_recurse(obj, s)
for s in subclasses: if sub_obj:
sub_obj = queryset._get_sub_obj_recurse(obj, s) break
if sub_obj: if not sub_obj:
break sub_obj = obj
if not sub_obj:
sub_obj = obj
if hasattr(queryset, '_annotated'): if hasattr(queryset, '_annotated'):
for k in queryset._annotated: for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k))
for k in extras:
setattr(sub_obj, k, getattr(obj, k)) setattr(sub_obj, k, getattr(obj, k))
yield sub_obj for k in extras:
else: setattr(sub_obj, k, getattr(obj, k))
yield from iter
yield sub_obj
else:
yield from iter
if TYPE_CHECKING:
class InheritanceIterable(ModelIterable[ModelT]):
queryset: QuerySet[ModelT]
def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any):
...
def __iter__(self) -> Iterator[ModelT]:
...
else:
class InheritanceIterable(ModelIterable):
def __iter__(self):
return _iter_inheritance_queryset(self.queryset)
class InheritanceQuerySetMixin(Generic[ModelT]): class InheritanceQuerySetMixin(Generic[ModelT]):