Merge pull request #387 from jazzband/instance-of

Add ability to filter an InheritanceQuerySet by model.
This commit is contained in:
Asif Saif Uddin 2019-08-20 23:45:02 +06:00 committed by GitHub
commit d901b233d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 0 deletions

View file

@ -225,6 +225,31 @@ class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
pass
def instance_of(self, *models):
"""
Fetch only objects that are instances of the provided model(s).
"""
# If we aren't already selecting the subclasess, we need
# to in order to get this to work.
# How can we tell if we are not selecting subclasses?
# Is it safe to just apply .select_subclasses(*models)?
# Due to https://code.djangoproject.com/ticket/16572, we
# can't really do this for anything other than children (ie,
# no grandchildren+).
where_queries = []
for model in models:
where_queries.append('(' + ' AND '.join([
'"%s"."%s" IS NOT NULL' % (
model._meta.db_table,
field.attname, # Should this be something else?
) for field in model._meta.parents.values()
]) + ')')
return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
class InheritanceManagerMixin(object):
_queryset_class = InheritanceQuerySet
@ -237,6 +262,8 @@ class InheritanceManagerMixin(object):
def get_subclass(self, *args, **kwargs):
return self.get_queryset().get_subclass(*args, **kwargs)
def instance_of(self, *models):
return self.get_queryset().instance_of(*models)
class InheritanceManager(InheritanceManagerMixin, models.Manager):
pass

View file

@ -424,6 +424,61 @@ class InheritanceManagerUsingModelsTests(TestCase):
)
self.assertTrue(all(result.foo == (result.id + 1) for result in results))
def test_limit_to_specific_subclass(self):
child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3)
self.assertEqual([child3], list(results))
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_limit_to_specific_grandchild_class(self):
grandchild1 = InheritanceManagerTestGrandChild1.objects.get()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1)
self.assertEqual([grandchild1], list(results))
def test_limit_to_child_fetches_grandchildren_as_child_class(self):
# Not sure if this is the desired behaviour...?
children = InheritanceManagerTestChild1.objects.all()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild1)
self.assertEqual(set(children), set(results))
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_can_fetch_limited_class_grandchildren(self):
# Not sure if this is the desired behaviour...?
children = InheritanceManagerTestChild1.objects.select_subclasses()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild1).select_subclasses()
self.assertEqual(set(children), set(results))
def test_selecting_multiple_instance_classes(self):
child3 = InheritanceManagerTestChild3.objects.create()
children1 = InheritanceManagerTestChild1.objects.all()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3, InheritanceManagerTestChild1)
self.assertEqual(set([child3] + list(children1)), set(results))
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
def test_selecting_multiple_instance_classes_including_grandchildren(self):
child3 = InheritanceManagerTestChild3.objects.create()
grandchild1 = InheritanceManagerTestGrandChild1.objects.get()
results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3, InheritanceManagerTestGrandChild1).select_subclasses()
self.assertEqual(set([child3, grandchild1]), set(results))
def test_select_subclasses_interaction_with_instance_of(self):
child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3)
self.assertEqual(set([child3]), set(results))
class InheritanceManagerRelatedTests(InheritanceManagerTests):
def setUp(self):