From 5da38ee0c6b48f132a983d8033389f152fcc6812 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Mon, 16 Aug 2010 17:58:16 -0400 Subject: [PATCH] manager_from bugfix from George Sakkis; fixes #1 --- CHANGES.rst | 1 + model_utils/managers.py | 32 ++++++++++++++++++++------------ model_utils/tests/models.py | 12 ++++++++++++ model_utils/tests/tests.py | 7 ++++++- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 5860908..383680e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,7 @@ CHANGES tip (unreleased) ---------------- +- incorporated manager_from inherited queryset bugfix (thanks George Sakkis) - added manager_from (thanks George Sakkis) - added StatusField, MonitorField, TimeFramedModel, and StatusModel (thanks Jannis Leidel) diff --git a/model_utils/managers.py b/model_utils/managers.py index 33c3f6f..76d615c 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -41,26 +41,34 @@ def manager_from(*mixins, **kwds): :keyword manager_cls: The base manager class to extend from (``django.db.models.manager.Manager`` by default). ''' + # collect separately the mixin classes and methods bases = [kwds.get('queryset_cls', QuerySet)] - attrs = {} + methods = {} for mixin in mixins: if isinstance(mixin, (ClassType, type)): bases.append(mixin) else: - try: attrs[mixin.__name__] = mixin + try: methods[mixin.__name__] = mixin except AttributeError: raise TypeError('Mixin must be class or function, not %s' % mixin.__class__) # create the QuerySet subclass id = hash(mixins + tuple(kwds.iteritems())) - qset_cls = type('Queryset_%d' % id, tuple(bases), attrs) + new_queryset_cls = type('Queryset_%d' % id, tuple(bases), methods) # create the Manager subclass - bases[0] = kwds.get('manager_cls', Manager) - def _get_query_set(self): - if hasattr(self, '_db'): - return qset_cls(self.model, using=self._db) - else: - return qset_cls(self.model) - attrs['get_query_set'] = _get_query_set - manager_cls = type('Manager_%d' % id, tuple(bases), attrs) - return manager_cls() + bases[0] = manager_cls = kwds.get('manager_cls', Manager) + new_manager_cls = type('Manager_%d' % id, tuple(bases), methods) + # and finally override new manager's get_query_set + super_get_query_set = manager_cls.get_query_set + def get_query_set(self): + # first honor the super manager's get_query_set + qs = super_get_query_set(self) + # and then try to bless the returned queryset by reassigning it to the + # newly created Queryset class, though this may not be feasible + if not issubclass(new_queryset_cls, qs.__class__): + raise TypeError('QuerySet subclass conflict: cannot determine a ' + 'unique class for queryset instance') + qs.__class__ = new_queryset_cls + return qs + new_manager_cls.get_query_set = get_query_set + return new_manager_cls() diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index 63abd91..cbe0048 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -83,8 +83,20 @@ class PublishedMixin(object): def unpublished(self): return self.filter(published=False) +class ByAuthorQuerySet(models.query.QuerySet, AuthorMixin): + pass + +class FeaturedManager(models.Manager): + def get_query_set(self): + return ByAuthorQuerySet(self.model, using=self._db).filter(feature=True) + class Entry(models.Model): author = models.CharField(max_length=20) published = models.BooleanField() + feature = models.BooleanField(default=False) objects = manager_from(AuthorMixin, PublishedMixin, unpublished) + broken = manager_from(PublishedMixin, manager_cls=FeaturedManager) + featured = manager_from(PublishedMixin, + manager_cls=FeaturedManager, + queryset_cls=ByAuthorQuerySet) diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index b446b95..0099751 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -369,7 +369,7 @@ class ManagerFromTests(TestCase): def setUp(self): Entry.objects.create(author='George', published=True) Entry.objects.create(author='George', published=False) - Entry.objects.create(author='Paul', published=True) + Entry.objects.create(author='Paul', published=True, feature=True) def test_chaining(self): self.assertEqual(Entry.objects.by_author('George').published().count(), @@ -381,3 +381,8 @@ class ManagerFromTests(TestCase): def test_typecheck(self): self.assertRaises(TypeError, manager_from, 'somestring') + def test_custom_get_query_set(self): + self.assertEqual(Entry.featured.published().count(), 1) + + def test_cant_reconcile_qs_class(self): + self.assertRaises(TypeError, Entry.broken.all)