diff --git a/model_utils/managers.py b/model_utils/managers.py index 0a2afc7..48911b2 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -171,6 +171,9 @@ class InheritanceQuerySetMixin(object): return levels def instance_of(self, *models): + """ + Fetch only objects that are instances of the provided model(s). + """ # If we aren't already selecting the subclasess, we need # to in order to get this to work. @@ -190,7 +193,7 @@ class InheritanceQuerySetMixin(object): ) for field in model._meta.parents.values() ]) + ')') - return self.extra(where=[' OR '.join(where_queries)]) + return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) class InheritanceManagerMixin(object): use_for_related_fields = True @@ -206,8 +209,8 @@ class InheritanceManagerMixin(object): def get_subclass(self, *args, **kwargs): return self.get_queryset().get_subclass(*args, **kwargs) - def instance_of(self, model): - return self.get_queryset().instance_of(model) + def instance_of(self, *models): + return self.get_queryset().instance_of(*models) class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): pass diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index ce3c3a5..6fea2ce 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -992,30 +992,51 @@ class InheritanceManagerUsingModelsTests(TestCase): def test_limit_to_specific_subclass(self): child3 = InheritanceManagerTestChild3.objects.create() - results = InheritanceManagerTestParent.objects.select_subclasses().instance_of(InheritanceManagerTestChild3) + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3) self.assertEqual([child3], list(results)) + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") def test_limit_to_specific_grandchild_class(self): grandchild1 = InheritanceManagerTestGrandChild1.objects.get() - results = InheritanceManagerTestParent.objects.select_subclasses().instance_of(InheritanceManagerTestGrandChild1) + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1) self.assertEqual([grandchild1], list(results)) - def test_limit_to_child_fetches_grandchildren(self): - children = InheritanceManagerTestChild1.objects.select_subclasses().all() + def test_limit_to_child_fetches_grandchildren_as_child_class(self): + # Not sure if this is the desired behaviour...? + children = InheritanceManagerTestChild1.objects.all() - results = InheritanceManagerTestParent.objects.select_subclasses().instance_of(InheritanceManagerTestChild1) + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild1) + + self.assertEqual(set(children), set(results)) + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_can_fetch_limited_class_grandchildren(self): + # Not sure if this is the desired behaviour...? + children = InheritanceManagerTestChild1.objects.select_subclasses() + + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild1).select_subclasses() self.assertEqual(set(children), set(results)) def test_selecting_multiple_instance_classes(self): + child3 = InheritanceManagerTestChild3.objects.create() + children1 = InheritanceManagerTestChild1.objects.all() + + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3, InheritanceManagerTestChild1) + + self.assertEqual(set([child3] + list(children1)), set(results)) + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_selecting_multiple_instance_classes_including_grandchildren(self): child3 = InheritanceManagerTestChild3.objects.create() grandchild1 = InheritanceManagerTestGrandChild1.objects.get() - results = InheritanceManagerTestParent.objects.select_subclasses().instance_of(InheritanceManagerTestChild3, InheritanceManagerTestGrandChild1) + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3, InheritanceManagerTestGrandChild1).select_subclasses() self.assertEqual(set([child3, grandchild1]), set(results)) + class InheritanceManagerRelatedTests(InheritanceManagerTests): def setUp(self):