mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-05-23 17:15:49 +00:00
Merge branch 'master' into master
This commit is contained in:
commit
73dd5aa8fd
31 changed files with 816 additions and 189 deletions
|
|
@ -1,5 +1,2 @@
|
||||||
[run]
|
[run]
|
||||||
source = model_utils
|
include = model_utils/*.py
|
||||||
omit = .*
|
|
||||||
tests/*
|
|
||||||
*/_*
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,9 @@ python:
|
||||||
- 3.6
|
- 3.6
|
||||||
install: pip install tox-travis codecov
|
install: pip install tox-travis codecov
|
||||||
# positional args ({posargs}) to pass into tox.ini
|
# positional args ({posargs}) to pass into tox.ini
|
||||||
script: tox -- --cov
|
script: tox -- --cov --cov-append
|
||||||
|
services:
|
||||||
|
- postgresql
|
||||||
after_success: codecov
|
after_success: codecov
|
||||||
deploy:
|
deploy:
|
||||||
provider: pypi
|
provider: pypi
|
||||||
|
|
|
||||||
16
CHANGES.rst
16
CHANGES.rst
|
|
@ -3,6 +3,22 @@ CHANGES
|
||||||
|
|
||||||
master (unreleased)
|
master (unreleased)
|
||||||
-------------------
|
-------------------
|
||||||
|
- Fix handling of deferred attributes on Django 1.10+, fixes GH-278
|
||||||
|
- Fix `FieldTracker.has_changed()` and `FieldTracker.previous()` to return
|
||||||
|
correct responses for deferred fields.
|
||||||
|
- Update AutoLastModifiedField so that at instance creation it will
|
||||||
|
always be set equal to created to make querying easier. Fixes GH-254
|
||||||
|
- Support `reversed` for all kinds of `Choices` objects, fixes GH-309
|
||||||
|
- Fix Model instance non picklable GH-330
|
||||||
|
- Fix patched `save` in FieldTracker
|
||||||
|
|
||||||
|
3.1.2 (2018.05.09)
|
||||||
|
------------------
|
||||||
|
* Update InheritanceIterable to inherit from
|
||||||
|
ModelIterable instead of BaseIterable, fixes GH-277.
|
||||||
|
|
||||||
|
* Add all_objects Manager for 'SoftDeletableModel' to include soft
|
||||||
|
deleted objects on queries as per issue GH-255
|
||||||
|
|
||||||
3.1.1 (2017.12.17)
|
3.1.1 (2017.12.17)
|
||||||
------------------
|
------------------
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
Copyright (c) 2009-2015, Carl Meyer and contributors
|
Copyright (c) 2009-2019, Carl Meyer and contributors
|
||||||
All rights reserved.
|
All rights reserved.
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
|
|
||||||
11
README.rst
11
README.rst
|
|
@ -14,7 +14,7 @@ django-model-utils
|
||||||
|
|
||||||
Django model mixins and utilities.
|
Django model mixins and utilities.
|
||||||
|
|
||||||
``django-model-utils`` supports `Django`_ 1.8 to 2.0.
|
``django-model-utils`` supports `Django`_ 1.8 to 2.1.
|
||||||
|
|
||||||
.. _Django: http://www.djangoproject.com/
|
.. _Django: http://www.djangoproject.com/
|
||||||
|
|
||||||
|
|
@ -28,6 +28,15 @@ Getting Help
|
||||||
Documentation for django-model-utils is available
|
Documentation for django-model-utils is available
|
||||||
https://django-model-utils.readthedocs.io/
|
https://django-model-utils.readthedocs.io/
|
||||||
|
|
||||||
|
|
||||||
|
Run tests
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. code-block
|
||||||
|
|
||||||
|
pip install -e .
|
||||||
|
py.test
|
||||||
|
|
||||||
Contributing
|
Contributing
|
||||||
============
|
============
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,33 @@ it's safe to use as your default manager for the model.
|
||||||
|
|
||||||
.. _contributed by Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/
|
.. _contributed by Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/
|
||||||
|
|
||||||
|
JoinManager
|
||||||
|
-----------
|
||||||
|
|
||||||
|
The ``JoinManager`` will create a temporary table of your current queryset
|
||||||
|
and join that temporary table with the model of your current queryset. This can
|
||||||
|
be advantageous if you have to page through your entire DB and using django's
|
||||||
|
slice mechanism to do that. ``LIMIT .. OFFSET ..`` becomes slower the bigger
|
||||||
|
offset you use.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
sliced_qs = Place.objects.all()[2000:2010]
|
||||||
|
qs = sliced_qs.join()
|
||||||
|
# qs contains 10 objects, and there will be a much smaller performance hit
|
||||||
|
# for paging through all of first 2000 objects.
|
||||||
|
|
||||||
|
Alternatively, you can give it a queryset and the manager will create a temporary
|
||||||
|
table and join that to your current queryset. This can work as a more performant
|
||||||
|
alternative to using django's ``__in`` as described in the following
|
||||||
|
(`StackExchange answer`_).
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
big_qs = Restaurant.objects.filter(menu='vegetarian')
|
||||||
|
qs = Country.objects.filter(country_code='SE').join(big_qs)
|
||||||
|
|
||||||
|
.. _StackExchange answer: https://dba.stackexchange.com/questions/91247/optimizing-a-postgres-query-with-a-large-in
|
||||||
|
|
||||||
.. _QueryManager:
|
.. _QueryManager:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ modify your ``INSTALLED_APPS`` setting.
|
||||||
Dependencies
|
Dependencies
|
||||||
============
|
============
|
||||||
|
|
||||||
``django-model-utils`` supports `Django`_ 1.8 through 1.10 (latest bugfix
|
``django-model-utils`` supports `Django`_ 1.8 through 2.1 (latest bugfix
|
||||||
release in each series only) on Python 2.7, 3.3 (Django 1.8 only), 3.4 and 3.5.
|
release in each series only) on Python 2.7, 3.4, 3.5 and 3.6.
|
||||||
|
|
||||||
.. _Django: http://www.djangoproject.com/
|
.. _Django: http://www.djangoproject.com/
|
||||||
|
|
|
||||||
|
|
@ -150,6 +150,10 @@ Returns the value of the given field during the last save:
|
||||||
|
|
||||||
Returns ``None`` when the model instance isn't saved yet.
|
Returns ``None`` when the model instance isn't saved yet.
|
||||||
|
|
||||||
|
If a field is `deferred`_, calling ``previous()`` will load the previous value from the database.
|
||||||
|
|
||||||
|
.. _deferred: https://docs.djangoproject.com/en/2.0/ref/models/querysets/#defer
|
||||||
|
|
||||||
|
|
||||||
has_changed
|
has_changed
|
||||||
~~~~~~~~~~~
|
~~~~~~~~~~~
|
||||||
|
|
@ -167,6 +171,8 @@ Returns ``True`` if the given field has changed since the last save. The ``has_c
|
||||||
The ``has_changed`` method relies on ``previous`` to determine whether a
|
The ``has_changed`` method relies on ``previous`` to determine whether a
|
||||||
field's values has changed.
|
field's values has changed.
|
||||||
|
|
||||||
|
If a field is `deferred`_ and has been assigned locally, calling ``has_changed()``
|
||||||
|
will load the previous value from the database to perform the comparison.
|
||||||
|
|
||||||
changed
|
changed
|
||||||
~~~~~~~
|
~~~~~~~
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from .choices import Choices
|
from .choices import Choices # noqa:F401
|
||||||
from .tracker import FieldTracker, ModelTracker
|
from .tracker import FieldTracker, ModelTracker # noqa:F401
|
||||||
|
|
||||||
__version__ = '3.1.1'
|
__version__ = '3.2.0'
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,6 @@ class Choices(object):
|
||||||
|
|
||||||
self._process(choices)
|
self._process(choices)
|
||||||
|
|
||||||
|
|
||||||
def _store(self, triple, triple_collector, double_collector):
|
def _store(self, triple, triple_collector, double_collector):
|
||||||
self._identifier_map[triple[1]] = triple[0]
|
self._identifier_map[triple[1]] = triple[0]
|
||||||
self._display_map[triple[0]] = triple[2]
|
self._display_map[triple[0]] = triple[2]
|
||||||
|
|
@ -65,7 +64,6 @@ class Choices(object):
|
||||||
triple_collector.append(triple)
|
triple_collector.append(triple)
|
||||||
double_collector.append((triple[0], triple[2]))
|
double_collector.append((triple[0], triple[2]))
|
||||||
|
|
||||||
|
|
||||||
def _process(self, choices, triple_collector=None, double_collector=None):
|
def _process(self, choices, triple_collector=None, double_collector=None):
|
||||||
if triple_collector is None:
|
if triple_collector is None:
|
||||||
triple_collector = self._triples
|
triple_collector = self._triples
|
||||||
|
|
@ -94,18 +92,18 @@ class Choices(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Choices can't take a list of length %s, only 2 or 3"
|
"Choices can't take a list of length %s, only 2 or 3"
|
||||||
% len(choice)
|
% len(choice)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
store((choice, choice, choice))
|
store((choice, choice, choice))
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._doubles)
|
return len(self._doubles)
|
||||||
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self._doubles)
|
return iter(self._doubles)
|
||||||
|
|
||||||
|
def __reversed__(self):
|
||||||
|
return reversed(self._doubles)
|
||||||
|
|
||||||
def __getattr__(self, attname):
|
def __getattr__(self, attname):
|
||||||
try:
|
try:
|
||||||
|
|
@ -113,11 +111,9 @@ class Choices(object):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AttributeError(attname)
|
raise AttributeError(attname)
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self._display_map[key]
|
return self._display_map[key]
|
||||||
|
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
if isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
other = other._triples
|
other = other._triples
|
||||||
|
|
@ -125,29 +121,24 @@ class Choices(object):
|
||||||
other = list(other)
|
other = list(other)
|
||||||
return Choices(*(self._triples + other))
|
return Choices(*(self._triples + other))
|
||||||
|
|
||||||
|
|
||||||
def __radd__(self, other):
|
def __radd__(self, other):
|
||||||
# radd is never called for matching types, so we don't check here
|
# radd is never called for matching types, so we don't check here
|
||||||
other = list(other)
|
other = list(other)
|
||||||
return Choices(*(other + self._triples))
|
return Choices(*(other + self._triples))
|
||||||
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, self.__class__):
|
if isinstance(other, self.__class__):
|
||||||
return self._triples == other._triples
|
return self._triples == other._triples
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '%s(%s)' % (
|
return '%s(%s)' % (
|
||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
', '.join(("%s" % repr(i) for i in self._triples))
|
', '.join(("%s" % repr(i) for i in self._triples))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def __contains__(self, item):
|
def __contains__(self, item):
|
||||||
return item in self._db_values
|
return item in self._db_values
|
||||||
|
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, memo):
|
||||||
return self.__class__(*copy.deepcopy(self._triples, memo))
|
return self.__class__(*copy.deepcopy(self._triples, memo))
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,11 @@ class AutoLastModifiedField(AutoCreatedField):
|
||||||
"""
|
"""
|
||||||
def pre_save(self, model_instance, add):
|
def pre_save(self, model_instance, add):
|
||||||
value = now()
|
value = now()
|
||||||
|
if not model_instance.pk:
|
||||||
|
for field in model_instance._meta.get_fields():
|
||||||
|
if isinstance(field, AutoCreatedField):
|
||||||
|
value = getattr(model_instance, field.name)
|
||||||
|
break
|
||||||
setattr(model_instance, self.attname, value)
|
setattr(model_instance, self.attname, value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
@ -141,6 +146,7 @@ SPLIT_DEFAULT_PARAGRAPHS = getattr(settings, 'SPLIT_DEFAULT_PARAGRAPHS', 2)
|
||||||
|
|
||||||
_excerpt_field_name = lambda name: '_%s_excerpt' % name
|
_excerpt_field_name = lambda name: '_%s_excerpt' % name
|
||||||
|
|
||||||
|
|
||||||
def get_excerpt(content):
|
def get_excerpt(content):
|
||||||
excerpt = []
|
excerpt = []
|
||||||
default_excerpt = []
|
default_excerpt = []
|
||||||
|
|
@ -156,6 +162,7 @@ def get_excerpt(content):
|
||||||
|
|
||||||
return '\n'.join(default_excerpt)
|
return '\n'.join(default_excerpt)
|
||||||
|
|
||||||
|
|
||||||
@python_2_unicode_compatible
|
@python_2_unicode_compatible
|
||||||
class SplitText(object):
|
class SplitText(object):
|
||||||
def __init__(self, instance, field_name, excerpt_field_name):
|
def __init__(self, instance, field_name, excerpt_field_name):
|
||||||
|
|
@ -166,11 +173,13 @@ class SplitText(object):
|
||||||
self.excerpt_field_name = excerpt_field_name
|
self.excerpt_field_name = excerpt_field_name
|
||||||
|
|
||||||
# content is read/write
|
# content is read/write
|
||||||
def _get_content(self):
|
@property
|
||||||
|
def content(self):
|
||||||
return self.instance.__dict__[self.field_name]
|
return self.instance.__dict__[self.field_name]
|
||||||
def _set_content(self, val):
|
|
||||||
|
@content.setter
|
||||||
|
def content(self, val):
|
||||||
setattr(self.instance, self.field_name, val)
|
setattr(self.instance, self.field_name, val)
|
||||||
content = property(_get_content, _set_content)
|
|
||||||
|
|
||||||
# excerpt is a read only property
|
# excerpt is a read only property
|
||||||
def _get_excerpt(self):
|
def _get_excerpt(self):
|
||||||
|
|
@ -185,6 +194,7 @@ class SplitText(object):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
|
|
||||||
class SplitDescriptor(object):
|
class SplitDescriptor(object):
|
||||||
def __init__(self, field):
|
def __init__(self, field):
|
||||||
self.field = field
|
self.field = field
|
||||||
|
|
@ -205,6 +215,7 @@ class SplitDescriptor(object):
|
||||||
else:
|
else:
|
||||||
obj.__dict__[self.field.name] = value
|
obj.__dict__[self.field.name] = value
|
||||||
|
|
||||||
|
|
||||||
class SplitField(models.TextField):
|
class SplitField(models.TextField):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
# for South FakeORM compatibility: the frozen version of a
|
# for South FakeORM compatibility: the frozen version of a
|
||||||
|
|
|
||||||
BIN
model_utils/locale/cs/LC_MESSAGES/django.mo
Normal file
BIN
model_utils/locale/cs/LC_MESSAGES/django.mo
Normal file
Binary file not shown.
46
model_utils/locale/cs/LC_MESSAGES/django.po
Normal file
46
model_utils/locale/cs/LC_MESSAGES/django.po
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Czech translations of django-model-utils
|
||||||
|
#
|
||||||
|
# This file is distributed under the same license as the django-model-utils package.
|
||||||
|
#
|
||||||
|
# Translators:
|
||||||
|
# ------------
|
||||||
|
# Václav Dohnal <vaclav.dohnal@gmail.com>, 2018.
|
||||||
|
#
|
||||||
|
msgid ""
|
||||||
|
msgstr ""
|
||||||
|
"Project-Id-Version: django-model-utils\n"
|
||||||
|
"Report-Msgid-Bugs-To: https://github.com/jazzband/django-model-utils/issues\n"
|
||||||
|
"POT-Creation-Date: 2018-05-04 13:40+0200\n"
|
||||||
|
"PO-Revision-Date: 2018-05-04 13:46+0200\n"
|
||||||
|
"Language: cs\n"
|
||||||
|
"MIME-Version: 1.0\n"
|
||||||
|
"Content-Type: text/plain; charset=UTF-8\n"
|
||||||
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
"Plural-Forms: nplurals=3; plural=(n==1) ? 0 : (n>=2 && n<=4) ? 1 : 2;\n"
|
||||||
|
"Last-Translator: Václav Dohnal <vaclav.dohnal@gmail.com>\n"
|
||||||
|
"Language-Team: N/A\n"
|
||||||
|
"X-Generator: Poedit 2.0.7\n"
|
||||||
|
|
||||||
|
#: .\models.py:24
|
||||||
|
msgid "created"
|
||||||
|
msgstr "vytvořeno"
|
||||||
|
|
||||||
|
#: .\models.py:25
|
||||||
|
msgid "modified"
|
||||||
|
msgstr "upraveno"
|
||||||
|
|
||||||
|
#: .\models.py:37
|
||||||
|
msgid "start"
|
||||||
|
msgstr "začátek"
|
||||||
|
|
||||||
|
#: .\models.py:38
|
||||||
|
msgid "end"
|
||||||
|
msgstr "konec"
|
||||||
|
|
||||||
|
#: .\models.py:53
|
||||||
|
msgid "status"
|
||||||
|
msgstr "stav"
|
||||||
|
|
||||||
|
#: .\models.py:54
|
||||||
|
msgid "status changed"
|
||||||
|
msgstr "změna stavu"
|
||||||
BIN
model_utils/locale/ru/LC_MESSAGES/django.mo
Normal file
BIN
model_utils/locale/ru/LC_MESSAGES/django.mo
Normal file
Binary file not shown.
43
model_utils/locale/ru/LC_MESSAGES/django.po
Normal file
43
model_utils/locale/ru/LC_MESSAGES/django.po
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
# This file is distributed under the same license as the django-model-utils package.
|
||||||
|
#
|
||||||
|
# Translators:
|
||||||
|
# Arseny Sysolyatin <arseny.sysolyatin@gmail.com>, 2017.
|
||||||
|
msgid ""
|
||||||
|
msgstr ""
|
||||||
|
"Project-Id-Version: django-model-utils\n"
|
||||||
|
"Report-Msgid-Bugs-To: \n"
|
||||||
|
"POT-Creation-Date: 2017-05-22 19:46+0300\n"
|
||||||
|
"PO-Revision-Date: 2017-05-22 19:46+0300\n"
|
||||||
|
"Last-Translator: Arseny Sysolyatin <arseny.sysolyatin@gmail.com>\n"
|
||||||
|
"Language-Team: \n"
|
||||||
|
"Language: ru\n"
|
||||||
|
"MIME-Version: 1.0\n"
|
||||||
|
"Content-Type: text/plain; charset=UTF-8\n"
|
||||||
|
"Content-Transfer-Encoding: 8bit\n"
|
||||||
|
"Plural-Forms: nplurals=4; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && n"
|
||||||
|
"%10<=4 && (n%100<12 || n%100>14) ? 1 : n%10==0 || (n%10>=5 && n%10<=9) || (n"
|
||||||
|
"%100>=11 && n%100<=14)? 2 : 3);\n"
|
||||||
|
|
||||||
|
#: models.py:24
|
||||||
|
msgid "created"
|
||||||
|
msgstr "создано"
|
||||||
|
|
||||||
|
#: models.py:25
|
||||||
|
msgid "modified"
|
||||||
|
msgstr "изменено"
|
||||||
|
|
||||||
|
#: models.py:37
|
||||||
|
msgid "start"
|
||||||
|
msgstr "начало"
|
||||||
|
|
||||||
|
#: models.py:38
|
||||||
|
msgid "end"
|
||||||
|
msgstr "конец"
|
||||||
|
|
||||||
|
#: models.py:53
|
||||||
|
msgid "status"
|
||||||
|
msgstr "статус"
|
||||||
|
|
||||||
|
#: models.py:54
|
||||||
|
msgid "status changed"
|
||||||
|
msgstr "статус изменен"
|
||||||
|
|
@ -3,18 +3,17 @@ import django
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.db.models.fields.related import OneToOneField, OneToOneRel
|
from django.db.models.fields.related import OneToOneField, OneToOneRel
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
try:
|
from django.db.models.query import ModelIterable
|
||||||
from django.db.models.query import BaseIterable, ModelIterable
|
|
||||||
except ImportError:
|
|
||||||
# Django 1.8 does not have iterable classes
|
|
||||||
BaseIterable = object
|
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
|
|
||||||
from django.db.models.constants import LOOKUP_SEP
|
from django.db.models.constants import LOOKUP_SEP
|
||||||
from django.utils.six import string_types
|
from django.utils.six import string_types
|
||||||
|
|
||||||
|
from django.db import connection
|
||||||
|
from django.db.models.sql.datastructures import Join
|
||||||
|
|
||||||
class InheritanceIterable(BaseIterable):
|
|
||||||
|
class InheritanceIterable(ModelIterable):
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
queryset = self.queryset
|
queryset = self.queryset
|
||||||
iter = ModelIterable(queryset)
|
iter = ModelIterable(queryset)
|
||||||
|
|
@ -48,11 +47,10 @@ class InheritanceIterable(BaseIterable):
|
||||||
class InheritanceQuerySetMixin(object):
|
class InheritanceQuerySetMixin(object):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(InheritanceQuerySetMixin, self).__init__(*args, **kwargs)
|
super(InheritanceQuerySetMixin, self).__init__(*args, **kwargs)
|
||||||
if django.VERSION > (1, 8):
|
self._iterable_class = InheritanceIterable
|
||||||
self._iterable_class = InheritanceIterable
|
|
||||||
|
|
||||||
def select_subclasses(self, *subclasses):
|
def select_subclasses(self, *subclasses):
|
||||||
levels = self._get_maximum_depth()
|
levels = None
|
||||||
calculated_subclasses = self._get_subclasses_recurse(
|
calculated_subclasses = self._get_subclasses_recurse(
|
||||||
self.model, levels=levels)
|
self.model, levels=levels)
|
||||||
# if none were passed in, we can just short circuit and select all
|
# if none were passed in, we can just short circuit and select all
|
||||||
|
|
@ -77,7 +75,7 @@ class InheritanceQuerySetMixin(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'%r is not in the discovered subclasses, tried: %s' % (
|
'%r is not in the discovered subclasses, tried: %s' % (
|
||||||
subclass, ', '.join(calculated_subclasses))
|
subclass, ', '.join(calculated_subclasses))
|
||||||
)
|
)
|
||||||
subclasses = verified_subclasses
|
subclasses = verified_subclasses
|
||||||
|
|
||||||
# workaround https://code.djangoproject.com/ticket/16855
|
# workaround https://code.djangoproject.com/ticket/16855
|
||||||
|
|
@ -99,16 +97,16 @@ class InheritanceQuerySetMixin(object):
|
||||||
|
|
||||||
def _clone(self, klass=None, setup=False, **kwargs):
|
def _clone(self, klass=None, setup=False, **kwargs):
|
||||||
if django.VERSION >= (2, 0):
|
if django.VERSION >= (2, 0):
|
||||||
return super(InheritanceQuerySetMixin, self)._clone()
|
qs = super(InheritanceQuerySetMixin, self)._clone()
|
||||||
|
for name in ['subclasses', '_annotated']:
|
||||||
|
if hasattr(self, name):
|
||||||
|
setattr(qs, name, getattr(self, name))
|
||||||
|
return qs
|
||||||
|
|
||||||
for name in ['subclasses', '_annotated']:
|
for name in ['subclasses', '_annotated']:
|
||||||
if hasattr(self, name):
|
if hasattr(self, name):
|
||||||
kwargs[name] = getattr(self, name)
|
kwargs[name] = getattr(self, name)
|
||||||
|
|
||||||
if django.VERSION < (1, 9):
|
|
||||||
kwargs['klass'] = klass
|
|
||||||
kwargs['setup'] = setup
|
|
||||||
|
|
||||||
return super(InheritanceQuerySetMixin, self)._clone(**kwargs)
|
return super(InheritanceQuerySetMixin, self)._clone(**kwargs)
|
||||||
|
|
||||||
def annotate(self, *args, **kwargs):
|
def annotate(self, *args, **kwargs):
|
||||||
|
|
@ -151,19 +149,16 @@ class InheritanceQuerySetMixin(object):
|
||||||
recursively, returning a `list` of strings representing the
|
recursively, returning a `list` of strings representing the
|
||||||
relations for select_related
|
relations for select_related
|
||||||
"""
|
"""
|
||||||
if django.VERSION < (1, 8):
|
related_objects = [
|
||||||
related_objects = model._meta.get_all_related_objects()
|
f for f in model._meta.get_fields()
|
||||||
else:
|
if isinstance(f, OneToOneRel)]
|
||||||
related_objects = [
|
|
||||||
f for f in model._meta.get_fields()
|
|
||||||
if isinstance(f, OneToOneRel)]
|
|
||||||
|
|
||||||
rels = [
|
rels = [
|
||||||
rel for rel in related_objects
|
rel for rel in related_objects
|
||||||
if isinstance(rel.field, OneToOneField)
|
if isinstance(rel.field, OneToOneField)
|
||||||
and issubclass(rel.field.model, model)
|
and issubclass(rel.field.model, model)
|
||||||
and model is not rel.field.model
|
and model is not rel.field.model
|
||||||
]
|
]
|
||||||
|
|
||||||
subclasses = []
|
subclasses = []
|
||||||
if levels:
|
if levels:
|
||||||
|
|
@ -193,16 +188,10 @@ class InheritanceQuerySetMixin(object):
|
||||||
if levels:
|
if levels:
|
||||||
levels -= 1
|
levels -= 1
|
||||||
while parent_link is not None:
|
while parent_link is not None:
|
||||||
if django.VERSION < (1, 9):
|
related = parent_link.remote_field
|
||||||
related = parent_link.rel
|
|
||||||
else:
|
|
||||||
related = parent_link.remote_field
|
|
||||||
ancestry.insert(0, related.get_accessor_name())
|
ancestry.insert(0, related.get_accessor_name())
|
||||||
if levels or levels is None:
|
if levels or levels is None:
|
||||||
if django.VERSION < (1, 8):
|
parent_model = related.model
|
||||||
parent_model = related.parent_model
|
|
||||||
else:
|
|
||||||
parent_model = related.model
|
|
||||||
parent_link = parent_model._meta.get_ancestor_link(
|
parent_link = parent_model._meta.get_ancestor_link(
|
||||||
self.model)
|
self.model)
|
||||||
else:
|
else:
|
||||||
|
|
@ -230,17 +219,6 @@ class InheritanceQuerySetMixin(object):
|
||||||
def get_subclass(self, *args, **kwargs):
|
def get_subclass(self, *args, **kwargs):
|
||||||
return self.select_subclasses().get(*args, **kwargs)
|
return self.select_subclasses().get(*args, **kwargs)
|
||||||
|
|
||||||
def _get_maximum_depth(self):
|
|
||||||
"""
|
|
||||||
Under Django versions < 1.6, to avoid triggering
|
|
||||||
https://code.djangoproject.com/ticket/16572 we can only look
|
|
||||||
as far as children.
|
|
||||||
"""
|
|
||||||
levels = None
|
|
||||||
if django.VERSION < (1, 6, 0):
|
|
||||||
levels = 1
|
|
||||||
return levels
|
|
||||||
|
|
||||||
|
|
||||||
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
|
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
|
||||||
pass
|
pass
|
||||||
|
|
@ -326,3 +304,111 @@ class SoftDeletableManagerMixin(object):
|
||||||
|
|
||||||
class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager):
|
class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JoinQueryset(models.QuerySet):
|
||||||
|
|
||||||
|
def get_quoted_query(self, query):
|
||||||
|
query, params = query.sql_with_params()
|
||||||
|
|
||||||
|
# Put additional quotes around string.
|
||||||
|
params = [
|
||||||
|
'\'{}\''.format(p)
|
||||||
|
if isinstance(p, str) else p
|
||||||
|
for p in params
|
||||||
|
]
|
||||||
|
|
||||||
|
# Cast list of parameters to tuple because I got
|
||||||
|
# "not enough format characters" otherwise.
|
||||||
|
params = tuple(params)
|
||||||
|
return query % params
|
||||||
|
|
||||||
|
def join(self, qs=None):
|
||||||
|
'''
|
||||||
|
Join one queryset together with another using a temporary table. If
|
||||||
|
no queryset is used, it will use the current queryset and join that
|
||||||
|
to itself.
|
||||||
|
|
||||||
|
`Join` either uses the current queryset and effectively does a self-join to
|
||||||
|
create a new limited queryset OR it uses a querset given by the user.
|
||||||
|
|
||||||
|
The model of a given queryset needs to contain a valid foreign key to
|
||||||
|
the current queryset to perform a join. A new queryset is then created.
|
||||||
|
'''
|
||||||
|
to_field = 'id'
|
||||||
|
|
||||||
|
if qs:
|
||||||
|
fk = [
|
||||||
|
fk for fk in qs.model._meta.fields
|
||||||
|
if getattr(fk, 'related_model', None) == self.model
|
||||||
|
]
|
||||||
|
fk = fk[0] if fk else None
|
||||||
|
model_set = '{}_set'.format(self.model.__name__.lower())
|
||||||
|
key = fk or getattr(qs.model, model_set, None)
|
||||||
|
|
||||||
|
if not key:
|
||||||
|
raise ValueError('QuerySet is not related to current model')
|
||||||
|
|
||||||
|
try:
|
||||||
|
fk_column = key.column
|
||||||
|
except AttributeError:
|
||||||
|
fk_column = 'id'
|
||||||
|
to_field = key.field.column
|
||||||
|
|
||||||
|
qs = qs.only(fk_column)
|
||||||
|
# if we give a qs we need to keep the model qs to not lose anything
|
||||||
|
new_qs = self
|
||||||
|
else:
|
||||||
|
fk_column = 'id'
|
||||||
|
qs = self.only(fk_column)
|
||||||
|
new_qs = self.model.objects.all()
|
||||||
|
|
||||||
|
TABLE_NAME = 'temp_stuff'
|
||||||
|
query = self.get_quoted_query(qs.query)
|
||||||
|
sql = '''
|
||||||
|
DROP TABLE IF EXISTS {table_name};
|
||||||
|
DROP INDEX IF EXISTS {table_name}_id;
|
||||||
|
CREATE TEMPORARY TABLE {table_name} AS {query};
|
||||||
|
CREATE INDEX {table_name}_{fk_column} ON {table_name} ({fk_column});
|
||||||
|
'''.format(table_name=TABLE_NAME, fk_column=fk_column, query=str(query))
|
||||||
|
|
||||||
|
with connection.cursor() as cursor:
|
||||||
|
cursor.execute(sql)
|
||||||
|
|
||||||
|
class TempModel(models.Model):
|
||||||
|
temp_key = models.ForeignKey(
|
||||||
|
self.model,
|
||||||
|
on_delete=models.DO_NOTHING,
|
||||||
|
db_column=fk_column,
|
||||||
|
to_field=to_field
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
managed = False
|
||||||
|
db_table = TABLE_NAME
|
||||||
|
|
||||||
|
conn = Join(
|
||||||
|
table_name=TempModel._meta.db_table,
|
||||||
|
parent_alias=new_qs.query.get_initial_alias(),
|
||||||
|
table_alias=None,
|
||||||
|
join_type='INNER JOIN',
|
||||||
|
join_field=self.model.tempmodel_set.rel,
|
||||||
|
nullable=False
|
||||||
|
)
|
||||||
|
new_qs.query.join(conn, reuse=None)
|
||||||
|
return new_qs
|
||||||
|
|
||||||
|
|
||||||
|
class JoinManagerMixin(object):
|
||||||
|
"""
|
||||||
|
Manager that adds a method join. This method allows you to join two
|
||||||
|
querysets together.
|
||||||
|
"""
|
||||||
|
_queryset_class = JoinQueryset
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
return self._queryset_class(model=self.model, using=self._db)
|
||||||
|
|
||||||
|
|
||||||
|
class JoinManager(JoinManagerMixin, models.Manager):
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,7 @@ class SoftDeletableModel(models.Model):
|
||||||
abstract = True
|
abstract = True
|
||||||
|
|
||||||
objects = SoftDeletableManager()
|
objects = SoftDeletableManager()
|
||||||
|
all_objects = models.Manager()
|
||||||
|
|
||||||
def delete(self, using=None, soft=True, *args, **kwargs):
|
def delete(self, using=None, soft=True, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -29,12 +29,73 @@ class DescriptorMixin(object):
|
||||||
return self.field_name
|
return self.field_name
|
||||||
|
|
||||||
|
|
||||||
|
class DescriptorWrapper(object):
|
||||||
|
|
||||||
|
def __init__(self, field_name, descriptor, tracker_attname):
|
||||||
|
self.field_name = field_name
|
||||||
|
self.descriptor = descriptor
|
||||||
|
self.tracker_attname = tracker_attname
|
||||||
|
|
||||||
|
def __get__(self, instance, owner):
|
||||||
|
if instance is None:
|
||||||
|
return self
|
||||||
|
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||||
|
value = self.descriptor.__get__(instance, owner)
|
||||||
|
if was_deferred:
|
||||||
|
tracker_instance = getattr(instance, self.tracker_attname)
|
||||||
|
tracker_instance.saved_data[self.field_name] = deepcopy(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __set__(self, instance, value):
|
||||||
|
initialized = hasattr(instance, '_instance_intialized')
|
||||||
|
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||||
|
|
||||||
|
# Sentinel attribute to detect whether we are already trying to
|
||||||
|
# set the attribute higher up the stack. This prevents infinite
|
||||||
|
# recursion when retrieving deferred values from the database.
|
||||||
|
recursion_sentinel_attname = '_setting_' + self.field_name
|
||||||
|
already_setting = hasattr(instance, recursion_sentinel_attname)
|
||||||
|
|
||||||
|
if initialized and was_deferred and not already_setting:
|
||||||
|
setattr(instance, recursion_sentinel_attname, True)
|
||||||
|
try:
|
||||||
|
# Retrieve the value to set the saved_data value.
|
||||||
|
# This will undefer the field
|
||||||
|
getattr(instance, self.field_name)
|
||||||
|
finally:
|
||||||
|
instance.__dict__.pop(recursion_sentinel_attname, None)
|
||||||
|
if hasattr(self.descriptor, '__set__'):
|
||||||
|
self.descriptor.__set__(instance, value)
|
||||||
|
else:
|
||||||
|
instance.__dict__[self.field_name] = value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cls_for_descriptor(descriptor):
|
||||||
|
if hasattr(descriptor, '__delete__'):
|
||||||
|
return FullDescriptorWrapper
|
||||||
|
else:
|
||||||
|
return DescriptorWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class FullDescriptorWrapper(DescriptorWrapper):
|
||||||
|
"""
|
||||||
|
Wrapper for descriptors with all three descriptor methods.
|
||||||
|
"""
|
||||||
|
def __delete__(self, obj):
|
||||||
|
self.descriptor.__delete__(obj)
|
||||||
|
|
||||||
|
|
||||||
class FieldInstanceTracker(object):
|
class FieldInstanceTracker(object):
|
||||||
def __init__(self, instance, fields, field_map):
|
def __init__(self, instance, fields, field_map):
|
||||||
self.instance = instance
|
self.instance = instance
|
||||||
self.fields = fields
|
self.fields = fields
|
||||||
self.field_map = field_map
|
self.field_map = field_map
|
||||||
self.init_deferred_fields()
|
if django.VERSION < (1, 10):
|
||||||
|
self.init_deferred_fields()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deferred_fields(self):
|
||||||
|
return self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields()
|
||||||
|
|
||||||
def get_field_value(self, field):
|
def get_field_value(self, field):
|
||||||
return getattr(self.instance, self.field_map[field])
|
return getattr(self.instance, self.field_map[field])
|
||||||
|
|
@ -54,10 +115,11 @@ class FieldInstanceTracker(object):
|
||||||
def current(self, fields=None):
|
def current(self, fields=None):
|
||||||
"""Returns dict of current values for all tracked fields"""
|
"""Returns dict of current values for all tracked fields"""
|
||||||
if fields is None:
|
if fields is None:
|
||||||
if self.instance._deferred_fields:
|
deferred_fields = self.deferred_fields
|
||||||
|
if deferred_fields:
|
||||||
fields = [
|
fields = [
|
||||||
field for field in self.fields
|
field for field in self.fields
|
||||||
if field not in self.instance._deferred_fields
|
if field not in deferred_fields
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
fields = self.fields
|
fields = self.fields
|
||||||
|
|
@ -67,12 +129,31 @@ class FieldInstanceTracker(object):
|
||||||
def has_changed(self, field):
|
def has_changed(self, field):
|
||||||
"""Returns ``True`` if field has changed from currently saved value"""
|
"""Returns ``True`` if field has changed from currently saved value"""
|
||||||
if field in self.fields:
|
if field in self.fields:
|
||||||
|
# deferred fields haven't changed
|
||||||
|
if field in self.deferred_fields and field not in self.instance.__dict__:
|
||||||
|
return False
|
||||||
return self.previous(field) != self.get_field_value(field)
|
return self.previous(field) != self.get_field_value(field)
|
||||||
else:
|
else:
|
||||||
raise FieldError('field "%s" not tracked' % field)
|
raise FieldError('field "%s" not tracked' % field)
|
||||||
|
|
||||||
def previous(self, field):
|
def previous(self, field):
|
||||||
"""Returns currently saved value of given field"""
|
"""Returns currently saved value of given field"""
|
||||||
|
|
||||||
|
# handle deferred fields that have not yet been loaded from the database
|
||||||
|
if self.instance.pk and field in self.deferred_fields and field not in self.saved_data:
|
||||||
|
|
||||||
|
# if the field has not been assigned locally, simply fetch and un-defer the value
|
||||||
|
if field not in self.instance.__dict__:
|
||||||
|
self.get_field_value(field)
|
||||||
|
|
||||||
|
# if the field has been assigned locally, store the local value, fetch the database value,
|
||||||
|
# store database value to saved_data, and restore the local value
|
||||||
|
else:
|
||||||
|
current_value = self.get_field_value(field)
|
||||||
|
self.instance.refresh_from_db(fields=[field])
|
||||||
|
self.saved_data[field] = deepcopy(self.get_field_value(field))
|
||||||
|
setattr(self.instance, self.field_map[field], current_value)
|
||||||
|
|
||||||
return self.saved_data.get(field)
|
return self.saved_data.get(field)
|
||||||
|
|
||||||
def changed(self):
|
def changed(self):
|
||||||
|
|
@ -97,35 +178,15 @@ class FieldInstanceTracker(object):
|
||||||
def _get_field_name(self):
|
def _get_field_name(self):
|
||||||
return self.field.name
|
return self.field.name
|
||||||
|
|
||||||
if django.VERSION >= (1, 8):
|
self.instance._deferred_fields = self.instance.get_deferred_fields()
|
||||||
self.instance._deferred_fields = self.instance.get_deferred_fields()
|
for field in self.instance._deferred_fields:
|
||||||
for field in self.instance._deferred_fields:
|
field_obj = self.instance.__class__.__dict__.get(field)
|
||||||
if django.VERSION >= (1, 10):
|
if isinstance(field_obj, FileDescriptor):
|
||||||
field_obj = getattr(self.instance.__class__, field)
|
field_tracker = FileDescriptorTracker(field_obj.field)
|
||||||
else:
|
setattr(self.instance.__class__, field, field_tracker)
|
||||||
field_obj = self.instance.__class__.__dict__.get(field)
|
else:
|
||||||
if isinstance(field_obj, FileDescriptor):
|
field_tracker = DeferredAttributeTracker(field, type(self.instance))
|
||||||
field_tracker = FileDescriptorTracker(field_obj.field)
|
setattr(self.instance.__class__, field, field_tracker)
|
||||||
setattr(self.instance.__class__, field, field_tracker)
|
|
||||||
else:
|
|
||||||
field_tracker = DeferredAttributeTracker(
|
|
||||||
field_obj.field_name, None)
|
|
||||||
setattr(self.instance.__class__, field, field_tracker)
|
|
||||||
else:
|
|
||||||
for field in self.fields:
|
|
||||||
field_obj = self.instance.__class__.__dict__.get(field)
|
|
||||||
if isinstance(field_obj, DeferredAttribute):
|
|
||||||
self.instance._deferred_fields.add(field)
|
|
||||||
|
|
||||||
# Django 1.4
|
|
||||||
if django.VERSION >= (1, 5):
|
|
||||||
model = None
|
|
||||||
else:
|
|
||||||
model = field_obj.model_ref()
|
|
||||||
|
|
||||||
field_tracker = DeferredAttributeTracker(
|
|
||||||
field_obj.field_name, model)
|
|
||||||
setattr(self.instance.__class__, field, field_tracker)
|
|
||||||
|
|
||||||
|
|
||||||
class FieldTracker(object):
|
class FieldTracker(object):
|
||||||
|
|
@ -146,12 +207,19 @@ class FieldTracker(object):
|
||||||
def contribute_to_class(self, cls, name):
|
def contribute_to_class(self, cls, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.attname = '_%s' % name
|
self.attname = '_%s' % name
|
||||||
|
self.patch_save(cls)
|
||||||
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
|
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
|
||||||
|
|
||||||
def finalize_class(self, sender, **kwargs):
|
def finalize_class(self, sender, **kwargs):
|
||||||
if self.fields is None:
|
if self.fields is None:
|
||||||
self.fields = (field.attname for field in sender._meta.fields)
|
self.fields = (field.attname for field in sender._meta.fields)
|
||||||
self.fields = set(self.fields)
|
self.fields = set(self.fields)
|
||||||
|
if django.VERSION >= (1, 10):
|
||||||
|
for field_name in self.fields:
|
||||||
|
descriptor = getattr(sender, field_name)
|
||||||
|
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
|
||||||
|
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
|
||||||
|
setattr(sender, field_name, wrapped_descriptor)
|
||||||
self.field_map = self.get_field_map(sender)
|
self.field_map = self.get_field_map(sender)
|
||||||
models.signals.post_init.connect(self.initialize_tracker)
|
models.signals.post_init.connect(self.initialize_tracker)
|
||||||
self.model_class = sender
|
self.model_class = sender
|
||||||
|
|
@ -163,12 +231,13 @@ class FieldTracker(object):
|
||||||
tracker = self.tracker_class(instance, self.fields, self.field_map)
|
tracker = self.tracker_class(instance, self.fields, self.field_map)
|
||||||
setattr(instance, self.attname, tracker)
|
setattr(instance, self.attname, tracker)
|
||||||
tracker.set_saved_fields()
|
tracker.set_saved_fields()
|
||||||
self.patch_save(instance)
|
instance._instance_intialized = True
|
||||||
|
|
||||||
def patch_save(self, instance):
|
def patch_save(self, model):
|
||||||
original_save = instance.save
|
original_save = model.save
|
||||||
def save(**kwargs):
|
|
||||||
ret = original_save(**kwargs)
|
def save(instance, *args, **kwargs):
|
||||||
|
ret = original_save(instance, *args, **kwargs)
|
||||||
update_fields = kwargs.get('update_fields')
|
update_fields = kwargs.get('update_fields')
|
||||||
if not update_fields and update_fields is not None: # () or []
|
if not update_fields and update_fields is not None: # () or []
|
||||||
fields = update_fields
|
fields = update_fields
|
||||||
|
|
@ -183,7 +252,8 @@ class FieldTracker(object):
|
||||||
fields=fields
|
fields=fields
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
instance.save = save
|
|
||||||
|
model.save = save
|
||||||
|
|
||||||
def __get__(self, instance, owner):
|
def __get__(self, instance, owner):
|
||||||
if instance is None:
|
if instance is None:
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
pytest==3.3.1
|
pytest==3.3.1
|
||||||
pytest-django==3.1.2
|
pytest-django==3.1.2
|
||||||
|
psycopg2==2.7.6.1
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,19 @@
|
||||||
from __future__ import unicode_literals, absolute_import
|
from __future__ import unicode_literals, absolute_import
|
||||||
|
|
||||||
|
import django
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.db.models.query_utils import DeferredAttribute
|
||||||
from django.db.models import Manager
|
from django.db.models import Manager
|
||||||
from django.utils.encoding import python_2_unicode_compatible
|
from django.utils.encoding import python_2_unicode_compatible
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
from model_utils import Choices
|
from model_utils import Choices
|
||||||
from model_utils.fields import SplitField, MonitorField, StatusField
|
from model_utils.fields import SplitField, MonitorField, StatusField
|
||||||
from model_utils.managers import QueryManager, InheritanceManager
|
from model_utils.managers import (
|
||||||
|
QueryManager,
|
||||||
|
InheritanceManager,
|
||||||
|
JoinManagerMixin
|
||||||
|
)
|
||||||
from model_utils.models import (
|
from model_utils.models import (
|
||||||
SoftDeletableModel,
|
SoftDeletableModel,
|
||||||
StatusModel,
|
StatusModel,
|
||||||
|
|
@ -36,9 +42,6 @@ class InheritanceManagerTestParent(models.Model):
|
||||||
on_delete=models.CASCADE)
|
on_delete=models.CASCADE)
|
||||||
objects = InheritanceManager()
|
objects = InheritanceManager()
|
||||||
|
|
||||||
def __unicode__(self):
|
|
||||||
return unicode(self.pk)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "%s(%s)" % (
|
return "%s(%s)" % (
|
||||||
self.__class__.__name__[len('InheritanceManagerTest'):],
|
self.__class__.__name__[len('InheritanceManagerTest'):],
|
||||||
|
|
@ -331,3 +334,62 @@ class CustomSoftDelete(SoftDeletableModel):
|
||||||
is_read = models.BooleanField(default=False)
|
is_read = models.BooleanField(default=False)
|
||||||
|
|
||||||
objects = CustomSoftDeleteManager()
|
objects = CustomSoftDeleteManager()
|
||||||
|
|
||||||
|
|
||||||
|
class StringyDescriptor(object):
|
||||||
|
"""
|
||||||
|
Descriptor that returns a string version of the underlying integer value.
|
||||||
|
"""
|
||||||
|
def __init__(self, name):
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __get__(self, obj, cls=None):
|
||||||
|
if obj is None:
|
||||||
|
return self
|
||||||
|
if self.name in obj.get_deferred_fields():
|
||||||
|
# This queries the database, and sets the value on the instance.
|
||||||
|
if django.VERSION < (2, 1):
|
||||||
|
DeferredAttribute(field_name=self.name, model=cls).__get__(obj, cls)
|
||||||
|
else:
|
||||||
|
DeferredAttribute(field_name=self.name).__get__(obj, cls)
|
||||||
|
return str(obj.__dict__[self.name])
|
||||||
|
|
||||||
|
def __set__(self, obj, value):
|
||||||
|
obj.__dict__[self.name] = int(value)
|
||||||
|
|
||||||
|
def __delete__(self, obj):
|
||||||
|
del obj.__dict__[self.name]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDescriptorField(models.IntegerField):
|
||||||
|
def contribute_to_class(self, cls, name, **kwargs):
|
||||||
|
super(CustomDescriptorField, self).contribute_to_class(cls, name, **kwargs)
|
||||||
|
setattr(cls, name, StringyDescriptor(name))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWithCustomDescriptor(models.Model):
|
||||||
|
custom_field = CustomDescriptorField()
|
||||||
|
tracked_custom_field = CustomDescriptorField()
|
||||||
|
regular_field = models.IntegerField()
|
||||||
|
tracked_regular_field = models.IntegerField()
|
||||||
|
|
||||||
|
tracker = FieldTracker(fields=['tracked_custom_field', 'tracked_regular_field'])
|
||||||
|
|
||||||
|
|
||||||
|
class JoinManager(JoinManagerMixin, models.Manager):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BoxJoinModel(models.Model):
|
||||||
|
name = models.CharField(max_length=32)
|
||||||
|
objects = JoinManager()
|
||||||
|
|
||||||
|
|
||||||
|
class JoinItemForeignKey(models.Model):
|
||||||
|
weight = models.IntegerField()
|
||||||
|
belonging = models.ForeignKey(
|
||||||
|
BoxJoinModel,
|
||||||
|
null=True,
|
||||||
|
on_delete=models.CASCADE
|
||||||
|
)
|
||||||
|
objects = JoinManager()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,22 @@
|
||||||
|
import os
|
||||||
|
|
||||||
INSTALLED_APPS = (
|
INSTALLED_APPS = (
|
||||||
'model_utils',
|
'model_utils',
|
||||||
'tests',
|
'tests',
|
||||||
)
|
)
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
'default': {
|
"default": {
|
||||||
'ENGINE': 'django.db.backends.sqlite3'
|
"ENGINE": "django.db.backends.postgresql_psycopg2",
|
||||||
}
|
"NAME": os.environ.get("DJANGO_DATABASE_NAME_POSTGRES", "modelutils"),
|
||||||
|
"USER": os.environ.get("DJANGO_DATABASE_USER_POSTGRES", 'postgres'),
|
||||||
|
"PASSWORD": os.environ.get("DJANGO_DATABASE_PASSWORD_POSTGRES", ""),
|
||||||
|
"HOST": os.environ.get("DJANGO_DATABASE_HOST_POSTGRES", ""),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
SECRET_KEY = 'dummy'
|
SECRET_KEY = 'dummy'
|
||||||
|
|
||||||
|
CACHES = {
|
||||||
|
'default': {
|
||||||
|
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,12 @@ class ChoicesTests(TestCase):
|
||||||
self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED')
|
self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED')
|
||||||
|
|
||||||
def test_iteration(self):
|
def test_iteration(self):
|
||||||
self.assertEqual(tuple(self.STATUS), (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED')))
|
self.assertEqual(tuple(self.STATUS),
|
||||||
|
(('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED')))
|
||||||
|
|
||||||
|
def test_reversed(self):
|
||||||
|
self.assertEqual(tuple(reversed(self.STATUS)),
|
||||||
|
(('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT')))
|
||||||
|
|
||||||
def test_len(self):
|
def test_len(self):
|
||||||
self.assertEqual(len(self.STATUS), 2)
|
self.assertEqual(len(self.STATUS), 2)
|
||||||
|
|
@ -78,8 +83,15 @@ class LabelChoicesTests(ChoicesTests):
|
||||||
self.assertEqual(tuple(self.STATUS), (
|
self.assertEqual(tuple(self.STATUS), (
|
||||||
('DRAFT', 'is draft'),
|
('DRAFT', 'is draft'),
|
||||||
('PUBLISHED', 'is published'),
|
('PUBLISHED', 'is published'),
|
||||||
('DELETED', 'DELETED'))
|
('DELETED', 'DELETED'),
|
||||||
)
|
))
|
||||||
|
|
||||||
|
def test_reversed(self):
|
||||||
|
self.assertEqual(tuple(reversed(self.STATUS)), (
|
||||||
|
('DELETED', 'DELETED'),
|
||||||
|
('PUBLISHED', 'is published'),
|
||||||
|
('DRAFT', 'is draft'),
|
||||||
|
))
|
||||||
|
|
||||||
def test_indexing(self):
|
def test_indexing(self):
|
||||||
self.assertEqual(self.STATUS['PUBLISHED'], 'is published')
|
self.assertEqual(self.STATUS['PUBLISHED'], 'is published')
|
||||||
|
|
@ -169,7 +181,15 @@ class IdentifierChoicesTests(ChoicesTests):
|
||||||
self.assertEqual(tuple(self.STATUS), (
|
self.assertEqual(tuple(self.STATUS), (
|
||||||
(0, 'is draft'),
|
(0, 'is draft'),
|
||||||
(1, 'is published'),
|
(1, 'is published'),
|
||||||
(2, 'is deleted')))
|
(2, 'is deleted'),
|
||||||
|
))
|
||||||
|
|
||||||
|
def test_reversed(self):
|
||||||
|
self.assertEqual(tuple(reversed(self.STATUS)), (
|
||||||
|
(2, 'is deleted'),
|
||||||
|
(1, 'is published'),
|
||||||
|
(0, 'is draft'),
|
||||||
|
))
|
||||||
|
|
||||||
def test_indexing(self):
|
def test_indexing(self):
|
||||||
self.assertEqual(self.STATUS[1], 'is published')
|
self.assertEqual(self.STATUS[1], 'is published')
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from unittest import skipUnless
|
|
||||||
|
|
||||||
import django
|
import django
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
from django.core.cache import cache
|
||||||
from model_utils import FieldTracker
|
from model_utils import FieldTracker
|
||||||
|
from model_utils.tracker import DescriptorWrapper
|
||||||
from tests.models import (
|
from tests.models import (
|
||||||
Tracked, TrackedFK, InheritedTrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple,
|
Tracked, TrackedFK, InheritedTrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple,
|
||||||
InheritedTracked, TrackedFileField,
|
InheritedTracked, TrackedFileField,
|
||||||
|
|
@ -74,7 +73,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||||
self.assertChanged(name=None, number=None)
|
self.assertChanged(name=None, number=None)
|
||||||
self.instance.name = ''
|
self.instance.name = ''
|
||||||
self.assertChanged(name=None, number=None)
|
self.assertChanged(name=None, number=None)
|
||||||
self.instance.mutable = [1,2,3]
|
self.instance.mutable = [1, 2, 3]
|
||||||
self.assertChanged(name=None, number=None, mutable=None)
|
self.assertChanged(name=None, number=None, mutable=None)
|
||||||
|
|
||||||
def test_pre_save_has_changed(self):
|
def test_pre_save_has_changed(self):
|
||||||
|
|
@ -83,9 +82,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||||
self.assertHasChanged(name=True, number=False, mutable=False)
|
self.assertHasChanged(name=True, number=False, mutable=False)
|
||||||
self.instance.number = 7
|
self.instance.number = 7
|
||||||
self.assertHasChanged(name=True, number=True)
|
self.assertHasChanged(name=True, number=True)
|
||||||
self.instance.mutable = [1,2,3]
|
self.instance.mutable = [1, 2, 3]
|
||||||
self.assertHasChanged(name=True, number=True, mutable=True)
|
self.assertHasChanged(name=True, number=True, mutable=True)
|
||||||
|
|
||||||
|
def test_save_with_args(self):
|
||||||
|
self.instance.number = 1
|
||||||
|
self.instance.save(False, False, None, None)
|
||||||
|
self.assertChanged()
|
||||||
|
|
||||||
def test_first_save(self):
|
def test_first_save(self):
|
||||||
self.assertHasChanged(name=True, number=False, mutable=False)
|
self.assertHasChanged(name=True, number=False, mutable=False)
|
||||||
self.assertPrevious(name=None, number=None, mutable=None)
|
self.assertPrevious(name=None, number=None, mutable=None)
|
||||||
|
|
@ -93,22 +97,22 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||||
self.assertChanged(name=None)
|
self.assertChanged(name=None)
|
||||||
self.instance.name = 'retro'
|
self.instance.name = 'retro'
|
||||||
self.instance.number = 4
|
self.instance.number = 4
|
||||||
self.instance.mutable = [1,2,3]
|
self.instance.mutable = [1, 2, 3]
|
||||||
self.assertHasChanged(name=True, number=True, mutable=True)
|
self.assertHasChanged(name=True, number=True, mutable=True)
|
||||||
self.assertPrevious(name=None, number=None, mutable=None)
|
self.assertPrevious(name=None, number=None, mutable=None)
|
||||||
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
|
self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3])
|
||||||
self.assertChanged(name=None, number=None, mutable=None)
|
self.assertChanged(name=None, number=None, mutable=None)
|
||||||
|
|
||||||
self.instance.save(update_fields=[])
|
self.instance.save(update_fields=[])
|
||||||
self.assertHasChanged(name=True, number=True, mutable=True)
|
self.assertHasChanged(name=True, number=True, mutable=True)
|
||||||
self.assertPrevious(name=None, number=None, mutable=None)
|
self.assertPrevious(name=None, number=None, mutable=None)
|
||||||
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
|
self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3])
|
||||||
self.assertChanged(name=None, number=None, mutable=None)
|
self.assertChanged(name=None, number=None, mutable=None)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.instance.save(update_fields=['number'])
|
self.instance.save(update_fields=['number'])
|
||||||
|
|
||||||
def test_post_save_has_changed(self):
|
def test_post_save_has_changed(self):
|
||||||
self.update_instance(name='retro', number=4, mutable=[1,2,3])
|
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||||
self.assertHasChanged(name=False, number=False, mutable=False)
|
self.assertHasChanged(name=False, number=False, mutable=False)
|
||||||
self.instance.name = 'new age'
|
self.instance.name = 'new age'
|
||||||
self.assertHasChanged(name=True, number=False)
|
self.assertHasChanged(name=True, number=False)
|
||||||
|
|
@ -120,14 +124,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||||
self.assertHasChanged(name=False, number=True, mutable=True)
|
self.assertHasChanged(name=False, number=True, mutable=True)
|
||||||
|
|
||||||
def test_post_save_previous(self):
|
def test_post_save_previous(self):
|
||||||
self.update_instance(name='retro', number=4, mutable=[1,2,3])
|
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||||
self.instance.name = 'new age'
|
self.instance.name = 'new age'
|
||||||
self.assertPrevious(name='retro', number=4, mutable=[1,2,3])
|
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
|
||||||
self.instance.mutable[1] = 4
|
self.instance.mutable[1] = 4
|
||||||
self.assertPrevious(name='retro', number=4, mutable=[1,2,3])
|
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
|
||||||
|
|
||||||
def test_post_save_changed(self):
|
def test_post_save_changed(self):
|
||||||
self.update_instance(name='retro', number=4, mutable=[1,2,3])
|
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
self.instance.name = 'new age'
|
self.instance.name = 'new age'
|
||||||
self.assertChanged(name='retro')
|
self.assertChanged(name='retro')
|
||||||
|
|
@ -136,8 +140,8 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||||
self.instance.name = 'retro'
|
self.instance.name = 'retro'
|
||||||
self.assertChanged(number=4)
|
self.assertChanged(number=4)
|
||||||
self.instance.mutable[1] = 4
|
self.instance.mutable[1] = 4
|
||||||
self.assertChanged(number=4, mutable=[1,2,3])
|
self.assertChanged(number=4, mutable=[1, 2, 3])
|
||||||
self.instance.mutable = [1,2,3]
|
self.instance.mutable = [1, 2, 3]
|
||||||
self.assertChanged(number=4)
|
self.assertChanged(number=4)
|
||||||
|
|
||||||
def test_current(self):
|
def test_current(self):
|
||||||
|
|
@ -146,29 +150,29 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||||
self.assertCurrent(id=None, name='new age', number=None, mutable=None)
|
self.assertCurrent(id=None, name='new age', number=None, mutable=None)
|
||||||
self.instance.number = 8
|
self.instance.number = 8
|
||||||
self.assertCurrent(id=None, name='new age', number=8, mutable=None)
|
self.assertCurrent(id=None, name='new age', number=8, mutable=None)
|
||||||
self.instance.mutable = [1,2,3]
|
self.instance.mutable = [1, 2, 3]
|
||||||
self.assertCurrent(id=None, name='new age', number=8, mutable=[1,2,3])
|
self.assertCurrent(id=None, name='new age', number=8, mutable=[1, 2, 3])
|
||||||
self.instance.mutable[1] = 4
|
self.instance.mutable[1] = 4
|
||||||
self.assertCurrent(id=None, name='new age', number=8, mutable=[1,4,3])
|
self.assertCurrent(id=None, name='new age', number=8, mutable=[1, 4, 3])
|
||||||
self.instance.save()
|
self.instance.save()
|
||||||
self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1,4,3])
|
self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3])
|
||||||
|
|
||||||
def test_update_fields(self):
|
def test_update_fields(self):
|
||||||
self.update_instance(name='retro', number=4, mutable=[1,2,3])
|
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
self.instance.name = 'new age'
|
self.instance.name = 'new age'
|
||||||
self.instance.number = 8
|
self.instance.number = 8
|
||||||
self.instance.mutable = [4,5,6]
|
self.instance.mutable = [4, 5, 6]
|
||||||
self.assertChanged(name='retro', number=4, mutable=[1,2,3])
|
self.assertChanged(name='retro', number=4, mutable=[1, 2, 3])
|
||||||
self.instance.save(update_fields=[])
|
self.instance.save(update_fields=[])
|
||||||
self.assertChanged(name='retro', number=4, mutable=[1,2,3])
|
self.assertChanged(name='retro', number=4, mutable=[1, 2, 3])
|
||||||
self.instance.save(update_fields=['name'])
|
self.instance.save(update_fields=['name'])
|
||||||
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
||||||
self.assertEqual(in_db.name, self.instance.name)
|
self.assertEqual(in_db.name, self.instance.name)
|
||||||
self.assertNotEqual(in_db.number, self.instance.number)
|
self.assertNotEqual(in_db.number, self.instance.number)
|
||||||
self.assertChanged(number=4, mutable=[1,2,3])
|
self.assertChanged(number=4, mutable=[1, 2, 3])
|
||||||
self.instance.save(update_fields=['number'])
|
self.instance.save(update_fields=['number'])
|
||||||
self.assertChanged(mutable=[1,2,3])
|
self.assertChanged(mutable=[1, 2, 3])
|
||||||
self.instance.save(update_fields=['mutable'])
|
self.instance.save(update_fields=['mutable'])
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
||||||
|
|
@ -180,20 +184,65 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
||||||
self.instance.name = 'new age'
|
self.instance.name = 'new age'
|
||||||
self.instance.number = 1
|
self.instance.number = 1
|
||||||
self.instance.save()
|
self.instance.save()
|
||||||
item = list(self.tracked_class.objects.only('name').all())[0]
|
item = self.tracked_class.objects.only('name').first()
|
||||||
self.assertTrue(item._deferred_fields)
|
if django.VERSION >= (1, 10):
|
||||||
|
self.assertTrue(item.get_deferred_fields())
|
||||||
|
else:
|
||||||
|
self.assertTrue(item._deferred_fields)
|
||||||
|
|
||||||
self.assertEqual(item.tracker.previous('number'), None)
|
# has_changed() returns False for deferred fields, without un-deferring them.
|
||||||
self.assertTrue('number' in item._deferred_fields)
|
# Use an if because ModelTracked doesn't support has_changed() in this case.
|
||||||
|
if self.tracked_class == Tracked:
|
||||||
|
self.assertFalse(item.tracker.has_changed('number'))
|
||||||
|
if django.VERSION >= (1, 10):
|
||||||
|
self.assertIsInstance(item.__class__.number, DescriptorWrapper)
|
||||||
|
self.assertTrue('number' in item.get_deferred_fields())
|
||||||
|
else:
|
||||||
|
self.assertTrue('number' in item._deferred_fields)
|
||||||
|
|
||||||
|
# previous() un-defers field and returns value
|
||||||
|
self.assertEqual(item.tracker.previous('number'), 1)
|
||||||
|
if django.VERSION >= (1, 10):
|
||||||
|
self.assertNotIn('number', item.get_deferred_fields())
|
||||||
|
else:
|
||||||
|
self.assertNotIn('number', item._deferred_fields)
|
||||||
|
|
||||||
|
# examining a deferred field un-defers it
|
||||||
|
item = self.tracked_class.objects.only('name').first()
|
||||||
self.assertEqual(item.number, 1)
|
self.assertEqual(item.number, 1)
|
||||||
self.assertTrue('number' not in item._deferred_fields)
|
if django.VERSION >= (1, 10):
|
||||||
|
self.assertTrue('number' not in item.get_deferred_fields())
|
||||||
|
else:
|
||||||
|
self.assertTrue('number' not in item._deferred_fields)
|
||||||
self.assertEqual(item.tracker.previous('number'), 1)
|
self.assertEqual(item.tracker.previous('number'), 1)
|
||||||
self.assertFalse(item.tracker.has_changed('number'))
|
self.assertFalse(item.tracker.has_changed('number'))
|
||||||
|
|
||||||
|
# has_changed() returns correct values after deferred field is examined
|
||||||
|
self.assertFalse(item.tracker.has_changed('number'))
|
||||||
item.number = 2
|
item.number = 2
|
||||||
self.assertTrue(item.tracker.has_changed('number'))
|
self.assertTrue(item.tracker.has_changed('number'))
|
||||||
|
|
||||||
|
# previous() returns correct value after deferred field is examined
|
||||||
|
self.assertEqual(item.tracker.previous('number'), 1)
|
||||||
|
|
||||||
|
# assigning to a deferred field un-defers it
|
||||||
|
# Use an if because ModelTracked doesn't handle this case.
|
||||||
|
if self.tracked_class == Tracked:
|
||||||
|
|
||||||
|
item = self.tracked_class.objects.only('name').first()
|
||||||
|
item.number = 2
|
||||||
|
|
||||||
|
# previous() fetches correct value from database after deferred field is assigned
|
||||||
|
self.assertEqual(item.tracker.previous('number'), 1)
|
||||||
|
|
||||||
|
# database fetch of previous() value doesn't affect current value
|
||||||
|
self.assertEqual(item.number, 2)
|
||||||
|
|
||||||
|
# has_changed() returns correct values after deferred field is assigned
|
||||||
|
self.assertTrue(item.tracker.has_changed('number'))
|
||||||
|
item.number = 1
|
||||||
|
self.assertFalse(item.tracker.has_changed('number'))
|
||||||
|
|
||||||
|
|
||||||
class FieldTrackerMultipleInstancesTests(TestCase):
|
class FieldTrackerMultipleInstancesTests(TestCase):
|
||||||
|
|
||||||
|
|
@ -595,6 +644,16 @@ class ModelTrackerTests(FieldTrackerTests):
|
||||||
|
|
||||||
tracked_class = ModelTracked
|
tracked_class = ModelTracked
|
||||||
|
|
||||||
|
def test_cache_compatible(self):
|
||||||
|
cache.set('key', self.instance)
|
||||||
|
instance = cache.get('key')
|
||||||
|
instance.number = 1
|
||||||
|
instance.name = 'cached'
|
||||||
|
instance.save()
|
||||||
|
self.assertChanged()
|
||||||
|
instance.number = 2
|
||||||
|
self.assertHasChanged(number=True)
|
||||||
|
|
||||||
def test_pre_save_changed(self):
|
def test_pre_save_changed(self):
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
self.instance.name = 'new age'
|
self.instance.name = 'new age'
|
||||||
|
|
@ -603,7 +662,7 @@ class ModelTrackerTests(FieldTrackerTests):
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
self.instance.name = ''
|
self.instance.name = ''
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
self.instance.mutable = [1,2,3]
|
self.instance.mutable = [1, 2, 3]
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
|
|
||||||
def test_first_save(self):
|
def test_first_save(self):
|
||||||
|
|
@ -613,16 +672,16 @@ class ModelTrackerTests(FieldTrackerTests):
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
self.instance.name = 'retro'
|
self.instance.name = 'retro'
|
||||||
self.instance.number = 4
|
self.instance.number = 4
|
||||||
self.instance.mutable = [1,2,3]
|
self.instance.mutable = [1, 2, 3]
|
||||||
self.assertHasChanged(name=True, number=True, mutable=True)
|
self.assertHasChanged(name=True, number=True, mutable=True)
|
||||||
self.assertPrevious(name=None, number=None, mutable=None)
|
self.assertPrevious(name=None, number=None, mutable=None)
|
||||||
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
|
self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3])
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
|
|
||||||
self.instance.save(update_fields=[])
|
self.instance.save(update_fields=[])
|
||||||
self.assertHasChanged(name=True, number=True, mutable=True)
|
self.assertHasChanged(name=True, number=True, mutable=True)
|
||||||
self.assertPrevious(name=None, number=None, mutable=None)
|
self.assertPrevious(name=None, number=None, mutable=None)
|
||||||
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
|
self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3])
|
||||||
self.assertChanged()
|
self.assertChanged()
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.instance.save(update_fields=['number'])
|
self.instance.save(update_fields=['number'])
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from model_utils.fields import StatusField
|
||||||
from tests.models import (
|
from tests.models import (
|
||||||
Article, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled,
|
Article, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled,
|
||||||
StatusFieldChoicesName,
|
StatusFieldChoicesName,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class StatusFieldTests(TestCase):
|
class StatusFieldTests(TestCase):
|
||||||
|
|
|
||||||
22
tests/test_inheritance_iterable.py
Normal file
22
tests/test_inheritance_iterable.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from unittest import skipIf
|
||||||
|
|
||||||
|
import django
|
||||||
|
from django.test import TestCase
|
||||||
|
from django.db.models import Prefetch
|
||||||
|
|
||||||
|
from tests.models import InheritanceManagerTestParent, InheritanceManagerTestChild1
|
||||||
|
|
||||||
|
|
||||||
|
class InheritanceIterableTest(TestCase):
|
||||||
|
@skipIf(django.VERSION[:2] == (1, 10), "Django 1.10 expects ModelIterable not a subclass of it")
|
||||||
|
def test_prefetch(self):
|
||||||
|
qs = InheritanceManagerTestChild1.objects.all().prefetch_related(
|
||||||
|
Prefetch(
|
||||||
|
'normal_field',
|
||||||
|
queryset=InheritanceManagerTestParent.objects.all(),
|
||||||
|
to_attr='normal_field_prefetched'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEquals(qs.count(), 0)
|
||||||
|
|
@ -6,11 +6,12 @@ import django
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from tests.models import (InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1,
|
from tests.models import (
|
||||||
InheritanceManagerTestGrandChild1_2, InheritanceManagerTestParent,
|
InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1,
|
||||||
InheritanceManagerTestChild1,
|
InheritanceManagerTestGrandChild1_2, InheritanceManagerTestParent,
|
||||||
InheritanceManagerTestChild2, TimeFrame, InheritanceManagerTestChild3
|
InheritanceManagerTestChild1,
|
||||||
)
|
InheritanceManagerTestChild2, TimeFrame, InheritanceManagerTestChild3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InheritanceManagerTests(TestCase):
|
class InheritanceManagerTests(TestCase):
|
||||||
|
|
@ -115,9 +116,6 @@ class InheritanceManagerTests(TestCase):
|
||||||
"inheritancemanagertestchild2").get(pk=self.child1.pk)
|
"inheritancemanagertestchild2").get(pk=self.child1.pk)
|
||||||
obj.inheritancemanagertestchild1
|
obj.inheritancemanagertestchild1
|
||||||
|
|
||||||
def test_version_determining_any_depth(self):
|
|
||||||
self.assertIsNone(self.get_manager().all()._get_maximum_depth())
|
|
||||||
|
|
||||||
def test_manually_specifying_parent_fk_including_grandchildren(self):
|
def test_manually_specifying_parent_fk_including_grandchildren(self):
|
||||||
"""
|
"""
|
||||||
given a Model which inherits from another Model, but also declares
|
given a Model which inherits from another Model, but also declares
|
||||||
|
|
@ -125,10 +123,16 @@ class InheritanceManagerTests(TestCase):
|
||||||
ensure that the relation names and subclasses are obtained correctly.
|
ensure that the relation names and subclasses are obtained correctly.
|
||||||
"""
|
"""
|
||||||
child3 = InheritanceManagerTestChild3.objects.create()
|
child3 = InheritanceManagerTestChild3.objects.create()
|
||||||
results = InheritanceManagerTestParent.objects.all().select_subclasses()
|
qs = InheritanceManagerTestParent.objects.all()
|
||||||
|
results = qs.select_subclasses().order_by('pk')
|
||||||
|
|
||||||
expected_objs = [self.child1, self.child2, self.grandchild1,
|
expected_objs = [
|
||||||
self.grandchild1_2, child3]
|
self.child1,
|
||||||
|
self.child2,
|
||||||
|
self.grandchild1,
|
||||||
|
self.grandchild1_2,
|
||||||
|
child3
|
||||||
|
]
|
||||||
self.assertEqual(list(results), expected_objs)
|
self.assertEqual(list(results), expected_objs)
|
||||||
|
|
||||||
expected_related_names = [
|
expected_related_names = [
|
||||||
|
|
@ -148,7 +152,8 @@ class InheritanceManagerTests(TestCase):
|
||||||
"""
|
"""
|
||||||
related_name = 'manual_onetoone'
|
related_name = 'manual_onetoone'
|
||||||
child3 = InheritanceManagerTestChild3.objects.create()
|
child3 = InheritanceManagerTestChild3.objects.create()
|
||||||
results = InheritanceManagerTestParent.objects.all().select_subclasses(related_name)
|
qs = InheritanceManagerTestParent.objects.all()
|
||||||
|
results = qs.select_subclasses(related_name).order_by('pk')
|
||||||
|
|
||||||
expected_objs = [InheritanceManagerTestParent(pk=self.child1.pk),
|
expected_objs = [InheritanceManagerTestParent(pk=self.child1.pk),
|
||||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
InheritanceManagerTestParent(pk=self.child2.pk),
|
||||||
|
|
@ -180,27 +185,26 @@ class InheritanceManagerTests(TestCase):
|
||||||
|
|
||||||
# No argument to select_subclasses
|
# No argument to select_subclasses
|
||||||
objs_1 = list(
|
objs_1 = list(
|
||||||
self.get_manager().
|
self.get_manager()
|
||||||
select_subclasses().
|
.select_subclasses()
|
||||||
values_list('id')
|
.values_list('id')
|
||||||
)
|
)
|
||||||
|
|
||||||
# String argument to select_subclasses
|
# String argument to select_subclasses
|
||||||
objs_2 = list(
|
objs_2 = list(
|
||||||
self.get_manager().
|
self.get_manager()
|
||||||
select_subclasses(
|
.select_subclasses(
|
||||||
"inheritancemanagertestchild2"
|
"inheritancemanagertestchild2"
|
||||||
).
|
)
|
||||||
values_list('id')
|
.values_list('id')
|
||||||
)
|
)
|
||||||
|
|
||||||
# String argument to select_subclasses
|
# String argument to select_subclasses
|
||||||
objs_3 = list(
|
objs_3 = list(
|
||||||
self.get_manager().
|
self.get_manager()
|
||||||
select_subclasses(
|
.select_subclasses(
|
||||||
InheritanceManagerTestChild2
|
InheritanceManagerTestChild2
|
||||||
).
|
).values_list('id')
|
||||||
values_list('id')
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert all((
|
assert all((
|
||||||
|
|
@ -392,14 +396,16 @@ class InheritanceManagerUsingModelsTests(TestCase):
|
||||||
"""
|
"""
|
||||||
child3 = InheritanceManagerTestChild3.objects.create()
|
child3 = InheritanceManagerTestChild3.objects.create()
|
||||||
results = InheritanceManagerTestParent.objects.all().select_subclasses(
|
results = InheritanceManagerTestParent.objects.all().select_subclasses(
|
||||||
InheritanceManagerTestChild3)
|
InheritanceManagerTestChild3).order_by('pk')
|
||||||
|
|
||||||
expected_objs = [InheritanceManagerTestParent(pk=self.parent1.pk),
|
expected_objs = [
|
||||||
InheritanceManagerTestParent(pk=self.child1.pk),
|
InheritanceManagerTestParent(pk=self.parent1.pk),
|
||||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
InheritanceManagerTestParent(pk=self.child1.pk),
|
||||||
InheritanceManagerTestParent(pk=self.grandchild1.pk),
|
InheritanceManagerTestParent(pk=self.child2.pk),
|
||||||
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
|
InheritanceManagerTestParent(pk=self.grandchild1.pk),
|
||||||
child3]
|
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
|
||||||
|
child3
|
||||||
|
]
|
||||||
self.assertEqual(list(results), expected_objs)
|
self.assertEqual(list(results), expected_objs)
|
||||||
|
|
||||||
expected_related_names = ['manual_onetoone']
|
expected_related_names = ['manual_onetoone']
|
||||||
|
|
@ -454,3 +460,7 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
|
||||||
qs = InheritanceManagerTestParent.objects.annotate(
|
qs = InheritanceManagerTestParent.objects.annotate(
|
||||||
test_count=models.Count('id')).select_subclasses()
|
test_count=models.Count('id')).select_subclasses()
|
||||||
self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
|
self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
|
||||||
|
|
||||||
|
def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self):
|
||||||
|
qs = InheritanceManagerTestParent.objects.select_subclasses()
|
||||||
|
self.assertEqual(qs.subclasses, qs._clone().subclasses)
|
||||||
|
|
|
||||||
38
tests/test_managers/test_join_manager.py
Normal file
38
tests/test_managers/test_join_manager.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from tests.models import JoinItemForeignKey, BoxJoinModel
|
||||||
|
|
||||||
|
|
||||||
|
class JoinManagerTest(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
for i in range(20):
|
||||||
|
BoxJoinModel.objects.create(name='name_{i}'.format(i=i))
|
||||||
|
|
||||||
|
JoinItemForeignKey.objects.create(
|
||||||
|
weight=10, belonging=BoxJoinModel.objects.get(name='name_1')
|
||||||
|
)
|
||||||
|
JoinItemForeignKey.objects.create(weight=20)
|
||||||
|
|
||||||
|
def test_self_join(self):
|
||||||
|
a_slice = BoxJoinModel.objects.all()[0:10]
|
||||||
|
with self.assertNumQueries(1):
|
||||||
|
result = a_slice.join()
|
||||||
|
self.assertEquals(result.count(), 10)
|
||||||
|
|
||||||
|
def test_self_join_with_where_statement(self):
|
||||||
|
qs = BoxJoinModel.objects.filter(name='name_1')
|
||||||
|
result = qs.join()
|
||||||
|
self.assertEquals(result.count(), 1)
|
||||||
|
|
||||||
|
def test_join_with_other_qs(self):
|
||||||
|
item_qs = JoinItemForeignKey.objects.filter(weight=10)
|
||||||
|
boxes = BoxJoinModel.objects.all().join(qs=item_qs)
|
||||||
|
self.assertEquals(boxes.count(), 1)
|
||||||
|
self.assertEquals(boxes[0].name, 'name_1')
|
||||||
|
|
||||||
|
def test_reverse_join(self):
|
||||||
|
box_qs = BoxJoinModel.objects.filter(name='name_1')
|
||||||
|
items = JoinItemForeignKey.objects.all().join(box_qs)
|
||||||
|
self.assertEquals(items.count(), 1)
|
||||||
|
self.assertEquals(items[0].weight, 10)
|
||||||
71
tests/test_models/test_deferred_fields.py
Normal file
71
tests/test_models/test_deferred_fields.py
Normal file
|
|
@ -0,0 +1,71 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import django
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from tests.models import ModelWithCustomDescriptor
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDescriptorTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.instance = ModelWithCustomDescriptor.objects.create(
|
||||||
|
custom_field='1',
|
||||||
|
tracked_custom_field='1',
|
||||||
|
regular_field=1,
|
||||||
|
tracked_regular_field=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_custom_descriptor_works(self):
|
||||||
|
instance = self.instance
|
||||||
|
self.assertEqual(instance.custom_field, '1')
|
||||||
|
self.assertEqual(instance.__dict__['custom_field'], 1)
|
||||||
|
self.assertEqual(instance.regular_field, 1)
|
||||||
|
instance.custom_field = 2
|
||||||
|
self.assertEqual(instance.custom_field, '2')
|
||||||
|
self.assertEqual(instance.__dict__['custom_field'], 2)
|
||||||
|
instance.save()
|
||||||
|
instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk)
|
||||||
|
self.assertEqual(instance.custom_field, '2')
|
||||||
|
self.assertEqual(instance.__dict__['custom_field'], 2)
|
||||||
|
|
||||||
|
def test_deferred(self):
|
||||||
|
instance = ModelWithCustomDescriptor.objects.only('id').get(
|
||||||
|
pk=self.instance.pk)
|
||||||
|
if django.VERSION >= (1, 10):
|
||||||
|
self.assertIn('custom_field', instance.get_deferred_fields())
|
||||||
|
else:
|
||||||
|
self.assertIn('custom_field', instance._deferred_fields)
|
||||||
|
self.assertEqual(instance.custom_field, '1')
|
||||||
|
if django.VERSION >= (1, 10):
|
||||||
|
self.assertNotIn('custom_field', instance.get_deferred_fields())
|
||||||
|
else:
|
||||||
|
self.assertNotIn('custom_field', instance._deferred_fields)
|
||||||
|
self.assertEqual(instance.regular_field, 1)
|
||||||
|
self.assertEqual(instance.tracked_custom_field, '1')
|
||||||
|
self.assertEqual(instance.tracked_regular_field, 1)
|
||||||
|
|
||||||
|
self.assertFalse(instance.tracker.has_changed('tracked_custom_field'))
|
||||||
|
self.assertFalse(instance.tracker.has_changed('tracked_regular_field'))
|
||||||
|
|
||||||
|
instance.tracked_custom_field = 2
|
||||||
|
instance.tracked_regular_field = 2
|
||||||
|
self.assertTrue(instance.tracker.has_changed('tracked_custom_field'))
|
||||||
|
self.assertTrue(instance.tracker.has_changed('tracked_regular_field'))
|
||||||
|
instance.save()
|
||||||
|
|
||||||
|
instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk)
|
||||||
|
self.assertEqual(instance.custom_field, '1')
|
||||||
|
self.assertEqual(instance.regular_field, 1)
|
||||||
|
self.assertEqual(instance.tracked_custom_field, '2')
|
||||||
|
self.assertEqual(instance.tracked_regular_field, 2)
|
||||||
|
|
||||||
|
instance = ModelWithCustomDescriptor.objects.only('id').get(pk=instance.pk)
|
||||||
|
if django.VERSION >= (1, 10):
|
||||||
|
# This fails on 1.8 and 1.9, which is a bug in the deferred field
|
||||||
|
# implementation on those versions.
|
||||||
|
instance.tracked_custom_field = 3
|
||||||
|
self.assertEqual(instance.tracked_custom_field, '3')
|
||||||
|
self.assertTrue(instance.tracker.has_changed('tracked_custom_field'))
|
||||||
|
del instance.tracked_custom_field
|
||||||
|
self.assertEqual(instance.tracked_custom_field, '2')
|
||||||
|
self.assertFalse(instance.tracker.has_changed('tracked_custom_field'))
|
||||||
|
|
@ -18,7 +18,7 @@ class StatusModelTests(TestCase):
|
||||||
c1 = self.model.objects.create()
|
c1 = self.model.objects.create()
|
||||||
self.assertTrue(c1.status_changed, datetime(2016, 1, 1))
|
self.assertTrue(c1.status_changed, datetime(2016, 1, 1))
|
||||||
|
|
||||||
c2 = self.model.objects.create()
|
self.model.objects.create()
|
||||||
self.assertEqual(self.model.active.count(), 2)
|
self.assertEqual(self.model.active.count(), 2)
|
||||||
self.assertEqual(self.model.deleted.count(), 0)
|
self.assertEqual(self.model.deleted.count(), 0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,13 @@ class TimeStampedModelTests(TestCase):
|
||||||
t1 = TimeStamp.objects.create()
|
t1 = TimeStamp.objects.create()
|
||||||
self.assertEqual(t1.created, datetime(2016, 1, 1))
|
self.assertEqual(t1.created, datetime(2016, 1, 1))
|
||||||
|
|
||||||
|
def test_created_sets_modified(self):
|
||||||
|
'''
|
||||||
|
Ensure that on creation that modifed is set exactly equal to created.
|
||||||
|
'''
|
||||||
|
t1 = TimeStamp.objects.create()
|
||||||
|
self.assertEqual(t1.created, t1.modified)
|
||||||
|
|
||||||
def test_modified(self):
|
def test_modified(self):
|
||||||
with freeze_time(datetime(2016, 1, 1)):
|
with freeze_time(datetime(2016, 1, 1)):
|
||||||
t1 = TimeStamp.objects.create()
|
t1 = TimeStamp.objects.create()
|
||||||
|
|
|
||||||
30
tox.ini
30
tox.ini
|
|
@ -1,24 +1,44 @@
|
||||||
[tox]
|
[tox]
|
||||||
envlist =
|
envlist =
|
||||||
py27-django{18,19,110,111}
|
py27-django{19,110,111}
|
||||||
py34-django{18,19,110,111,200}
|
py34-django{19,110,111,200}
|
||||||
py35-django{18,19,110,111,200,trunk}
|
py35-django{19,110,111,200,201,trunk}
|
||||||
py36-django{111,200,trunk}
|
py36-django{111,200,201,trunk}
|
||||||
|
flake8
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
deps =
|
deps =
|
||||||
django18: Django>=1.8,<1.9
|
|
||||||
django19: Django>=1.9,<1.10
|
django19: Django>=1.9,<1.10
|
||||||
django110: Django>=1.10,<1.11
|
django110: Django>=1.10,<1.11
|
||||||
django111: Django>=1.11,<1.12
|
django111: Django>=1.11,<1.12
|
||||||
django200: Django>=2.0,<2.1
|
django200: Django>=2.0,<2.1
|
||||||
|
django201: Django>=2.1,<2.2
|
||||||
djangotrunk: https://github.com/django/django/archive/master.tar.gz
|
djangotrunk: https://github.com/django/django/archive/master.tar.gz
|
||||||
freezegun == 0.3.8
|
freezegun == 0.3.8
|
||||||
-rrequirements-test.txt
|
-rrequirements-test.txt
|
||||||
pytest-cov
|
pytest-cov
|
||||||
ignore_outcome =
|
ignore_outcome =
|
||||||
djangotrunk: True
|
djangotrunk: True
|
||||||
|
passenv =
|
||||||
|
CI
|
||||||
|
TRAVIS
|
||||||
|
TRAVIS_*
|
||||||
|
|
||||||
commands =
|
commands =
|
||||||
pip install -e .
|
pip install -e .
|
||||||
py.test {posargs}
|
py.test {posargs}
|
||||||
|
|
||||||
|
[testenv:flake8]
|
||||||
|
basepython =
|
||||||
|
python3.6
|
||||||
|
deps =
|
||||||
|
flake8
|
||||||
|
commands =
|
||||||
|
flake8 model_utils tests
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
ignore =
|
||||||
|
E731 ; do not assign a lambda expression, use a def
|
||||||
|
W503 ; line break before binary operator
|
||||||
|
E402 ; module level import not at top of file
|
||||||
|
E501 ; line too long
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue