diff --git a/AUTHORS.rst b/AUTHORS.rst index 3502203..97616c0 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -1,5 +1,6 @@ Carl Meyer Jannis Leidel +Facundo Gaich Gregor Müllegger Jeff Elmore Paul McLanahan diff --git a/CHANGES.rst b/CHANGES.rst index 52d547d..7378e7d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,7 +4,8 @@ CHANGES tip (unreleased) ---------------- -- Fixed annotation of InheritanceQuerysets. Thanks Jeff Elmore. +- Fixed annotation of InheritanceQuerysets. Thanks Jeff Elmore and Facundo + Gaich. - Dropped support for Python 2.5 and Django 1.1. Both are no longer supported even for security fixes, and should not be used. diff --git a/model_utils/managers.py b/model_utils/managers.py index 14a6b09..45139eb 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -9,10 +9,6 @@ from django.db.models.query import QuerySet class InheritanceQuerySet(QuerySet): - def __init__(self, *args, **kwargs): - self._annotated = None - super(InheritanceQuerySet, self).__init__(*args, **kwargs) - def select_subclasses(self, *subclasses): if not subclasses: subclasses = [rel.var_name for rel in self.model._meta.get_all_related_objects() @@ -23,11 +19,9 @@ class InheritanceQuerySet(QuerySet): return new_qs def _clone(self, klass=None, setup=False, **kwargs): - try: - kwargs.update({'subclasses': self.subclasses, - '_annotated': self._annotated}) - except AttributeError: - pass + for name in ['subclasses', '_annotated']: + if hasattr(self, name): + kwargs[name] = getattr(self, name) return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs) def annotate(self, *args, **kwargs): @@ -41,7 +35,7 @@ class InheritanceQuerySet(QuerySet): for obj in iter: sub_obj = [getattr(obj, s) for s in self.subclasses if getattr(obj, s)] or [obj] sub_obj = sub_obj[0] - if self._annotated: + if getattr(self, '_annotated', False): for k in self._annotated: setattr(sub_obj, k, getattr(obj, k)) diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index d38522c..3b5a46a 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -385,6 +385,18 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests): self.assertEqual(qs.get(id=self.child1.id).test_count, 1) + def test_annotate_before_select_subclasses(self): + qs = InheritanceManagerTestParent.objects.annotate( + models.Count('id')).select_subclasses() + self.assertEqual(qs.get(id=self.child1.id).id__count, 1) + + + def test_annotate_with_named_arguments_before_select_subclasses(self): + qs = InheritanceManagerTestParent.objects.annotate( + test_count=models.Count('id')).select_subclasses() + self.assertEqual(qs.get(id=self.child1.id).test_count, 1) + + class TimeStampedModelTests(TestCase): def test_created(self):