mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-05-28 11:28:15 +00:00
Support for multi-level inheritance in InheritanceManager
This commit is contained in:
parent
3abefce3dd
commit
937b3e018f
3 changed files with 66 additions and 19 deletions
|
|
@ -5,13 +5,15 @@ from django.db.models.fields.related import OneToOneField
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
|
|
||||||
|
try:
|
||||||
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
|
except ImportError: # Django <= 1.5
|
||||||
|
from django.db.models.sql.constants import LOOKUP_SEP
|
||||||
|
|
||||||
class InheritanceQuerySet(QuerySet):
|
class InheritanceQuerySet(QuerySet):
|
||||||
def select_subclasses(self, *subclasses):
|
def select_subclasses(self, *subclasses):
|
||||||
if not subclasses:
|
if not subclasses:
|
||||||
subclasses = [rel.var_name for rel in self.model._meta.get_all_related_objects()
|
subclasses = self._get_subclasses_recurse(self.model)
|
||||||
if isinstance(rel.field, OneToOneField)
|
|
||||||
and issubclass(rel.field.model, self.model)]
|
|
||||||
new_qs = self.select_related(*subclasses)
|
new_qs = self.select_related(*subclasses)
|
||||||
new_qs.subclasses = subclasses
|
new_qs.subclasses = subclasses
|
||||||
return new_qs
|
return new_qs
|
||||||
|
|
@ -31,15 +33,14 @@ class InheritanceQuerySet(QuerySet):
|
||||||
iter = super(InheritanceQuerySet, self).iterator()
|
iter = super(InheritanceQuerySet, self).iterator()
|
||||||
if getattr(self, 'subclasses', False):
|
if getattr(self, 'subclasses', False):
|
||||||
for obj in iter:
|
for obj in iter:
|
||||||
|
sub_obj = None
|
||||||
for s in self.subclasses:
|
for s in self.subclasses:
|
||||||
try:
|
sub_obj = self._get_sub_obj_recurse(obj, s)
|
||||||
sub_obj = getattr(obj, s)
|
|
||||||
except ObjectDoesNotExist:
|
|
||||||
sub_obj = None
|
|
||||||
if sub_obj:
|
if sub_obj:
|
||||||
break
|
break
|
||||||
if not sub_obj:
|
if not sub_obj:
|
||||||
sub_obj = obj
|
sub_obj = obj
|
||||||
|
|
||||||
if getattr(self, '_annotated', False):
|
if getattr(self, '_annotated', False):
|
||||||
for k in self._annotated:
|
for k in self._annotated:
|
||||||
setattr(sub_obj, k, getattr(obj, k))
|
setattr(sub_obj, k, getattr(obj, k))
|
||||||
|
|
@ -49,6 +50,29 @@ class InheritanceQuerySet(QuerySet):
|
||||||
for obj in iter:
|
for obj in iter:
|
||||||
yield obj
|
yield obj
|
||||||
|
|
||||||
|
def _get_subclasses_recurse(self, model):
|
||||||
|
rels = [rel for rel in model._meta.get_all_related_objects()
|
||||||
|
if isinstance(rel.field, OneToOneField)
|
||||||
|
and issubclass(rel.field.model, model)]
|
||||||
|
subclasses = []
|
||||||
|
for rel in rels:
|
||||||
|
for subclass in self._get_subclasses_recurse(rel.field.model):
|
||||||
|
subclasses.append(rel.var_name + LOOKUP_SEP + subclass)
|
||||||
|
subclasses.append(rel.var_name)
|
||||||
|
return subclasses
|
||||||
|
|
||||||
|
def _get_sub_obj_recurse(self, obj, s):
|
||||||
|
rel, _, s = s.partition(LOOKUP_SEP)
|
||||||
|
try:
|
||||||
|
node = getattr(obj, rel)
|
||||||
|
except ObjectDoesNotExist:
|
||||||
|
return None
|
||||||
|
if s:
|
||||||
|
child = self._get_sub_obj_recurse(node, s)
|
||||||
|
return child or node
|
||||||
|
else:
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
class InheritanceManager(models.Manager):
|
class InheritanceManager(models.Manager):
|
||||||
use_for_related_fields = True
|
use_for_related_fields = True
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import django
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
|
|
@ -29,6 +30,10 @@ class InheritanceManagerTestChild1(InheritanceManagerTestParent):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if django.VERSION >= (1, 6, 0):
|
||||||
|
class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1):
|
||||||
|
text_field = models.TextField()
|
||||||
|
|
||||||
|
|
||||||
class InheritanceManagerTestChild2(InheritanceManagerTestParent):
|
class InheritanceManagerTestChild2(InheritanceManagerTestParent):
|
||||||
non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
|
non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from __future__ import with_statement
|
from __future__ import with_statement
|
||||||
|
import unittest
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import django
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models.fields import FieldDoesNotExist
|
from django.db.models.fields import FieldDoesNotExist
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
|
@ -20,6 +21,8 @@ from model_utils.tests.models import (
|
||||||
StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
|
StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded,
|
||||||
TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot)
|
TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot)
|
||||||
|
|
||||||
|
if django.VERSION >= (1, 6, 0):
|
||||||
|
from model_utils.tests.models import InheritanceManagerTestGrandChild1
|
||||||
|
|
||||||
|
|
||||||
class GetExcerptTests(TestCase):
|
class GetExcerptTests(TestCase):
|
||||||
|
|
@ -279,32 +282,45 @@ class InheritanceManagerTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.child1 = InheritanceManagerTestChild1.objects.create()
|
self.child1 = InheritanceManagerTestChild1.objects.create()
|
||||||
self.child2 = InheritanceManagerTestChild2.objects.create()
|
self.child2 = InheritanceManagerTestChild2.objects.create()
|
||||||
|
if django.VERSION >= (1, 6, 0):
|
||||||
|
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
|
||||||
|
|
||||||
def get_manager(self):
|
def get_manager(self):
|
||||||
return InheritanceManagerTestParent.objects
|
return InheritanceManagerTestParent.objects
|
||||||
|
|
||||||
|
|
||||||
def test_normal(self):
|
def test_normal(self):
|
||||||
self.assertEquals(set(self.get_manager().all()),
|
children = set([InheritanceManagerTestParent(pk=self.child1.pk),
|
||||||
set([
|
InheritanceManagerTestParent(pk=self.child2.pk)])
|
||||||
InheritanceManagerTestParent(pk=self.child1.pk),
|
if django.VERSION >= (1, 6, 0):
|
||||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
children.add(InheritanceManagerTestParent(pk=self.grandchild1.pk))
|
||||||
]))
|
self.assertEquals(set(self.get_manager().all()), children)
|
||||||
|
|
||||||
|
|
||||||
def test_select_all_subclasses(self):
|
def test_select_all_subclasses(self):
|
||||||
|
children = set([self.child1, self.child2])
|
||||||
|
if django.VERSION >= (1, 6, 0):
|
||||||
|
children.add(self.grandchild1)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
set(self.get_manager().select_subclasses()),
|
set(self.get_manager().select_subclasses()), children)
|
||||||
set([self.child1, self.child2]))
|
|
||||||
|
|
||||||
|
|
||||||
def test_select_specific_subclasses(self):
|
def test_select_specific_subclasses(self):
|
||||||
|
children = set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)])
|
||||||
|
if django.VERSION >= (1, 6, 0):
|
||||||
|
children.add(InheritanceManagerTestChild1(pk=self.grandchild1.pk))
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
set(self.get_manager().select_subclasses(
|
set(self.get_manager().select_subclasses(
|
||||||
"inheritancemanagertestchild1")),
|
"inheritancemanagertestchild1")), children)
|
||||||
set([self.child1,
|
|
||||||
InheritanceManagerTestParent(pk=self.child2.pk)]))
|
@unittest.skipIf(django.VERSION < (1, 6, 0), "not supported in this django version")
|
||||||
|
def test_select_specific_grandchildren(self):
|
||||||
|
children = set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)])
|
||||||
|
if django.VERSION >= (1, 6, 0):
|
||||||
|
children.add(InheritanceManagerTestGrandChild1(pk=self.grandchild1.pk))
|
||||||
|
self.assertEquals(
|
||||||
|
set(self.get_manager().select_subclasses(
|
||||||
|
"inheritancemanagertestchild1__inheritancemanagertestgrandchild1")), children)
|
||||||
|
|
||||||
def test_get_subclass(self):
|
def test_get_subclass(self):
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
|
|
@ -319,6 +335,8 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
|
||||||
related=self.related)
|
related=self.related)
|
||||||
self.child2 = InheritanceManagerTestChild2.objects.create(
|
self.child2 = InheritanceManagerTestChild2.objects.create(
|
||||||
related=self.related)
|
related=self.related)
|
||||||
|
if django.VERSION >= (1, 6, 0):
|
||||||
|
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related)
|
||||||
|
|
||||||
|
|
||||||
def get_manager(self):
|
def get_manager(self):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue