Support for multi-level inheritance in InheritanceManager

This commit is contained in:
Ivan Virabyan 2013-01-31 18:27:16 +04:00
parent 3abefce3dd
commit 937b3e018f
3 changed files with 66 additions and 19 deletions

View file

@ -5,13 +5,15 @@ from django.db.models.fields.related import OneToOneField
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
try:
from django.db.models.constants import LOOKUP_SEP
except ImportError: # Django <= 1.5
from django.db.models.sql.constants import LOOKUP_SEP
class InheritanceQuerySet(QuerySet): class InheritanceQuerySet(QuerySet):
def select_subclasses(self, *subclasses): def select_subclasses(self, *subclasses):
if not subclasses: if not subclasses:
subclasses = [rel.var_name for rel in self.model._meta.get_all_related_objects() subclasses = self._get_subclasses_recurse(self.model)
if isinstance(rel.field, OneToOneField)
and issubclass(rel.field.model, self.model)]
new_qs = self.select_related(*subclasses) new_qs = self.select_related(*subclasses)
new_qs.subclasses = subclasses new_qs.subclasses = subclasses
return new_qs return new_qs
@ -31,15 +33,14 @@ class InheritanceQuerySet(QuerySet):
iter = super(InheritanceQuerySet, self).iterator() iter = super(InheritanceQuerySet, self).iterator()
if getattr(self, 'subclasses', False): if getattr(self, 'subclasses', False):
for obj in iter: for obj in iter:
sub_obj = None
for s in self.subclasses: for s in self.subclasses:
try: sub_obj = self._get_sub_obj_recurse(obj, s)
sub_obj = getattr(obj, s)
except ObjectDoesNotExist:
sub_obj = None
if sub_obj: if sub_obj:
break break
if not sub_obj: if not sub_obj:
sub_obj = obj sub_obj = obj
if getattr(self, '_annotated', False): if getattr(self, '_annotated', False):
for k in self._annotated: for k in self._annotated:
setattr(sub_obj, k, getattr(obj, k)) setattr(sub_obj, k, getattr(obj, k))
@ -49,6 +50,29 @@ class InheritanceQuerySet(QuerySet):
for obj in iter: for obj in iter:
yield obj yield obj
def _get_subclasses_recurse(self, model):
rels = [rel for rel in model._meta.get_all_related_objects()
if isinstance(rel.field, OneToOneField)
and issubclass(rel.field.model, model)]
subclasses = []
for rel in rels:
for subclass in self._get_subclasses_recurse(rel.field.model):
subclasses.append(rel.var_name + LOOKUP_SEP + subclass)
subclasses.append(rel.var_name)
return subclasses
def _get_sub_obj_recurse(self, obj, s):
rel, _, s = s.partition(LOOKUP_SEP)
try:
node = getattr(obj, rel)
except ObjectDoesNotExist:
return None
if s:
child = self._get_sub_obj_recurse(node, s)
return child or node
else:
return node
class InheritanceManager(models.Manager): class InheritanceManager(models.Manager):
use_for_related_fields = True use_for_related_fields = True

View file

@ -1,3 +1,4 @@
import django
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -29,6 +30,10 @@ class InheritanceManagerTestChild1(InheritanceManagerTestParent):
pass pass
if django.VERSION >= (1, 6, 0):
class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1):
text_field = models.TextField()
class InheritanceManagerTestChild2(InheritanceManagerTestParent): class InheritanceManagerTestChild2(InheritanceManagerTestParent):
non_related_field_using_descriptor_2 = models.FileField(upload_to="test") non_related_field_using_descriptor_2 = models.FileField(upload_to="test")

View file

@ -1,9 +1,10 @@
from __future__ import with_statement from __future__ import with_statement
import unittest
import pickle import pickle
from datetime import datetime, timedelta from datetime import datetime, timedelta
import django
from django.db import models from django.db import models
from django.db.models.fields import FieldDoesNotExist from django.db.models.fields import FieldDoesNotExist
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
@ -20,6 +21,8 @@ from model_utils.tests.models import (
StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded, StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot) TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot)
if django.VERSION >= (1, 6, 0):
from model_utils.tests.models import InheritanceManagerTestGrandChild1
class GetExcerptTests(TestCase): class GetExcerptTests(TestCase):
@ -279,32 +282,45 @@ class InheritanceManagerTests(TestCase):
def setUp(self): def setUp(self):
self.child1 = InheritanceManagerTestChild1.objects.create() self.child1 = InheritanceManagerTestChild1.objects.create()
self.child2 = InheritanceManagerTestChild2.objects.create() self.child2 = InheritanceManagerTestChild2.objects.create()
if django.VERSION >= (1, 6, 0):
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
def get_manager(self): def get_manager(self):
return InheritanceManagerTestParent.objects return InheritanceManagerTestParent.objects
def test_normal(self): def test_normal(self):
self.assertEquals(set(self.get_manager().all()), children = set([InheritanceManagerTestParent(pk=self.child1.pk),
set([ InheritanceManagerTestParent(pk=self.child2.pk)])
InheritanceManagerTestParent(pk=self.child1.pk), if django.VERSION >= (1, 6, 0):
InheritanceManagerTestParent(pk=self.child2.pk), children.add(InheritanceManagerTestParent(pk=self.grandchild1.pk))
])) self.assertEquals(set(self.get_manager().all()), children)
def test_select_all_subclasses(self): def test_select_all_subclasses(self):
children = set([self.child1, self.child2])
if django.VERSION >= (1, 6, 0):
children.add(self.grandchild1)
self.assertEquals( self.assertEquals(
set(self.get_manager().select_subclasses()), set(self.get_manager().select_subclasses()), children)
set([self.child1, self.child2]))
def test_select_specific_subclasses(self): def test_select_specific_subclasses(self):
children = set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)])
if django.VERSION >= (1, 6, 0):
children.add(InheritanceManagerTestChild1(pk=self.grandchild1.pk))
self.assertEquals( self.assertEquals(
set(self.get_manager().select_subclasses( set(self.get_manager().select_subclasses(
"inheritancemanagertestchild1")), "inheritancemanagertestchild1")), children)
set([self.child1,
InheritanceManagerTestParent(pk=self.child2.pk)])) @unittest.skipIf(django.VERSION < (1, 6, 0), "not supported in this django version")
def test_select_specific_grandchildren(self):
children = set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)])
if django.VERSION >= (1, 6, 0):
children.add(InheritanceManagerTestGrandChild1(pk=self.grandchild1.pk))
self.assertEquals(
set(self.get_manager().select_subclasses(
"inheritancemanagertestchild1__inheritancemanagertestgrandchild1")), children)
def test_get_subclass(self): def test_get_subclass(self):
self.assertEquals( self.assertEquals(
@ -319,6 +335,8 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
related=self.related) related=self.related)
self.child2 = InheritanceManagerTestChild2.objects.create( self.child2 = InheritanceManagerTestChild2.objects.create(
related=self.related) related=self.related)
if django.VERSION >= (1, 6, 0):
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related)
def get_manager(self): def get_manager(self):