diff --git a/model_utils/managers.py b/model_utils/managers.py index 1f01465..19a1425 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -70,6 +70,7 @@ class InheritanceQuerySet(QuerySet): def iterator(self): iter = super(InheritanceQuerySet, self).iterator() if getattr(self, 'subclasses', False): + extras = self.query.extra_select.keys() # sort the subclass names longest first, # so with 'a' and 'a__b' it goes as deep as possible subclasses = sorted(self.subclasses, key=len, reverse=True) @@ -86,6 +87,9 @@ class InheritanceQuerySet(QuerySet): for k in self._annotated: setattr(sub_obj, k, getattr(obj, k)) + for k in extras: + setattr(sub_obj, k, getattr(obj, k)) + yield sub_obj else: for obj in iter: diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 070ba46..bbc0e37 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -974,6 +974,15 @@ class InheritanceManagerUsingModelsTests(TestCase): self.assertEqual(set(results.subclasses), set(expected_related_names)) + def test_extras_descend(self): + """ + Ensure that if extras(select=) is passed, we copy the values down onto + sub classes. + """ + results = InheritanceManagerTestParent.objects.all().extra( + select={'foo': 'id + 1'} + ) + self.assertTrue(all(result.foo == (result.id + 1) for result in results)) class InheritanceManagerRelatedTests(InheritanceManagerTests):