mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-17 04:10:24 +00:00
via a OneToOne should present the same behaviour as the implicit ptr Django generates on subclasses. Thanks to Eran Rundstein for reporting the issue and proposing the fix.
272 lines
9.4 KiB
Python
272 lines
9.4 KiB
Python
from __future__ import unicode_literals
|
|
import django
|
|
from django.db import models
|
|
from django.db.models.fields.related import OneToOneField
|
|
from django.db.models.query import QuerySet
|
|
from django.core.exceptions import ObjectDoesNotExist
|
|
|
|
try:
|
|
from django.db.models.constants import LOOKUP_SEP
|
|
from django.utils.six import string_types
|
|
except ImportError: # Django < 1.5
|
|
from django.db.models.sql.constants import LOOKUP_SEP
|
|
string_types = (basestring,)
|
|
|
|
|
|
class InheritanceQuerySet(QuerySet):
|
|
def select_subclasses(self, *subclasses):
|
|
levels = self._get_maximum_depth()
|
|
calculated_subclasses = self._get_subclasses_recurse(
|
|
self.model, levels=levels)
|
|
# if none were passed in, we can just short circuit and select all
|
|
if not subclasses:
|
|
subclasses = calculated_subclasses
|
|
else:
|
|
verified_subclasses = []
|
|
for subclass in subclasses:
|
|
# special case for passing in the same model as the queryset
|
|
# is bound against. Rather than raise an error later, we know
|
|
# we can allow this through.
|
|
if subclass is self.model:
|
|
continue
|
|
|
|
if not isinstance(subclass, string_types):
|
|
subclass = self._get_ancestors_path(
|
|
subclass, levels=levels)
|
|
|
|
if subclass in calculated_subclasses:
|
|
verified_subclasses.append(subclass)
|
|
else:
|
|
raise ValueError(
|
|
'%r is not in the discovered subclasses, tried: %s' % (
|
|
subclass, ', '.join(calculated_subclasses))
|
|
)
|
|
subclasses = verified_subclasses
|
|
|
|
# workaround https://code.djangoproject.com/ticket/16855
|
|
previous_select_related = self.query.select_related
|
|
new_qs = self.select_related(*subclasses)
|
|
previous_is_dict = isinstance(previous_select_related, dict)
|
|
new_is_dict = isinstance(new_qs.query.select_related, dict)
|
|
if previous_is_dict and new_is_dict:
|
|
new_qs.query.select_related.update(previous_select_related)
|
|
new_qs.subclasses = subclasses
|
|
return new_qs
|
|
|
|
|
|
def _clone(self, klass=None, setup=False, **kwargs):
|
|
for name in ['subclasses', '_annotated']:
|
|
if hasattr(self, name):
|
|
kwargs[name] = getattr(self, name)
|
|
return super(InheritanceQuerySet, self)._clone(klass, setup, **kwargs)
|
|
|
|
|
|
def annotate(self, *args, **kwargs):
|
|
qset = super(InheritanceQuerySet, self).annotate(*args, **kwargs)
|
|
qset._annotated = [a.default_alias for a in args] + list(kwargs.keys())
|
|
return qset
|
|
|
|
|
|
def iterator(self):
|
|
iter = super(InheritanceQuerySet, self).iterator()
|
|
if getattr(self, 'subclasses', False):
|
|
# sort the subclass names longest first,
|
|
# so with 'a' and 'a__b' it goes as deep as possible
|
|
subclasses = sorted(self.subclasses, key=len, reverse=True)
|
|
for obj in iter:
|
|
sub_obj = None
|
|
for s in subclasses:
|
|
sub_obj = self._get_sub_obj_recurse(obj, s)
|
|
if sub_obj:
|
|
break
|
|
if not sub_obj:
|
|
sub_obj = obj
|
|
|
|
if getattr(self, '_annotated', False):
|
|
for k in self._annotated:
|
|
setattr(sub_obj, k, getattr(obj, k))
|
|
|
|
yield sub_obj
|
|
else:
|
|
for obj in iter:
|
|
yield obj
|
|
|
|
|
|
def _get_subclasses_recurse(self, model, levels=None):
|
|
"""
|
|
Given a Model class, find all related objects, exploring children
|
|
recursively, returning a `list` of strings representing the
|
|
relations for select_related
|
|
"""
|
|
rels = [
|
|
rel for rel in model._meta.get_all_related_objects()
|
|
if isinstance(rel.field, OneToOneField)
|
|
and issubclass(rel.field.model, model)
|
|
]
|
|
subclasses = []
|
|
if levels:
|
|
levels -= 1
|
|
for rel in rels:
|
|
if levels or levels is None:
|
|
for subclass in self._get_subclasses_recurse(
|
|
rel.field.model, levels=levels):
|
|
subclasses.append(rel.get_accessor_name() + LOOKUP_SEP + subclass)
|
|
subclasses.append(rel.get_accessor_name())
|
|
return subclasses
|
|
|
|
|
|
def _get_ancestors_path(self, model, levels=None):
|
|
"""
|
|
Serves as an opposite to _get_subclasses_recurse, instead walking from
|
|
the Model class up the Model's ancestry and constructing the desired
|
|
select_related string backwards.
|
|
"""
|
|
if not issubclass(model, self.model):
|
|
raise ValueError("%r is not a subclass of %r" % (model, self.model))
|
|
|
|
ancestry = []
|
|
# should be a OneToOneField or None
|
|
parent = model._meta.get_ancestor_link(self.model)
|
|
if levels:
|
|
levels -= 1
|
|
while parent is not None:
|
|
ancestry.insert(0, parent.related.get_accessor_name())
|
|
if levels or levels is None:
|
|
parent = parent.related.parent_model._meta.get_ancestor_link(
|
|
self.model)
|
|
else:
|
|
parent = None
|
|
return LOOKUP_SEP.join(ancestry)
|
|
|
|
|
|
def _get_sub_obj_recurse(self, obj, s):
|
|
rel, _, s = s.partition(LOOKUP_SEP)
|
|
try:
|
|
node = getattr(obj, rel)
|
|
except ObjectDoesNotExist:
|
|
return None
|
|
if s:
|
|
child = self._get_sub_obj_recurse(node, s)
|
|
return child
|
|
else:
|
|
return node
|
|
|
|
def get_subclass(self, *args, **kwargs):
|
|
return self.select_subclasses().get(*args, **kwargs)
|
|
|
|
def _get_maximum_depth(self):
|
|
"""
|
|
Under Django versions < 1.6, to avoid triggering
|
|
https://code.djangoproject.com/ticket/16572 we can only look
|
|
as far as children.
|
|
"""
|
|
levels = None
|
|
if django.VERSION < (1, 6, 0):
|
|
levels = 1
|
|
return levels
|
|
|
|
|
|
|
|
class InheritanceManager(models.Manager):
|
|
use_for_related_fields = True
|
|
|
|
def get_query_set(self):
|
|
return InheritanceQuerySet(self.model)
|
|
|
|
def select_subclasses(self, *subclasses):
|
|
return self.get_query_set().select_subclasses(*subclasses)
|
|
|
|
def get_subclass(self, *args, **kwargs):
|
|
return self.get_query_set().get_subclass(*args, **kwargs)
|
|
|
|
|
|
class QueryManager(models.Manager):
|
|
use_for_related_fields = True
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
if args:
|
|
self._q = args[0]
|
|
else:
|
|
self._q = models.Q(**kwargs)
|
|
self._order_by = None
|
|
super(QueryManager, self).__init__()
|
|
|
|
def order_by(self, *args):
|
|
self._order_by = args
|
|
return self
|
|
|
|
def get_query_set(self):
|
|
qs = super(QueryManager, self).get_query_set().filter(self._q)
|
|
if self._order_by is not None:
|
|
return qs.order_by(*self._order_by)
|
|
return qs
|
|
|
|
|
|
class PassThroughManager(models.Manager):
|
|
"""
|
|
Inherit from this Manager to enable you to call any methods from your
|
|
custom QuerySet class from your manager. Simply define your QuerySet
|
|
class, and return an instance of it from your manager's `get_query_set`
|
|
method.
|
|
|
|
Alternately, if you don't need any extra methods on your manager that
|
|
aren't on your QuerySet, then just pass your QuerySet class to the
|
|
``for_queryset_class`` class method.
|
|
|
|
class PostQuerySet(QuerySet):
|
|
def enabled(self):
|
|
return self.filter(disabled=False)
|
|
|
|
class Post(models.Model):
|
|
objects = PassThroughManager.for_queryset_class(PostQuerySet)()
|
|
|
|
"""
|
|
# pickling causes recursion errors
|
|
_deny_methods = ['__getstate__', '__setstate__', '__getinitargs__',
|
|
'__getnewargs__', '__copy__', '__deepcopy__', '_db']
|
|
|
|
def __init__(self, queryset_cls=None):
|
|
self._queryset_cls = queryset_cls
|
|
super(PassThroughManager, self).__init__()
|
|
|
|
def __getattr__(self, name):
|
|
if name in self._deny_methods:
|
|
raise AttributeError(name)
|
|
return getattr(self.get_query_set(), name)
|
|
|
|
def get_query_set(self):
|
|
qs = super(PassThroughManager, self).get_query_set()
|
|
if self._queryset_cls is not None:
|
|
qs = qs._clone(klass=self._queryset_cls)
|
|
return qs
|
|
|
|
@classmethod
|
|
def for_queryset_class(cls, queryset_cls):
|
|
return create_pass_through_manager_for_queryset_class(cls, queryset_cls)
|
|
|
|
|
|
def create_pass_through_manager_for_queryset_class(base, queryset_cls):
|
|
class _PassThroughManager(base):
|
|
def __init__(self):
|
|
return super(_PassThroughManager, self).__init__()
|
|
|
|
def get_query_set(self):
|
|
qs = super(_PassThroughManager, self).get_query_set()
|
|
return qs._clone(klass=queryset_cls)
|
|
|
|
def __reduce__(self):
|
|
# our pickling support breaks for subclasses (e.g. RelatedManager)
|
|
if self.__class__ is not _PassThroughManager:
|
|
return super(_PassThroughManager, self).__reduce__()
|
|
return (
|
|
unpickle_pass_through_manager_for_queryset_class,
|
|
(base, queryset_cls),
|
|
self.__dict__,
|
|
)
|
|
|
|
return _PassThroughManager
|
|
|
|
|
|
def unpickle_pass_through_manager_for_queryset_class(base, queryset_cls):
|
|
cls = create_pass_through_manager_for_queryset_class(base, queryset_cls)
|
|
return cls.__new__(cls)
|