diff --git a/CHANGES.rst b/CHANGES.rst index b48c821..52d547d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,8 @@ CHANGES tip (unreleased) ---------------- +- Fixed annotation of InheritanceQuerysets. Thanks Jeff Elmore. + - Dropped support for Python 2.5 and Django 1.1. Both are no longer supported even for security fixes, and should not be used. diff --git a/model_utils/managers.py b/model_utils/managers.py index 459de51..14a6b09 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -7,7 +7,12 @@ from django.db.models.fields.related import OneToOneField from django.db.models.manager import Manager from django.db.models.query import QuerySet + class InheritanceQuerySet(QuerySet): + def __init__(self, *args, **kwargs): + self._annotated = None + super(InheritanceQuerySet, self).__init__(*args, **kwargs) + def select_subclasses(self, *subclasses): if not subclasses: subclasses = [rel.var_name for rel in self.model._meta.get_all_related_objects() @@ -19,21 +24,33 @@ class InheritanceQuerySet(QuerySet): def _clone(self, klass=None, setup=False, **kwargs): try: - kwargs.update({'subclasses': self.subclasses}) + kwargs.update({'subclasses': self.subclasses, + '_annotated': self._annotated}) except AttributeError: pass return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs) + def annotate(self, *args, **kwargs): + qset = super(InheritanceQuerySet, self).annotate(*args, **kwargs) + qset._annotated = [a.default_alias for a in args] + kwargs.keys() + return qset + def iterator(self): iter = super(InheritanceQuerySet, self).iterator() if getattr(self, 'subclasses', False): for obj in iter: - obj = [getattr(obj, s) for s in self.subclasses if getattr(obj, s)] or [obj] - yield obj[0] + sub_obj = [getattr(obj, s) for s in self.subclasses if getattr(obj, s)] or [obj] + sub_obj = sub_obj[0] + if self._annotated: + for k in self._annotated: + setattr(sub_obj, k, getattr(obj, k)) + + yield sub_obj else: for obj in iter: yield obj + class InheritanceManager(models.Manager): use_for_related_fields = True diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index fa16e6c..d38522c 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -373,6 +373,19 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests): self.child1) + def test_annotate_with_select_subclasses(self): + qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( + models.Count('id')) + self.assertEqual(qs.get(id=self.child1.id).id__count, 1) + + + def test_annotate_with_named_arguments_with_select_subclasses(self): + qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( + test_count=models.Count('id')) + self.assertEqual(qs.get(id=self.child1.id).test_count, 1) + + + class TimeStampedModelTests(TestCase): def test_created(self): t1 = TimeStamp.objects.create()