diff --git a/README.rst b/README.rst index e8c3776..0cbc412 100644 --- a/README.rst +++ b/README.rst @@ -411,3 +411,17 @@ define a custom manager that inherits from ``PassThroughManager``:: fly, which broke pickling of those querysets. For this reason, ``PassThroughManager`` is recommended instead. +If you would like your custom ``QuerySet`` methods available through related +managers, use the convenience ``PassThroughManager.for_queryset_class``. For +example:: + + class Post(models.Model): + user = models.ForeignKey(User) + published = models.DateTimeField() + + objects = PassThroughManager.for_queryset_class(PostQuerySet)() + +Now you will be able to make queries like:: + + >>> u = User.objects.all()[0] + >>> a.post_set.published() diff --git a/model_utils/managers.py b/model_utils/managers.py index 02ab29b..828ce2a 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -126,6 +126,17 @@ class PassThroughManager(models.Manager): return self._queryset_cls(**kargs) return super(PassThroughManager, self).get_query_set() + @classmethod + def for_queryset_class(cls, queryset_cls): + class _PassThroughManager(cls): + def get_query_set(self): + kwargs = {} + if hasattr(self, "_db"): + kwargs["using"] = self._db + return queryset_cls(self.model, **kwargs) + + return _PassThroughManager + def manager_from(*mixins, **kwds): """ diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index fb939e4..1247436 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -225,3 +225,27 @@ class Dude(models.Model): objects = PassThroughManager(DudeQuerySet) abiders = AbidingManager() + + +class Car(models.Model): + name = models.CharField(max_length=20) + owner = models.ForeignKey(Dude, related_name='cars_owned') + + objects = PassThroughManager(DudeQuerySet) + + +class SpotQuerySet(models.query.QuerySet): + def closed(self): + return self.filter(closed=True) + + def secured(self): + return self.filter(secure=True) + + +class Spot(models.Model): + name = models.CharField(max_length=20) + secure = models.BooleanField(default=True) + closed = models.BooleanField(default=False) + owner = models.ForeignKey(Dude, related_name='spots_owned') + + objects = PassThroughManager.for_queryset_class(SpotQuerySet)() diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 810e78a..1da2903 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -21,7 +21,7 @@ from model_utils.tests.models import ( InheritanceManagerTestParent, InheritanceManagerTestChild1, InheritanceManagerTestChild2, TimeStamp, Post, Article, Status, StatusPlainTuple, TimeFrame, Monitored, StatusManagerAdded, - TimeFrameManagerAdded, Entry, Dude, SplitFieldAbstractParent) + TimeFrameManagerAdded, Entry, Dude, SplitFieldAbstractParent, Car, Spot) @@ -621,3 +621,24 @@ class PassThroughManagerTests(TestCase): saltyqs = pickle.dumps(qs) unqs = pickle.loads(saltyqs) self.assertEqual(unqs.by_name('The Dude').count(), 1) + + def test_queryset_not_available_on_related_manager(self): + dude = Dude.objects.by_name('Duder').get() + Car.objects.create(name='Ford', owner=dude) + self.assertFalse(hasattr(dude.cars_owned, 'by_name')) + + +class CreatePassThroughManagerTests(TestCase): + def setUp(self): + self.dude = Dude.objects.create(name='El Duderino') + Spot.objects.create(name='The Crib', owner=self.dude, closed=True, + secure=True) + + def test_reverse_manager(self): + self.assertEqual(self.dude.spots_owned.closed().count(), 1) + + def test_related_queryset_pickling(self): + qs = self.dude.spots_owned.closed() + pickled_qs = pickle.dumps(qs) + unpickled_qs = pickle.loads(pickled_qs) + self.assertEqual(unpickled_qs.secured().count(), 1)