Merge pull request #79 from kezabelle/feature/model_subclasses

Implementation for selecting subclasses by Model class rather than string.
This commit is contained in:
Carl Meyer 2013-10-22 11:12:48 -07:00
commit 0cee79ecd7
4 changed files with 306 additions and 10 deletions

View file

@ -51,6 +51,27 @@ be returned as their actual type, you can pass subclass names to
nearby_places = Place.objects.select_subclasses("restaurant")
# restaurants will be Restaurant instances, bars will still be Place instances
nearby_places = Place.objects.select_subclasses("restaurant", "bar")
# all Places will be converted to Restaurant and Bar instances.
It is also possible to use the subclasses themselves as arguments to
``select_subclasses``, leaving it to calculate the relationship for you:
.. code-block:: python
nearby_places = Place.objects.select_subclasses(Restaurant)
# restaurants will be Restaurant instances, bars will still be Place instances
nearby_places = Place.objects.select_subclasses(Restaurant, Bar)
# all Places will be converted to Restaurant and Bar instances.
It is even possible to mix and match the two:
.. code-block:: python
nearby_places = Place.objects.select_subclasses(Restaurant, "bar")
# all Places will be converted to Restaurant and Bar instances.
``InheritanceManager`` also provides a subclass-fetching alternative to the
``get()`` method:

View file

@ -7,20 +7,42 @@ from django.core.exceptions import ObjectDoesNotExist
try:
from django.db.models.constants import LOOKUP_SEP
from django.utils.six import string_types
except ImportError: # Django < 1.5
from django.db.models.sql.constants import LOOKUP_SEP
string_types = (basestring,)
class InheritanceQuerySet(QuerySet):
def select_subclasses(self, *subclasses):
levels = self._get_maximum_depth()
calculated_subclasses = self._get_subclasses_recurse(self.model,
levels=levels)
# if none were passed in, we can just short circuit and select all
if not subclasses:
# only recurse one level on Django < 1.6 to avoid triggering
# https://code.djangoproject.com/ticket/16572
levels = None
if django.VERSION < (1, 6, 0):
levels = 1
subclasses = self._get_subclasses_recurse(self.model, levels=levels)
subclasses = calculated_subclasses
else:
verified_subclasses = []
for subclass in subclasses:
# special case for passing in the same model as the queryset
# is bound against. Rather than raise an error later, we know
# we can allow this through.
if subclass is self.model:
continue
if not isinstance(subclass, string_types):
subclass = self._get_ancestors_path(subclass,
levels=levels)
if subclass in calculated_subclasses:
verified_subclasses.append(subclass)
else:
raise ValueError('%r is not in the discovered subclasses, '
'tried: %s' % (subclass,
', '.join(calculated_subclasses),
))
subclasses = verified_subclasses
# workaround https://code.djangoproject.com/ticket/16855
field_dict = self.query.select_related
new_qs = self.select_related(*subclasses)
@ -69,9 +91,14 @@ class InheritanceQuerySet(QuerySet):
def _get_subclasses_recurse(self, model, levels=None):
"""
Given a Model class, find all related objects, exploring children
recursively, returning a `list` of strings representing the
relations for select_related
"""
rels = [rel for rel in model._meta.get_all_related_objects()
if isinstance(rel.field, OneToOneField)
and issubclass(rel.field.model, model)]
if isinstance(rel.field, OneToOneField)
and issubclass(rel.field.model, model)]
subclasses = []
if levels:
levels -= 1
@ -84,6 +111,29 @@ class InheritanceQuerySet(QuerySet):
return subclasses
def _get_ancestors_path(self, model, levels=None):
"""
Serves as an opposite to _get_subclasses_recurse, instead walking from
the Model class up the Model's ancestry and constructing the desired
select_related string backwards.
"""
if not issubclass(model, self.model):
raise ValueError("%r is not a subclass of %r" % (model, self.model))
ancestry = []
# should be a OneToOneField or None
parent = model._meta.get_ancestor_link(self.model)
if levels:
levels -= 1
while parent is not None:
ancestry.insert(0, parent.related.var_name)
if levels or levels is None:
parent = parent.related.parent_model._meta.get_ancestor_link(self.model)
else:
parent = None
return LOOKUP_SEP.join(ancestry)
def _get_sub_obj_recurse(self, obj, s):
rel, _, s = s.partition(LOOKUP_SEP)
try:
@ -99,6 +149,18 @@ class InheritanceQuerySet(QuerySet):
def get_subclass(self, *args, **kwargs):
return self.select_subclasses().get(*args, **kwargs)
def _get_maximum_depth(self):
"""
Under Django versions < 1.6, to avoid triggering
https://code.djangoproject.com/ticket/16572 we can only look
as far as children.
"""
levels = None
if django.VERSION < (1, 6, 0):
levels = 1
return levels
class InheritanceManager(models.Manager):
use_for_related_fields = True

View file

@ -27,6 +27,8 @@ class InheritanceManagerTestParent(models.Model):
normal_field = models.TextField()
objects = InheritanceManager()
def __unicode__(self):
return unicode(self.pk)
def __str__(self):
return "%s(%s)" % (
@ -39,6 +41,7 @@ class InheritanceManagerTestParent(models.Model):
class InheritanceManagerTestChild1(InheritanceManagerTestParent):
non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
normal_field_2 = models.TextField()
objects = InheritanceManager()
class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1):
@ -52,7 +55,6 @@ class InheritanceManagerTestGrandChild1_2(InheritanceManagerTestChild1):
class InheritanceManagerTestChild2(InheritanceManagerTestParent):
non_related_field_using_descriptor_2 = models.FileField(upload_to="test")
normal_field_2 = models.TextField()
pass

View file

@ -589,6 +589,17 @@ class InheritanceManagerTests(TestCase):
set(self.get_manager().select_subclasses()), children)
def test_select_subclasses_invalid_relation(self):
"""
If an invalid relation string is provided, we can provide the user
with a list which is valid, rather than just have the select_related()
raise an AttributeError further in.
"""
regex = '^.+? is not in the discovered subclasses, tried:.+$'
with self.assertRaisesRegexp(ValueError, regex):
self.get_manager().select_subclasses('user')
def test_select_specific_subclasses(self):
children = set([
self.child1,
@ -662,6 +673,206 @@ class InheritanceManagerTests(TestCase):
obj.inheritancemanagertestchild1
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_version_determining_any_depth(self):
self.assertIsNone(self.get_manager().all()._get_maximum_depth())
@skipUnless(django.VERSION < (1, 6, 0), "test only applies to Django < 1.6")
def test_version_determining_only_child_depth(self):
self.assertEqual(1, self.get_manager().all()._get_maximum_depth())
class InheritanceManagerUsingModelsTests(TestCase):
def setUp(self):
self.parent1 = InheritanceManagerTestParent.objects.create()
self.child1 = InheritanceManagerTestChild1.objects.create()
self.child2 = InheritanceManagerTestChild2.objects.create()
self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create()
self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create()
def test_select_subclass_by_child_model(self):
"""
Confirm that passing a child model works the same as passing the
select_related manually
"""
objs = InheritanceManagerTestParent.objects.select_subclasses(
"inheritancemanagertestchild1").order_by('pk')
objsmodels = InheritanceManagerTestParent.objects.select_subclasses(
InheritanceManagerTestChild1).order_by('pk')
self.assertEqual(objs.subclasses, objsmodels.subclasses)
self.assertEqual(list(objs), list(objsmodels))
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_select_subclass_by_grandchild_model(self):
"""
Confirm that passing a grandchild model works the same as passing the
select_related manually
"""
objs = InheritanceManagerTestParent.objects.select_subclasses(
"inheritancemanagertestchild1__inheritancemanagertestgrandchild1")\
.order_by('pk')
objsmodels = InheritanceManagerTestParent.objects.select_subclasses(
InheritanceManagerTestGrandChild1).order_by('pk')
self.assertEqual(objs.subclasses, objsmodels.subclasses)
self.assertEqual(list(objs), list(objsmodels))
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_selecting_all_subclasses_specifically_grandchildren(self):
"""
A bare select_subclasses() should achieve the same results as doing
select_subclasses and specifying all possible subclasses.
This test checks grandchildren, so only works on 1.6>=
"""
objs = InheritanceManagerTestParent.objects.select_subclasses().order_by('pk')
objsmodels = InheritanceManagerTestParent.objects.select_subclasses(
InheritanceManagerTestChild1, InheritanceManagerTestChild2,
InheritanceManagerTestGrandChild1,
InheritanceManagerTestGrandChild1_2).order_by('pk')
self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses))
self.assertEqual(list(objs), list(objsmodels))
def test_selecting_all_subclasses_specifically_children(self):
"""
A bare select_subclasses() should achieve the same results as doing
select_subclasses and specifying all possible subclasses.
Note: This is sort of the same test as
`test_selecting_all_subclasses_specifically_grandchildren` but it
specifically switches what models are used because that happens
behind the scenes in a bare select_subclasses(), so we need to
emulate it.
"""
objs = InheritanceManagerTestParent.objects.select_subclasses().order_by('pk')
if django.VERSION >= (1, 6, 0):
models = (InheritanceManagerTestChild1, InheritanceManagerTestChild2,
InheritanceManagerTestGrandChild1,
InheritanceManagerTestGrandChild1_2)
else:
models = (InheritanceManagerTestChild1, InheritanceManagerTestChild2)
objsmodels = InheritanceManagerTestParent.objects.select_subclasses(
*models).order_by('pk')
# order shouldn't matter, I don't think, as long as the resulting
# queryset (when cast to a list) is the same.
self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses))
self.assertEqual(list(objs), list(objsmodels))
def test_select_subclass_just_self(self):
"""
Passing in the same model as the manager/queryset is bound against
(ie: the root parent) should have no effect on the result set.
"""
objsmodels = InheritanceManagerTestParent.objects.select_subclasses(
InheritanceManagerTestParent).order_by('pk')
self.assertEqual([], objsmodels.subclasses)
self.assertEqual(list(objsmodels), [
InheritanceManagerTestParent(pk=self.parent1.pk),
InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestParent(pk=self.child2.pk),
InheritanceManagerTestParent(pk=self.grandchild1.pk),
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
])
def test_select_subclass_invalid_related_model(self):
"""
Confirming that giving a stupid model doesn't work.
"""
from django.contrib.auth.models import User
regex = '^.+? is not a subclass of .+$'
with self.assertRaisesRegexp(ValueError, regex):
InheritanceManagerTestParent.objects.select_subclasses(
User).order_by('pk')
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_mixing_strings_and_classes_with_grandchildren(self):
"""
Given arguments consisting of both strings and model classes,
ensure the right resolutions take place, accounting for the extra
depth (grandchildren etc) 1.6> allows.
"""
objs = InheritanceManagerTestParent.objects.select_subclasses(
"inheritancemanagertestchild2",
InheritanceManagerTestGrandChild1_2).order_by('pk')
expecting = ['inheritancemanagertestchild1__inheritancemanagertestgrandchild1_2',
'inheritancemanagertestchild2']
self.assertEqual(set(objs.subclasses), set(expecting))
expecting2 = [
InheritanceManagerTestParent(pk=self.parent1.pk),
InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestChild2(pk=self.child2.pk),
InheritanceManagerTestParent(pk=self.grandchild1.pk),
InheritanceManagerTestGrandChild1_2(pk=self.grandchild1_2.pk),
]
self.assertEqual(list(objs), expecting2)
def test_mixing_strings_and_classes_with_children(self):
"""
Given arguments consisting of both strings and model classes,
ensure the right resolutions take place, walking down as far as
children.
"""
objs = InheritanceManagerTestParent.objects.select_subclasses(
"inheritancemanagertestchild2",
InheritanceManagerTestChild1).order_by('pk')
expecting = ['inheritancemanagertestchild1',
'inheritancemanagertestchild2']
self.assertEqual(set(objs.subclasses), set(expecting))
expecting2 = [
InheritanceManagerTestParent(pk=self.parent1.pk),
InheritanceManagerTestChild1(pk=self.child1.pk),
InheritanceManagerTestChild2(pk=self.child2.pk),
InheritanceManagerTestChild1(pk=self.grandchild1.pk),
InheritanceManagerTestChild1(pk=self.grandchild1_2.pk),
]
self.assertEqual(list(objs), expecting2)
def test_duplications(self):
"""
Check that even if the same thing is provided as a string and a model
that the right results are retrieved.
"""
# mixing strings and models which evaluate to the same thing is fine.
objs = InheritanceManagerTestParent.objects.select_subclasses(
"inheritancemanagertestchild2",
InheritanceManagerTestChild2).order_by('pk')
self.assertEqual(list(objs), [
InheritanceManagerTestParent(pk=self.parent1.pk),
InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestChild2(pk=self.child2.pk),
InheritanceManagerTestParent(pk=self.grandchild1.pk),
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
])
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_child_doesnt_accidentally_get_parent(self):
"""
Given a Child model which also has an InheritanceManager,
none of the returned objects should be Parent objects.
"""
objs = InheritanceManagerTestChild1.objects.select_subclasses(
InheritanceManagerTestGrandChild1).order_by('pk')
self.assertEqual([
InheritanceManagerTestChild1(pk=self.child1.pk),
InheritanceManagerTestGrandChild1(pk=self.grandchild1.pk),
InheritanceManagerTestChild1(pk=self.grandchild1_2.pk),
], list(objs))
class InheritanceManagerRelatedTests(InheritanceManagerTests):
def setUp(self):