diff --git a/docs/managers.rst b/docs/managers.rst index d2c074c..55d8bd4 100644 --- a/docs/managers.rst +++ b/docs/managers.rst @@ -51,6 +51,27 @@ be returned as their actual type, you can pass subclass names to nearby_places = Place.objects.select_subclasses("restaurant") # restaurants will be Restaurant instances, bars will still be Place instances + nearby_places = Place.objects.select_subclasses("restaurant", "bar") + # all Places will be converted to Restaurant and Bar instances. + +It is also possible to use the subclasses themselves as arguments to +``select_subclasses``, leaving it to calculate the relationship for you: + +.. code-block:: python + + nearby_places = Place.objects.select_subclasses(Restaurant) + # restaurants will be Restaurant instances, bars will still be Place instances + + nearby_places = Place.objects.select_subclasses(Restaurant, Bar) + # all Places will be converted to Restaurant and Bar instances. + +It is even possible to mix and match the two: + +.. code-block:: python + + nearby_places = Place.objects.select_subclasses(Restaurant, "bar") + # all Places will be converted to Restaurant and Bar instances. + ``InheritanceManager`` also provides a subclass-fetching alternative to the ``get()`` method: diff --git a/model_utils/managers.py b/model_utils/managers.py index 3c72585..349f980 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -7,20 +7,42 @@ from django.core.exceptions import ObjectDoesNotExist try: from django.db.models.constants import LOOKUP_SEP + from django.utils.six import string_types except ImportError: # Django < 1.5 from django.db.models.sql.constants import LOOKUP_SEP - + string_types = (basestring,) class InheritanceQuerySet(QuerySet): def select_subclasses(self, *subclasses): + levels = self._get_maximum_depth() + calculated_subclasses = self._get_subclasses_recurse(self.model, + levels=levels) + # if none were passed in, we can just short circuit and select all if not subclasses: - # only recurse one level on Django < 1.6 to avoid triggering - # https://code.djangoproject.com/ticket/16572 - levels = None - if django.VERSION < (1, 6, 0): - levels = 1 - subclasses = self._get_subclasses_recurse(self.model, levels=levels) + subclasses = calculated_subclasses + else: + verified_subclasses = [] + for subclass in subclasses: + # special case for passing in the same model as the queryset + # is bound against. Rather than raise an error later, we know + # we can allow this through. + if subclass is self.model: + continue + + if not isinstance(subclass, string_types): + subclass = self._get_ancestors_path(subclass, + levels=levels) + + if subclass in calculated_subclasses: + verified_subclasses.append(subclass) + else: + raise ValueError('%r is not in the discovered subclasses, ' + 'tried: %s' % (subclass, + ', '.join(calculated_subclasses), + )) + subclasses = verified_subclasses + # workaround https://code.djangoproject.com/ticket/16855 field_dict = self.query.select_related new_qs = self.select_related(*subclasses) @@ -69,9 +91,14 @@ class InheritanceQuerySet(QuerySet): def _get_subclasses_recurse(self, model, levels=None): + """ + Given a Model class, find all related objects, exploring children + recursively, returning a `list` of strings representing the + relations for select_related + """ rels = [rel for rel in model._meta.get_all_related_objects() - if isinstance(rel.field, OneToOneField) - and issubclass(rel.field.model, model)] + if isinstance(rel.field, OneToOneField) + and issubclass(rel.field.model, model)] subclasses = [] if levels: levels -= 1 @@ -84,6 +111,29 @@ class InheritanceQuerySet(QuerySet): return subclasses + def _get_ancestors_path(self, model, levels=None): + """ + Serves as an opposite to _get_subclasses_recurse, instead walking from + the Model class up the Model's ancestry and constructing the desired + select_related string backwards. + """ + if not issubclass(model, self.model): + raise ValueError("%r is not a subclass of %r" % (model, self.model)) + + ancestry = [] + # should be a OneToOneField or None + parent = model._meta.get_ancestor_link(self.model) + if levels: + levels -= 1 + while parent is not None: + ancestry.insert(0, parent.related.var_name) + if levels or levels is None: + parent = parent.related.parent_model._meta.get_ancestor_link(self.model) + else: + parent = None + return LOOKUP_SEP.join(ancestry) + + def _get_sub_obj_recurse(self, obj, s): rel, _, s = s.partition(LOOKUP_SEP) try: @@ -99,6 +149,18 @@ class InheritanceQuerySet(QuerySet): def get_subclass(self, *args, **kwargs): return self.select_subclasses().get(*args, **kwargs) + def _get_maximum_depth(self): + """ + Under Django versions < 1.6, to avoid triggering + https://code.djangoproject.com/ticket/16572 we can only look + as far as children. + """ + levels = None + if django.VERSION < (1, 6, 0): + levels = 1 + return levels + + class InheritanceManager(models.Manager): use_for_related_fields = True diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index 80537ed..88897f9 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -27,6 +27,8 @@ class InheritanceManagerTestParent(models.Model): normal_field = models.TextField() objects = InheritanceManager() + def __unicode__(self): + return unicode(self.pk) def __str__(self): return "%s(%s)" % ( @@ -39,6 +41,7 @@ class InheritanceManagerTestParent(models.Model): class InheritanceManagerTestChild1(InheritanceManagerTestParent): non_related_field_using_descriptor_2 = models.FileField(upload_to="test") normal_field_2 = models.TextField() + objects = InheritanceManager() class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1): @@ -52,7 +55,6 @@ class InheritanceManagerTestGrandChild1_2(InheritanceManagerTestChild1): class InheritanceManagerTestChild2(InheritanceManagerTestParent): non_related_field_using_descriptor_2 = models.FileField(upload_to="test") normal_field_2 = models.TextField() - pass diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 2d86596..4dda351 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -589,6 +589,17 @@ class InheritanceManagerTests(TestCase): set(self.get_manager().select_subclasses()), children) + def test_select_subclasses_invalid_relation(self): + """ + If an invalid relation string is provided, we can provide the user + with a list which is valid, rather than just have the select_related() + raise an AttributeError further in. + """ + regex = '^.+? is not in the discovered subclasses, tried:.+$' + with self.assertRaisesRegexp(ValueError, regex): + self.get_manager().select_subclasses('user') + + def test_select_specific_subclasses(self): children = set([ self.child1, @@ -662,6 +673,206 @@ class InheritanceManagerTests(TestCase): obj.inheritancemanagertestchild1 + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_version_determining_any_depth(self): + self.assertIsNone(self.get_manager().all()._get_maximum_depth()) + + + @skipUnless(django.VERSION < (1, 6, 0), "test only applies to Django < 1.6") + def test_version_determining_only_child_depth(self): + self.assertEqual(1, self.get_manager().all()._get_maximum_depth()) + + +class InheritanceManagerUsingModelsTests(TestCase): + + def setUp(self): + self.parent1 = InheritanceManagerTestParent.objects.create() + self.child1 = InheritanceManagerTestChild1.objects.create() + self.child2 = InheritanceManagerTestChild2.objects.create() + self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() + self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create() + + + def test_select_subclass_by_child_model(self): + """ + Confirm that passing a child model works the same as passing the + select_related manually + """ + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild1").order_by('pk') + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + InheritanceManagerTestChild1).order_by('pk') + self.assertEqual(objs.subclasses, objsmodels.subclasses) + self.assertEqual(list(objs), list(objsmodels)) + + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_select_subclass_by_grandchild_model(self): + """ + Confirm that passing a grandchild model works the same as passing the + select_related manually + """ + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild1__inheritancemanagertestgrandchild1")\ + .order_by('pk') + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + InheritanceManagerTestGrandChild1).order_by('pk') + self.assertEqual(objs.subclasses, objsmodels.subclasses) + self.assertEqual(list(objs), list(objsmodels)) + + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_selecting_all_subclasses_specifically_grandchildren(self): + """ + A bare select_subclasses() should achieve the same results as doing + select_subclasses and specifying all possible subclasses. + This test checks grandchildren, so only works on 1.6>= + """ + objs = InheritanceManagerTestParent.objects.select_subclasses().order_by('pk') + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + InheritanceManagerTestChild1, InheritanceManagerTestChild2, + InheritanceManagerTestGrandChild1, + InheritanceManagerTestGrandChild1_2).order_by('pk') + self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) + self.assertEqual(list(objs), list(objsmodels)) + + + def test_selecting_all_subclasses_specifically_children(self): + """ + A bare select_subclasses() should achieve the same results as doing + select_subclasses and specifying all possible subclasses. + + Note: This is sort of the same test as + `test_selecting_all_subclasses_specifically_grandchildren` but it + specifically switches what models are used because that happens + behind the scenes in a bare select_subclasses(), so we need to + emulate it. + """ + objs = InheritanceManagerTestParent.objects.select_subclasses().order_by('pk') + + if django.VERSION >= (1, 6, 0): + models = (InheritanceManagerTestChild1, InheritanceManagerTestChild2, + InheritanceManagerTestGrandChild1, + InheritanceManagerTestGrandChild1_2) + else: + models = (InheritanceManagerTestChild1, InheritanceManagerTestChild2) + + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + *models).order_by('pk') + # order shouldn't matter, I don't think, as long as the resulting + # queryset (when cast to a list) is the same. + self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) + self.assertEqual(list(objs), list(objsmodels)) + + + def test_select_subclass_just_self(self): + """ + Passing in the same model as the manager/queryset is bound against + (ie: the root parent) should have no effect on the result set. + """ + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + InheritanceManagerTestParent).order_by('pk') + self.assertEqual([], objsmodels.subclasses) + self.assertEqual(list(objsmodels), [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + ]) + + + def test_select_subclass_invalid_related_model(self): + """ + Confirming that giving a stupid model doesn't work. + """ + from django.contrib.auth.models import User + regex = '^.+? is not a subclass of .+$' + with self.assertRaisesRegexp(ValueError, regex): + InheritanceManagerTestParent.objects.select_subclasses( + User).order_by('pk') + + + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_mixing_strings_and_classes_with_grandchildren(self): + """ + Given arguments consisting of both strings and model classes, + ensure the right resolutions take place, accounting for the extra + depth (grandchildren etc) 1.6> allows. + """ + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild2", + InheritanceManagerTestGrandChild1_2).order_by('pk') + expecting = ['inheritancemanagertestchild1__inheritancemanagertestgrandchild1_2', + 'inheritancemanagertestchild2'] + self.assertEqual(set(objs.subclasses), set(expecting)) + expecting2 = [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestChild2(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestGrandChild1_2(pk=self.grandchild1_2.pk), + ] + self.assertEqual(list(objs), expecting2) + + + def test_mixing_strings_and_classes_with_children(self): + """ + Given arguments consisting of both strings and model classes, + ensure the right resolutions take place, walking down as far as + children. + """ + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild2", + InheritanceManagerTestChild1).order_by('pk') + expecting = ['inheritancemanagertestchild1', + 'inheritancemanagertestchild2'] + + self.assertEqual(set(objs.subclasses), set(expecting)) + expecting2 = [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestChild1(pk=self.child1.pk), + InheritanceManagerTestChild2(pk=self.child2.pk), + InheritanceManagerTestChild1(pk=self.grandchild1.pk), + InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), + ] + self.assertEqual(list(objs), expecting2) + + + def test_duplications(self): + """ + Check that even if the same thing is provided as a string and a model + that the right results are retrieved. + """ + # mixing strings and models which evaluate to the same thing is fine. + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild2", + InheritanceManagerTestChild2).order_by('pk') + self.assertEqual(list(objs), [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestChild2(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + ]) + + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_child_doesnt_accidentally_get_parent(self): + """ + Given a Child model which also has an InheritanceManager, + none of the returned objects should be Parent objects. + """ + objs = InheritanceManagerTestChild1.objects.select_subclasses( + InheritanceManagerTestGrandChild1).order_by('pk') + self.assertEqual([ + InheritanceManagerTestChild1(pk=self.child1.pk), + InheritanceManagerTestGrandChild1(pk=self.grandchild1.pk), + InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), + ], list(objs)) + + class InheritanceManagerRelatedTests(InheritanceManagerTests): def setUp(self):