diff --git a/model_utils/managers.py b/model_utils/managers.py index d4a3450..c9ec416 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals import django from django.db import models -from django.contrib.gis.db import models as geomodels from django.db.models.fields.related import OneToOneField from django.db.models.query import QuerySet from django.core.exceptions import ObjectDoesNotExist @@ -14,7 +13,7 @@ except ImportError: # Django < 1.5 string_types = (basestring,) -class InheritanceQuerySet(QuerySet): +class InheritanceMixin(object): def select_subclasses(self, *subclasses): levels = self._get_maximum_depth() calculated_subclasses = self._get_subclasses_recurse( @@ -59,17 +58,17 @@ class InheritanceQuerySet(QuerySet): for name in ['subclasses', '_annotated']: if hasattr(self, name): kwargs[name] = getattr(self, name) - return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs) + return super(InheritanceMixin, self)._clone(klass, setup, **kwargs) def annotate(self, *args, **kwargs): - qset = super(InheritanceQuerySet, self).annotate(*args, **kwargs) + qset = super(InheritanceMixin, self).annotate(*args, **kwargs) qset._annotated = [a.default_alias for a in args] + list(kwargs.keys()) return qset def iterator(self): - iter = super(InheritanceQuerySet, self).iterator() + iter = super(InheritanceMixin, self).iterator() if getattr(self, 'subclasses', False): # sort the subclass names longest first, # so with 'a' and 'a__b' it goes as deep as possible @@ -166,14 +165,9 @@ class InheritanceQuerySet(QuerySet): levels = 1 return levels - - -class InheritanceManager(models.Manager): +class InheritanceManagerMixin(object): use_for_related_fields = True - def get_queryset(self): - return InheritanceQuerySet(self.model) - get_query_set = get_queryset def select_subclasses(self, *subclasses): @@ -183,7 +177,15 @@ class InheritanceManager(models.Manager): return self.get_queryset().get_subclass(*args, **kwargs) -class QueryManager(models.Manager): +class InheritanceQuerySet(InheritanceMixin, QuerySet): + pass + +class InheritanceManager(InheritanceManagerMixin, models.Manager): + def get_queryset(self): + return InheritanceQuerySet(self.model) + + +class QueryMixin(object): use_for_related_fields = True def __init__(self, *args, **kwargs): @@ -192,7 +194,7 @@ class QueryManager(models.Manager): else: self._q = models.Q(**kwargs) self._order_by = None - super(QueryManager, self).__init__() + super(QueryMixin, self).__init__() def order_by(self, *args): self._order_by = args @@ -200,15 +202,17 @@ class QueryManager(models.Manager): def get_queryset(self): try: - qs = super(QueryManager, self).get_queryset().filter(self._q) + qs = super(QueryMixin, self).get_queryset().filter(self._q) except AttributeError: - qs = super(QueryManager, self).get_query_set().filter(self._q) + qs = super(QueryMixin, self).get_query_set().filter(self._q) if self._order_by is not None: return qs.order_by(*self._order_by) return qs get_query_set = get_queryset +class QueryManager(QueryMixin, models.Manager): + pass class PassThroughMixin(object): """ @@ -267,28 +271,6 @@ class PassThroughManager(PassThroughMixin, models.Manager): """ pass -class PassThroughGeoManager(PassThroughMixin, geomodels.GeoManager): - """ - For use with GeoDjango's GeoManager to enable spatial lookups. - Inherit from this Manager to enable you to call any methods from your - custom GeoQuerySet class from your manager. Simply define your GeoQuerySet - class, and return an instance of it from your manager's `get_queryset` - method. - - Alternately, if you don't need any extra methods on your manager that - aren't on your GeoQuerySet, then just pass your GeoQuerySet class to the - ``for_queryset_class`` class method. - - class LocationQuerySet(GeoQuerySet): - def within_boundary(self): - return self.filter(point__within=geom) - - class Location(models.Model): - objects = PassThroughGeoManager.for_queryset_class(LocationQuerySet)() - - """ - pass - def create_pass_through_manager_for_queryset_class(base, queryset_cls): class _PassThroughManager(base):