diff --git a/model_utils/managers.py b/model_utils/managers.py index bcd0ef8..b5a0bde 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -170,7 +170,22 @@ class InheritanceQuerySetMixin(object): levels = 1 return levels - def instance_of(self, model): + def instance_of(self, *models): + # If we aren't already selecting the subclasess, we need + # to in order to get this to work. + + where_queries = [] + for model in models: + where_queries.append('(' + ' AND '.join([ + '"%s"."%s" IS NOT NULL' % ( + model._meta.db_table, + field.attname + ) for field in model._meta.parents.values() + ]) + ')') + + return self.extra(where=[' OR '.join(where_queries)]) + + # We need to get the parent_field = model._meta.parents.values()[0].attname query = '"%s"."%s" IS NOT NULL' % ( model._meta.db_table, diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 830d3f7..ce3c3a5 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -994,9 +994,28 @@ class InheritanceManagerUsingModelsTests(TestCase): child3 = InheritanceManagerTestChild3.objects.create() results = InheritanceManagerTestParent.objects.select_subclasses().instance_of(InheritanceManagerTestChild3) - self.assertEqual(1, len(results)) self.assertEqual([child3], list(results)) + + def test_limit_to_specific_grandchild_class(self): + grandchild1 = InheritanceManagerTestGrandChild1.objects.get() + results = InheritanceManagerTestParent.objects.select_subclasses().instance_of(InheritanceManagerTestGrandChild1) + self.assertEqual([grandchild1], list(results)) + + def test_limit_to_child_fetches_grandchildren(self): + children = InheritanceManagerTestChild1.objects.select_subclasses().all() + + results = InheritanceManagerTestParent.objects.select_subclasses().instance_of(InheritanceManagerTestChild1) + + self.assertEqual(set(children), set(results)) + + def test_selecting_multiple_instance_classes(self): + child3 = InheritanceManagerTestChild3.objects.create() + grandchild1 = InheritanceManagerTestGrandChild1.objects.get() + + results = InheritanceManagerTestParent.objects.select_subclasses().instance_of(InheritanceManagerTestChild3, InheritanceManagerTestGrandChild1) + + self.assertEqual(set([child3, grandchild1]), set(results)) class InheritanceManagerRelatedTests(InheritanceManagerTests): def setUp(self):