From 937b3e018f4e5739dd096d511d80f5cc063e65a0 Mon Sep 17 00:00:00 2001 From: Ivan Virabyan Date: Thu, 31 Jan 2013 18:27:16 +0400 Subject: [PATCH] Support for multi-level inheritance in InheritanceManager --- model_utils/managers.py | 38 ++++++++++++++++++++++++++------- model_utils/tests/models.py | 5 +++++ model_utils/tests/tests.py | 42 ++++++++++++++++++++++++++----------- 3 files changed, 66 insertions(+), 19 deletions(-) diff --git a/model_utils/managers.py b/model_utils/managers.py index 60aae40..6efc546 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -5,13 +5,15 @@ from django.db.models.fields.related import OneToOneField from django.db.models.query import QuerySet from django.core.exceptions import ObjectDoesNotExist +try: + from django.db.models.constants import LOOKUP_SEP +except ImportError: # Django <= 1.5 + from django.db.models.sql.constants import LOOKUP_SEP class InheritanceQuerySet(QuerySet): def select_subclasses(self, *subclasses): if not subclasses: - subclasses = [rel.var_name for rel in self.model._meta.get_all_related_objects() - if isinstance(rel.field, OneToOneField) - and issubclass(rel.field.model, self.model)] + subclasses = self._get_subclasses_recurse(self.model) new_qs = self.select_related(*subclasses) new_qs.subclasses = subclasses return new_qs @@ -31,15 +33,14 @@ class InheritanceQuerySet(QuerySet): iter = super(InheritanceQuerySet, self).iterator() if getattr(self, 'subclasses', False): for obj in iter: + sub_obj = None for s in self.subclasses: - try: - sub_obj = getattr(obj, s) - except ObjectDoesNotExist: - sub_obj = None + sub_obj = self._get_sub_obj_recurse(obj, s) if sub_obj: break if not sub_obj: sub_obj = obj + if getattr(self, '_annotated', False): for k in self._annotated: setattr(sub_obj, k, getattr(obj, k)) @@ -49,6 +50,29 @@ class InheritanceQuerySet(QuerySet): for obj in iter: yield obj + def _get_subclasses_recurse(self, model): + rels = [rel for rel in model._meta.get_all_related_objects() + if isinstance(rel.field, OneToOneField) + and issubclass(rel.field.model, model)] + subclasses = [] + for rel in rels: + for subclass in self._get_subclasses_recurse(rel.field.model): + subclasses.append(rel.var_name + LOOKUP_SEP + subclass) + subclasses.append(rel.var_name) + return subclasses + + def _get_sub_obj_recurse(self, obj, s): + rel, _, s = s.partition(LOOKUP_SEP) + try: + node = getattr(obj, rel) + except ObjectDoesNotExist: + return None + if s: + child = self._get_sub_obj_recurse(node, s) + return child or node + else: + return node + class InheritanceManager(models.Manager): use_for_related_fields = True diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index 881ed7d..0f50914 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -1,3 +1,4 @@ +import django from django.db import models from django.utils.translation import ugettext_lazy as _ @@ -29,6 +30,10 @@ class InheritanceManagerTestChild1(InheritanceManagerTestParent): pass +if django.VERSION >= (1, 6, 0): + class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1): + text_field = models.TextField() + class InheritanceManagerTestChild2(InheritanceManagerTestParent): non_related_field_using_descriptor_2 = models.FileField(upload_to="test") diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index bf7637d..ed23b2d 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -1,9 +1,10 @@ from __future__ import with_statement - +import unittest import pickle from datetime import datetime, timedelta +import django from django.db import models from django.db.models.fields import FieldDoesNotExist from django.core.exceptions import ImproperlyConfigured @@ -20,6 +21,8 @@ from model_utils.tests.models import ( StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded, TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot) +if django.VERSION >= (1, 6, 0): + from model_utils.tests.models import InheritanceManagerTestGrandChild1 class GetExcerptTests(TestCase): @@ -279,32 +282,45 @@ class InheritanceManagerTests(TestCase): def setUp(self): self.child1 = InheritanceManagerTestChild1.objects.create() self.child2 = InheritanceManagerTestChild2.objects.create() - + if django.VERSION >= (1, 6, 0): + self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() def get_manager(self): return InheritanceManagerTestParent.objects def test_normal(self): - self.assertEquals(set(self.get_manager().all()), - set([ - InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestParent(pk=self.child2.pk), - ])) + children = set([InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk)]) + if django.VERSION >= (1, 6, 0): + children.add(InheritanceManagerTestParent(pk=self.grandchild1.pk)) + self.assertEquals(set(self.get_manager().all()), children) def test_select_all_subclasses(self): + children = set([self.child1, self.child2]) + if django.VERSION >= (1, 6, 0): + children.add(self.grandchild1) self.assertEquals( - set(self.get_manager().select_subclasses()), - set([self.child1, self.child2])) + set(self.get_manager().select_subclasses()), children) def test_select_specific_subclasses(self): + children = set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)]) + if django.VERSION >= (1, 6, 0): + children.add(InheritanceManagerTestChild1(pk=self.grandchild1.pk)) self.assertEquals( set(self.get_manager().select_subclasses( - "inheritancemanagertestchild1")), - set([self.child1, - InheritanceManagerTestParent(pk=self.child2.pk)])) + "inheritancemanagertestchild1")), children) + + @unittest.skipIf(django.VERSION < (1, 6, 0), "not supported in this django version") + def test_select_specific_grandchildren(self): + children = set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)]) + if django.VERSION >= (1, 6, 0): + children.add(InheritanceManagerTestGrandChild1(pk=self.grandchild1.pk)) + self.assertEquals( + set(self.get_manager().select_subclasses( + "inheritancemanagertestchild1__inheritancemanagertestgrandchild1")), children) def test_get_subclass(self): self.assertEquals( @@ -319,6 +335,8 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests): related=self.related) self.child2 = InheritanceManagerTestChild2.objects.create( related=self.related) + if django.VERSION >= (1, 6, 0): + self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related) def get_manager(self):