mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Fix type generics in InheritanceIterable
This commit is contained in:
parent
f4653f08e5
commit
1db7d6ba33
1 changed files with 41 additions and 33 deletions
|
|
@ -7,6 +7,7 @@ from django.core.exceptions import ObjectDoesNotExist
|
|||
from django.db import connection, models
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields.related import OneToOneField, OneToOneRel
|
||||
from django.db.models.query import ModelIterable, QuerySet
|
||||
from django.db.models.sql.datastructures import Join
|
||||
|
||||
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
|
||||
|
|
@ -15,44 +16,51 @@ if TYPE_CHECKING:
|
|||
from collections.abc import Iterator
|
||||
|
||||
from django.db.models.query import BaseIterable
|
||||
from django.db.models.query import ModelIterable as ModelIterableGeneric
|
||||
from django.db.models.query import QuerySet as QuerySetGeneric
|
||||
|
||||
ModelIterable = ModelIterableGeneric[ModelT]
|
||||
QuerySet = QuerySetGeneric[ModelT]
|
||||
else:
|
||||
from django.db.models.query import ModelIterable, QuerySet
|
||||
|
||||
|
||||
class InheritanceIterable(ModelIterable):
|
||||
def __iter__(self) -> Iterator[ModelT]:
|
||||
queryset = self.queryset
|
||||
iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset)
|
||||
if hasattr(queryset, 'subclasses'):
|
||||
assert hasattr(queryset, '_get_sub_obj_recurse')
|
||||
extras = tuple(queryset.query.extra.keys())
|
||||
# sort the subclass names longest first,
|
||||
# so with 'a' and 'a__b' it goes as deep as possible
|
||||
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
|
||||
for obj in iter:
|
||||
sub_obj = None
|
||||
for s in subclasses:
|
||||
sub_obj = queryset._get_sub_obj_recurse(obj, s)
|
||||
if sub_obj:
|
||||
break
|
||||
if not sub_obj:
|
||||
sub_obj = obj
|
||||
def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]:
|
||||
iter: ModelIterable[ModelT] = ModelIterable(queryset)
|
||||
if hasattr(queryset, 'subclasses'):
|
||||
assert hasattr(queryset, '_get_sub_obj_recurse')
|
||||
extras = tuple(queryset.query.extra.keys())
|
||||
# sort the subclass names longest first,
|
||||
# so with 'a' and 'a__b' it goes as deep as possible
|
||||
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
|
||||
for obj in iter:
|
||||
sub_obj = None
|
||||
for s in subclasses:
|
||||
sub_obj = queryset._get_sub_obj_recurse(obj, s)
|
||||
if sub_obj:
|
||||
break
|
||||
if not sub_obj:
|
||||
sub_obj = obj
|
||||
|
||||
if hasattr(queryset, '_annotated'):
|
||||
for k in queryset._annotated:
|
||||
setattr(sub_obj, k, getattr(obj, k))
|
||||
|
||||
for k in extras:
|
||||
if hasattr(queryset, '_annotated'):
|
||||
for k in queryset._annotated:
|
||||
setattr(sub_obj, k, getattr(obj, k))
|
||||
|
||||
yield sub_obj
|
||||
else:
|
||||
yield from iter
|
||||
for k in extras:
|
||||
setattr(sub_obj, k, getattr(obj, k))
|
||||
|
||||
yield sub_obj
|
||||
else:
|
||||
yield from iter
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
class InheritanceIterable(ModelIterable[ModelT]):
|
||||
queryset: QuerySet[ModelT]
|
||||
|
||||
def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any):
|
||||
...
|
||||
|
||||
def __iter__(self) -> Iterator[ModelT]:
|
||||
...
|
||||
|
||||
else:
|
||||
class InheritanceIterable(ModelIterable):
|
||||
def __iter__(self):
|
||||
return _iter_inheritance_queryset(self.queryset)
|
||||
|
||||
|
||||
class InheritanceQuerySetMixin(Generic[ModelT]):
|
||||
|
|
|
|||
Loading…
Reference in a new issue