From ac36cbf56cd9a9a5d2e9268616aa986b408f8bca Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 23 Nov 2010 12:48:23 -0500 Subject: [PATCH] Added InheritanceManager, contributed by Jeff Elmore. --- AUTHORS.rst | 1 + CHANGES.rst | 3 ++ README.rst | 96 ++++++++++++++++++++++++++++--------- TODO.rst | 4 +- model_utils/managers.py | 35 ++++++++++++++ model_utils/tests/models.py | 11 ++++- model_utils/tests/tests.py | 66 ++++++++++++++++--------- 7 files changed, 168 insertions(+), 48 deletions(-) diff --git a/AUTHORS.rst b/AUTHORS.rst index d18929e..a5da196 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -1,3 +1,4 @@ Carl Meyer Jannis Leidel Gregor Müllegger +Jeff Elmore diff --git a/CHANGES.rst b/CHANGES.rst index dcb91bb..2e344a8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,9 @@ CHANGES tip (unreleased) ---------------- +- added InheritanceManager, a better approach to selecting subclass instances + for Django 1.2+. Thanks Jeff Elmore. + - added InheritanceCastManager and InheritanceCastQuerySet, to allow bulk casting of a queryset to child types. Thanks Gregor Müllegger. diff --git a/README.rst b/README.rst index 9a4f9cb..838d3c3 100644 --- a/README.rst +++ b/README.rst @@ -24,7 +24,8 @@ your ``INSTALLED_APPS`` setting. Dependencies ------------ -``django-model-utils`` requires `Django`_ 1.0 or later. +Most of ``django-model-utils`` works with `Django`_ 1.0 or later. +`InheritanceManager`_ requires Django 1.2 or later. .. _Django: http://www.djangoproject.com/ @@ -218,38 +219,89 @@ returns objects with that status only:: # this query will only return published articles: Article.published.all() -InheritanceCastModel -==================== +InheritanceManager +================== -This abstract base class can be inherited by the root (parent) model -in a model-inheritance tree. It allows each model in the tree to -"know" what type it is (via an automatically-set foreign key to -``ContentType``), allowing for automatic casting of a parent instance -to its proper leaf (child) type. +This manager (`contributed by Jeff Elmore`_) should be attached to a base model +class in a model-inheritance tree. It allows queries on that base model to +return heterogenous results of the actual proper subtypes, without any +additional queries. -For instance, if you have a ``Place`` model with subclasses -``Restaurant`` and ``Bar``, you may want to query all Places:: +For instance, if you have a ``Place`` model with subclasses ``Restaurant`` and +``Bar``, you may want to query all Places:: nearby_places = Place.objects.filter(location='here') But when you iterate over ``nearby_places``, you'll get only ``Place`` -instances back, even for objects that are "really" ``Restaurant`` or -``Bar``. If you have ``Place`` inherit from ``InheritanceCastModel``, -you can just call the ``cast()`` method on each ``Place`` and it will -return an instance of the proper subtype, ``Restaurant`` or ``Bar``:: +instances back, even for objects that are "really" ``Restaurant`` or ``Bar``. +If you attach an ``InheritanceManager`` to ``Place``, you can just call the +``select_subclasses()`` method on the ``InheritanceManager`` or any +``QuerySet`` from it, and the resulting objects will be instances of +``Restaurant`` or ``Bar``:: + + from model_utils.managers import InheritanceManager + + class Place(models.Model): + # ... + objects = InheritanceManager() + + class Restaurant(Place): + # ... + + class Bar(Place): + # ... + + nearby_places = Place.objects.filter(location='here').select_subclasses() + for place in nearby_places: + # "place" will automatically be an instance of Place, Restaurant, or Bar + +The database query performed will have an extra join for each subclass; if you +want to reduce the number of joins and you only need particular subclasses to +be returned as their actual type, you can pass subclass names to +``select_subclasses()``, much like the built-in ``select_related()`` method:: + + nearby_places = Place.objects.select_subclasses("restaurant") + # restaurants will be Restaurant instances, bars will still be Place instances + +If you don't explicitly call ``select_subclasses()``, an ``InheritanceManager`` +behaves identically to a normal ``Manager``; so it's safe to use as your +default manager for the model. + +.. note:: + ``InheritanceManager`` currently only supports a single level of model + inheritance; it won't work for grandchild models. + +.. note:: + ``InheritanceManager`` requires Django 1.2 or later. + +.. _contributed by Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/ + + +InheritanceCastModel +==================== + +This abstract base class can be inherited by the root (parent) model in a +model-inheritance tree. It solves the same problem as `InheritanceManager`_ in +a way that requires more database queries and is less convenient to use, but is +compatible with Django versions prior to 1.2. Whenever possible, +`InheritanceManager`_ should be used instead. + +Usage:: from model_utils.models import InheritanceCastModel class Place(InheritanceCastModel): # ... - + class Restaurant(Place): # ... + class Bar(Place): + # ... + nearby_places = Place.objects.filter(location='here') for place in nearby_places: - restaurant_or_bar = place.cast() - # ... + restaurant_or_bar = place.cast() # ... This is inefficient for large querysets, as it results in a new query for every individual returned object. You can use the ``cast()`` method on a queryset to @@ -259,16 +311,16 @@ reduce this to as many queries as subtypes are involved:: for place in nearby_places.cast(): # ... -.. note:: The ``cast()`` queryset method does *not* return another - queryset but an already evaluated result of the database query. This means - that you cannot chain additional queryset methods after ``cast()``. +.. note:: + The ``cast()`` queryset method does *not* return another queryset but an + already evaluated result of the database query. This means that you cannot + chain additional queryset methods after ``cast()``. TimeStampedModel ================ This abstract base class just provides self-updating ``created`` and -``modified`` fields on any model that inherits from it. - +``modified`` fields on any model that inherits from it. QueryManager ============ diff --git a/TODO.rst b/TODO.rst index 53b2e52..05bc3ab 100644 --- a/TODO.rst +++ b/TODO.rst @@ -1,6 +1,4 @@ TODO ==== -* A version of InheritanceCastModel for 1.2+ (with reverse OneToOne - select_related now available) that doesn't require the added real_type - field. +* Add support for multiple levels of inheritance to ``InheritanceManager``. diff --git a/model_utils/managers.py b/model_utils/managers.py index 6df4a40..7b764dd 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -2,9 +2,44 @@ from types import ClassType from django.contrib.contenttypes.models import ContentType from django.db import models +from django.db.models.fields.related import SingleRelatedObjectDescriptor from django.db.models.manager import Manager from django.db.models.query import QuerySet +class InheritanceQuerySet(QuerySet): + def select_subclasses(self, *subclasses): + if not subclasses: + subclasses = [o for o in dir(self.model) + if isinstance(getattr(self.model, o), SingleRelatedObjectDescriptor) + and issubclass(getattr(self.model,o).related.model, self.model)] + new_qs = self.select_related(*subclasses) + new_qs.subclasses = subclasses + return new_qs + + def _clone(self, klass=None, setup=False, **kwargs): + try: + kwargs.update({'subclasses': self.subclasses}) + except AttributeError: + pass + return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs) + + def iterator(self): + iter = super(InheritanceQuerySet, self).iterator() + if getattr(self, 'subclasses', False): + for obj in iter: + obj = [getattr(obj, s) for s in self.subclasses if getattr(obj, s)] or [obj] + yield obj[0] + else: + for obj in iter: + yield obj + +class InheritanceManager(models.Manager): + def get_query_set(self): + return InheritanceQuerySet(self.model) + + def select_subclasses(self, *subclasses): + return self.get_query_set().select_subclasses(*subclasses) + class InheritanceCastMixin(object): def cast(self): diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index aea8244..5e63597 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -2,7 +2,7 @@ from django.db import models from django.utils.translation import ugettext_lazy as _ from model_utils.models import InheritanceCastModel, TimeStampedModel, StatusModel, TimeFramedModel -from model_utils.managers import QueryManager, manager_from +from model_utils.managers import QueryManager, manager_from, InheritanceManager from model_utils.fields import SplitField, MonitorField from model_utils import Choices @@ -15,6 +15,15 @@ class InheritChild(InheritParent): class InheritChild2(InheritParent): pass +class InheritanceManagerTestParent(models.Model): + objects = InheritanceManager() + +class InheritanceManagerTestChild1(InheritanceManagerTestParent): + pass + +class InheritanceManagerTestChild2(InheritanceManagerTestParent): + pass + class TimeStamp(TimeStampedModel): pass diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 455d1aa..d073f64 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -12,11 +12,11 @@ from model_utils import ChoiceEnum, Choices from model_utils.fields import get_excerpt, MonitorField from model_utils.managers import QueryManager, manager_from from model_utils.models import StatusModel, TimeFramedModel -from model_utils.tests.models import (InheritParent, InheritChild, InheritChild2, - TimeStamp, Post, Article, Status, - StatusPlainTuple, TimeFrame, Monitored, - StatusManagerAdded, TimeFrameManagerAdded, - Entry) +from model_utils.tests.models import ( + InheritParent, InheritChild, InheritChild2, InheritanceManagerTestParent, + InheritanceManagerTestChild1, InheritanceManagerTestChild2, + TimeStamp, Post, Article, Status, StatusPlainTuple, TimeFrame, Monitored, + StatusManagerAdded, TimeFrameManagerAdded, Entry) class GetExcerptTests(TestCase): @@ -35,11 +35,11 @@ class GetExcerptTests(TestCase): def test_middle_of_line(self): e = get_excerpt("some text more text") self.assertEquals(e, "some text more text") - + class SplitFieldTests(TestCase): full_text = u'summary\n\n\n\nmore' excerpt = u'summary\n' - + def setUp(self): self.post = Article.objects.create( title='example post', body=self.full_text) @@ -60,7 +60,7 @@ class SplitFieldTests(TestCase): post = Article.objects.create(title='example 2', body='some text\n\nsome more\n') self.failIf(post.body.has_more) - + def test_load_back(self): post = Article.objects.get(pk=self.post.pk) self.assertEquals(post.body.content, self.post.body.content) @@ -126,7 +126,7 @@ class MonitorFieldTests(TestCase): def test_no_monitor_arg(self): self.assertRaises(TypeError, MonitorField) - + class ChoicesTests(TestCase): def setUp(self): self.STATUS = Choices('DRAFT', 'PUBLISHED') @@ -149,7 +149,7 @@ class ChoicesTests(TestCase): def test_wrong_length_tuple(self): self.assertRaises(ValueError, Choices, ('a',)) - + class LabelChoicesTests(ChoicesTests): def setUp(self): self.STATUS = Choices( @@ -181,7 +181,7 @@ class LabelChoicesTests(ChoicesTests): "('PUBLISHED', 'PUBLISHED', 'is published'), " "('DELETED', 'DELETED', 'DELETED'))") - + class IdentifierChoicesTests(ChoicesTests): def setUp(self): self.STATUS = Choices( @@ -200,7 +200,7 @@ class IdentifierChoicesTests(ChoicesTests): def test_getattr(self): self.assertEquals(self.STATUS.DRAFT, 0) - + def test_repr(self): self.assertEquals(repr(self.STATUS), "Choices(" @@ -208,12 +208,12 @@ class IdentifierChoicesTests(ChoicesTests): "(1, 'PUBLISHED', 'is published'), " "(2, 'DELETED', 'is deleted'))") - + class InheritanceCastModelTests(TestCase): def setUp(self): self.parent = InheritParent.objects.create() self.child = InheritChild.objects.create() - + def test_parent_real_type(self): self.assertEquals(self.parent.real_type, ContentType.objects.get_for_model(InheritParent)) @@ -246,6 +246,28 @@ class InheritanceCastQuerysetTests(TestCase): set([parent, self.child, self.child2])) +class InheritanceManagerTests(TestCase): + def setUp(self): + self.child1 = InheritanceManagerTestChild1.objects.create() + self.child2 = InheritanceManagerTestChild2.objects.create() + + def test_normal(self): + self.assertEquals(set(InheritanceManagerTestParent.objects.all()), + set([ + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk), + ])) + + def test_select_all_subclasses(self): + self.assertEquals(set(InheritanceManagerTestParent.objects.select_subclasses()), + set([self.child1, self.child2])) + + def test_select_specific_subclasses(self): + self.assertEquals(set(InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild1")), + set([self.child1, InheritanceManagerTestParent(pk=self.child2.pk)])) + + class TimeStampedModelTests(TestCase): def test_created(self): t1 = TimeStamp.objects.create() @@ -263,7 +285,7 @@ class TimeFramedModelTests(TestCase): def setUp(self): self.now = datetime.now() - + def test_not_yet_begun(self): TimeFrame.objects.create(start=self.now+timedelta(days=2)) self.assertEquals(TimeFrame.timeframed.count(), 0) @@ -295,8 +317,8 @@ class TimeFrameManagerAddedTests(TestCase): class ErrorModel(TimeFramedModel): timeframed = models.BooleanField() self.assertRaises(ImproperlyConfigured, _run) - - + + class StatusModelTests(TestCase): def setUp(self): self.model = Status @@ -326,7 +348,7 @@ class StatusModelTests(TestCase): t1.save() self.assert_(t1.status_changed > date_active_again) - + class StatusModelPlainTupleTests(StatusModelTests): def setUp(self): self.model = StatusPlainTuple @@ -347,7 +369,7 @@ class StatusManagerAddedTests(TestCase): ) active = models.BooleanField() self.assertRaises(ImproperlyConfigured, _run) - + class QueryManagerTests(TestCase): def setUp(self): @@ -379,7 +401,7 @@ if 'south' in settings.INSTALLED_APPS: mf = Article._meta.get_field('body') args, kwargs = introspector(mf) self.assertEquals(kwargs['no_excerpt_field'], 'True') - + def test_no_excerpt_field_works(self): from models import NoRendered self.assertRaises(FieldDoesNotExist, @@ -391,7 +413,7 @@ class ManagerFromTests(TestCase): Entry.objects.create(author='George', published=True) Entry.objects.create(author='George', published=False) Entry.objects.create(author='Paul', published=True, feature=True) - + def test_chaining(self): self.assertEqual(Entry.objects.by_author('George').published().count(), 1) @@ -401,7 +423,7 @@ class ManagerFromTests(TestCase): def test_typecheck(self): self.assertRaises(TypeError, manager_from, 'somestring') - + def test_custom_get_query_set(self): self.assertEqual(Entry.featured.published().count(), 1)