Merge branch 'master' into master

This commit is contained in:
Reece Dunham 2019-01-10 15:47:22 -05:00 committed by GitHub
commit 73dd5aa8fd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 816 additions and 189 deletions

View file

@ -1,5 +1,2 @@
[run]
source = model_utils
omit = .*
tests/*
*/_*
include = model_utils/*.py

View file

@ -8,7 +8,9 @@ python:
- 3.6
install: pip install tox-travis codecov
# positional args ({posargs}) to pass into tox.ini
script: tox -- --cov
script: tox -- --cov --cov-append
services:
- postgresql
after_success: codecov
deploy:
provider: pypi

View file

@ -3,6 +3,22 @@ CHANGES
master (unreleased)
-------------------
- Fix handling of deferred attributes on Django 1.10+, fixes GH-278
- Fix `FieldTracker.has_changed()` and `FieldTracker.previous()` to return
correct responses for deferred fields.
- Update AutoLastModifiedField so that at instance creation it will
always be set equal to created to make querying easier. Fixes GH-254
- Support `reversed` for all kinds of `Choices` objects, fixes GH-309
- Fix Model instance non picklable GH-330
- Fix patched `save` in FieldTracker
3.1.2 (2018.05.09)
------------------
* Update InheritanceIterable to inherit from
ModelIterable instead of BaseIterable, fixes GH-277.
* Add all_objects Manager for 'SoftDeletableModel' to include soft
deleted objects on queries as per issue GH-255
3.1.1 (2017.12.17)
------------------

View file

@ -1,4 +1,4 @@
Copyright (c) 2009-2015, Carl Meyer and contributors
Copyright (c) 2009-2019, Carl Meyer and contributors
All rights reserved.
Redistribution and use in source and binary forms, with or without

View file

@ -14,7 +14,7 @@ django-model-utils
Django model mixins and utilities.
``django-model-utils`` supports `Django`_ 1.8 to 2.0.
``django-model-utils`` supports `Django`_ 1.8 to 2.1.
.. _Django: http://www.djangoproject.com/
@ -28,6 +28,15 @@ Getting Help
Documentation for django-model-utils is available
https://django-model-utils.readthedocs.io/
Run tests
---------
.. code-block
pip install -e .
py.test
Contributing
============

View file

@ -86,6 +86,33 @@ it's safe to use as your default manager for the model.
.. _contributed by Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/
JoinManager
-----------
The ``JoinManager`` will create a temporary table of your current queryset
and join that temporary table with the model of your current queryset. This can
be advantageous if you have to page through your entire DB and using django's
slice mechanism to do that. ``LIMIT .. OFFSET ..`` becomes slower the bigger
offset you use.
.. code-block:: python
sliced_qs = Place.objects.all()[2000:2010]
qs = sliced_qs.join()
# qs contains 10 objects, and there will be a much smaller performance hit
# for paging through all of first 2000 objects.
Alternatively, you can give it a queryset and the manager will create a temporary
table and join that to your current queryset. This can work as a more performant
alternative to using django's ``__in`` as described in the following
(`StackExchange answer`_).
.. code-block:: python
big_qs = Restaurant.objects.filter(menu='vegetarian')
qs = Country.objects.filter(country_code='SE').join(big_qs)
.. _StackExchange answer: https://dba.stackexchange.com/questions/91247/optimizing-a-postgres-query-with-a-large-in
.. _QueryManager:

View file

@ -17,7 +17,7 @@ modify your ``INSTALLED_APPS`` setting.
Dependencies
============
``django-model-utils`` supports `Django`_ 1.8 through 1.10 (latest bugfix
release in each series only) on Python 2.7, 3.3 (Django 1.8 only), 3.4 and 3.5.
``django-model-utils`` supports `Django`_ 1.8 through 2.1 (latest bugfix
release in each series only) on Python 2.7, 3.4, 3.5 and 3.6.
.. _Django: http://www.djangoproject.com/

View file

@ -150,6 +150,10 @@ Returns the value of the given field during the last save:
Returns ``None`` when the model instance isn't saved yet.
If a field is `deferred`_, calling ``previous()`` will load the previous value from the database.
.. _deferred: https://docs.djangoproject.com/en/2.0/ref/models/querysets/#defer
has_changed
~~~~~~~~~~~
@ -167,6 +171,8 @@ Returns ``True`` if the given field has changed since the last save. The ``has_c
The ``has_changed`` method relies on ``previous`` to determine whether a
field's values has changed.
If a field is `deferred`_ and has been assigned locally, calling ``has_changed()``
will load the previous value from the database to perform the comparison.
changed
~~~~~~~

View file

@ -1,4 +1,4 @@
from .choices import Choices
from .tracker import FieldTracker, ModelTracker
from .choices import Choices # noqa:F401
from .tracker import FieldTracker, ModelTracker # noqa:F401
__version__ = '3.1.1'
__version__ = '3.2.0'

View file

@ -57,7 +57,6 @@ class Choices(object):
self._process(choices)
def _store(self, triple, triple_collector, double_collector):
self._identifier_map[triple[1]] = triple[0]
self._display_map[triple[0]] = triple[2]
@ -65,7 +64,6 @@ class Choices(object):
triple_collector.append(triple)
double_collector.append((triple[0], triple[2]))
def _process(self, choices, triple_collector=None, double_collector=None):
if triple_collector is None:
triple_collector = self._triples
@ -94,18 +92,18 @@ class Choices(object):
raise ValueError(
"Choices can't take a list of length %s, only 2 or 3"
% len(choice)
)
)
else:
store((choice, choice, choice))
def __len__(self):
return len(self._doubles)
def __iter__(self):
return iter(self._doubles)
def __reversed__(self):
return reversed(self._doubles)
def __getattr__(self, attname):
try:
@ -113,11 +111,9 @@ class Choices(object):
except KeyError:
raise AttributeError(attname)
def __getitem__(self, key):
return self._display_map[key]
def __add__(self, other):
if isinstance(other, self.__class__):
other = other._triples
@ -125,29 +121,24 @@ class Choices(object):
other = list(other)
return Choices(*(self._triples + other))
def __radd__(self, other):
# radd is never called for matching types, so we don't check here
other = list(other)
return Choices(*(other + self._triples))
def __eq__(self, other):
if isinstance(other, self.__class__):
return self._triples == other._triples
return False
def __repr__(self):
return '%s(%s)' % (
self.__class__.__name__,
', '.join(("%s" % repr(i) for i in self._triples))
)
)
def __contains__(self, item):
return item in self._db_values
def __deepcopy__(self, memo):
return self.__class__(*copy.deepcopy(self._triples, memo))

View file

@ -32,6 +32,11 @@ class AutoLastModifiedField(AutoCreatedField):
"""
def pre_save(self, model_instance, add):
value = now()
if not model_instance.pk:
for field in model_instance._meta.get_fields():
if isinstance(field, AutoCreatedField):
value = getattr(model_instance, field.name)
break
setattr(model_instance, self.attname, value)
return value
@ -141,6 +146,7 @@ SPLIT_DEFAULT_PARAGRAPHS = getattr(settings, 'SPLIT_DEFAULT_PARAGRAPHS', 2)
_excerpt_field_name = lambda name: '_%s_excerpt' % name
def get_excerpt(content):
excerpt = []
default_excerpt = []
@ -156,6 +162,7 @@ def get_excerpt(content):
return '\n'.join(default_excerpt)
@python_2_unicode_compatible
class SplitText(object):
def __init__(self, instance, field_name, excerpt_field_name):
@ -166,11 +173,13 @@ class SplitText(object):
self.excerpt_field_name = excerpt_field_name
# content is read/write
def _get_content(self):
@property
def content(self):
return self.instance.__dict__[self.field_name]
def _set_content(self, val):
@content.setter
def content(self, val):
setattr(self.instance, self.field_name, val)
content = property(_get_content, _set_content)
# excerpt is a read only property
def _get_excerpt(self):
@ -185,6 +194,7 @@ class SplitText(object):
def __str__(self):
return self.content
class SplitDescriptor(object):
def __init__(self, field):
self.field = field
@ -205,6 +215,7 @@ class SplitDescriptor(object):
else:
obj.__dict__[self.field.name] = value
class SplitField(models.TextField):
def __init__(self, *args, **kwargs):
# for South FakeORM compatibility: the frozen version of a

Binary file not shown.

View file

@ -0,0 +1,46 @@
# Czech translations of django-model-utils
#
# This file is distributed under the same license as the django-model-utils package.
#
# Translators:
# ------------
# Václav Dohnal <vaclav.dohnal@gmail.com>, 2018.
#
msgid ""
msgstr ""
"Project-Id-Version: django-model-utils\n"
"Report-Msgid-Bugs-To: https://github.com/jazzband/django-model-utils/issues\n"
"POT-Creation-Date: 2018-05-04 13:40+0200\n"
"PO-Revision-Date: 2018-05-04 13:46+0200\n"
"Language: cs\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=UTF-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Plural-Forms: nplurals=3; plural=(n==1) ? 0 : (n>=2 && n<=4) ? 1 : 2;\n"
"Last-Translator: Václav Dohnal <vaclav.dohnal@gmail.com>\n"
"Language-Team: N/A\n"
"X-Generator: Poedit 2.0.7\n"
#: .\models.py:24
msgid "created"
msgstr "vytvořeno"
#: .\models.py:25
msgid "modified"
msgstr "upraveno"
#: .\models.py:37
msgid "start"
msgstr "začátek"
#: .\models.py:38
msgid "end"
msgstr "konec"
#: .\models.py:53
msgid "status"
msgstr "stav"
#: .\models.py:54
msgid "status changed"
msgstr "změna stavu"

Binary file not shown.

View file

@ -0,0 +1,43 @@
# This file is distributed under the same license as the django-model-utils package.
#
# Translators:
# Arseny Sysolyatin <arseny.sysolyatin@gmail.com>, 2017.
msgid ""
msgstr ""
"Project-Id-Version: django-model-utils\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2017-05-22 19:46+0300\n"
"PO-Revision-Date: 2017-05-22 19:46+0300\n"
"Last-Translator: Arseny Sysolyatin <arseny.sysolyatin@gmail.com>\n"
"Language-Team: \n"
"Language: ru\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=UTF-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Plural-Forms: nplurals=4; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && n"
"%10<=4 && (n%100<12 || n%100>14) ? 1 : n%10==0 || (n%10>=5 && n%10<=9) || (n"
"%100>=11 && n%100<=14)? 2 : 3);\n"
#: models.py:24
msgid "created"
msgstr "создано"
#: models.py:25
msgid "modified"
msgstr "изменено"
#: models.py:37
msgid "start"
msgstr "начало"
#: models.py:38
msgid "end"
msgstr "конец"
#: models.py:53
msgid "status"
msgstr "статус"
#: models.py:54
msgid "status changed"
msgstr "статус изменен"

View file

@ -3,18 +3,17 @@ import django
from django.db import models
from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import QuerySet
try:
from django.db.models.query import BaseIterable, ModelIterable
except ImportError:
# Django 1.8 does not have iterable classes
BaseIterable = object
from django.db.models.query import ModelIterable
from django.core.exceptions import ObjectDoesNotExist
from django.db.models.constants import LOOKUP_SEP
from django.utils.six import string_types
from django.db import connection
from django.db.models.sql.datastructures import Join
class InheritanceIterable(BaseIterable):
class InheritanceIterable(ModelIterable):
def __iter__(self):
queryset = self.queryset
iter = ModelIterable(queryset)
@ -48,11 +47,10 @@ class InheritanceIterable(BaseIterable):
class InheritanceQuerySetMixin(object):
def __init__(self, *args, **kwargs):
super(InheritanceQuerySetMixin, self).__init__(*args, **kwargs)
if django.VERSION > (1, 8):
self._iterable_class = InheritanceIterable
self._iterable_class = InheritanceIterable
def select_subclasses(self, *subclasses):
levels = self._get_maximum_depth()
levels = None
calculated_subclasses = self._get_subclasses_recurse(
self.model, levels=levels)
# if none were passed in, we can just short circuit and select all
@ -77,7 +75,7 @@ class InheritanceQuerySetMixin(object):
raise ValueError(
'%r is not in the discovered subclasses, tried: %s' % (
subclass, ', '.join(calculated_subclasses))
)
)
subclasses = verified_subclasses
# workaround https://code.djangoproject.com/ticket/16855
@ -99,16 +97,16 @@ class InheritanceQuerySetMixin(object):
def _clone(self, klass=None, setup=False, **kwargs):
if django.VERSION >= (2, 0):
return super(InheritanceQuerySetMixin, self)._clone()
qs = super(InheritanceQuerySetMixin, self)._clone()
for name in ['subclasses', '_annotated']:
if hasattr(self, name):
setattr(qs, name, getattr(self, name))
return qs
for name in ['subclasses', '_annotated']:
if hasattr(self, name):
kwargs[name] = getattr(self, name)
if django.VERSION < (1, 9):
kwargs['klass'] = klass
kwargs['setup'] = setup
return super(InheritanceQuerySetMixin, self)._clone(**kwargs)
def annotate(self, *args, **kwargs):
@ -151,19 +149,16 @@ class InheritanceQuerySetMixin(object):
recursively, returning a `list` of strings representing the
relations for select_related
"""
if django.VERSION < (1, 8):
related_objects = model._meta.get_all_related_objects()
else:
related_objects = [
f for f in model._meta.get_fields()
if isinstance(f, OneToOneRel)]
related_objects = [
f for f in model._meta.get_fields()
if isinstance(f, OneToOneRel)]
rels = [
rel for rel in related_objects
if isinstance(rel.field, OneToOneField)
and issubclass(rel.field.model, model)
and model is not rel.field.model
]
]
subclasses = []
if levels:
@ -193,16 +188,10 @@ class InheritanceQuerySetMixin(object):
if levels:
levels -= 1
while parent_link is not None:
if django.VERSION < (1, 9):
related = parent_link.rel
else:
related = parent_link.remote_field
related = parent_link.remote_field
ancestry.insert(0, related.get_accessor_name())
if levels or levels is None:
if django.VERSION < (1, 8):
parent_model = related.parent_model
else:
parent_model = related.model
parent_model = related.model
parent_link = parent_model._meta.get_ancestor_link(
self.model)
else:
@ -230,17 +219,6 @@ class InheritanceQuerySetMixin(object):
def get_subclass(self, *args, **kwargs):
return self.select_subclasses().get(*args, **kwargs)
def _get_maximum_depth(self):
"""
Under Django versions < 1.6, to avoid triggering
https://code.djangoproject.com/ticket/16572 we can only look
as far as children.
"""
levels = None
if django.VERSION < (1, 6, 0):
levels = 1
return levels
class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet):
pass
@ -326,3 +304,111 @@ class SoftDeletableManagerMixin(object):
class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager):
pass
class JoinQueryset(models.QuerySet):
def get_quoted_query(self, query):
query, params = query.sql_with_params()
# Put additional quotes around string.
params = [
'\'{}\''.format(p)
if isinstance(p, str) else p
for p in params
]
# Cast list of parameters to tuple because I got
# "not enough format characters" otherwise.
params = tuple(params)
return query % params
def join(self, qs=None):
'''
Join one queryset together with another using a temporary table. If
no queryset is used, it will use the current queryset and join that
to itself.
`Join` either uses the current queryset and effectively does a self-join to
create a new limited queryset OR it uses a querset given by the user.
The model of a given queryset needs to contain a valid foreign key to
the current queryset to perform a join. A new queryset is then created.
'''
to_field = 'id'
if qs:
fk = [
fk for fk in qs.model._meta.fields
if getattr(fk, 'related_model', None) == self.model
]
fk = fk[0] if fk else None
model_set = '{}_set'.format(self.model.__name__.lower())
key = fk or getattr(qs.model, model_set, None)
if not key:
raise ValueError('QuerySet is not related to current model')
try:
fk_column = key.column
except AttributeError:
fk_column = 'id'
to_field = key.field.column
qs = qs.only(fk_column)
# if we give a qs we need to keep the model qs to not lose anything
new_qs = self
else:
fk_column = 'id'
qs = self.only(fk_column)
new_qs = self.model.objects.all()
TABLE_NAME = 'temp_stuff'
query = self.get_quoted_query(qs.query)
sql = '''
DROP TABLE IF EXISTS {table_name};
DROP INDEX IF EXISTS {table_name}_id;
CREATE TEMPORARY TABLE {table_name} AS {query};
CREATE INDEX {table_name}_{fk_column} ON {table_name} ({fk_column});
'''.format(table_name=TABLE_NAME, fk_column=fk_column, query=str(query))
with connection.cursor() as cursor:
cursor.execute(sql)
class TempModel(models.Model):
temp_key = models.ForeignKey(
self.model,
on_delete=models.DO_NOTHING,
db_column=fk_column,
to_field=to_field
)
class Meta:
managed = False
db_table = TABLE_NAME
conn = Join(
table_name=TempModel._meta.db_table,
parent_alias=new_qs.query.get_initial_alias(),
table_alias=None,
join_type='INNER JOIN',
join_field=self.model.tempmodel_set.rel,
nullable=False
)
new_qs.query.join(conn, reuse=None)
return new_qs
class JoinManagerMixin(object):
"""
Manager that adds a method join. This method allows you to join two
querysets together.
"""
_queryset_class = JoinQueryset
def get_queryset(self):
return self._queryset_class(model=self.model, using=self._db)
class JoinManager(JoinManagerMixin, models.Manager):
pass

View file

@ -123,6 +123,7 @@ class SoftDeletableModel(models.Model):
abstract = True
objects = SoftDeletableManager()
all_objects = models.Manager()
def delete(self, using=None, soft=True, *args, **kwargs):
"""

View file

@ -29,12 +29,73 @@ class DescriptorMixin(object):
return self.field_name
class DescriptorWrapper(object):
def __init__(self, field_name, descriptor, tracker_attname):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname
def __get__(self, instance, owner):
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
value = self.descriptor.__get__(instance, owner)
if was_deferred:
tracker_instance = getattr(instance, self.tracker_attname)
tracker_instance.saved_data[self.field_name] = deepcopy(value)
return value
def __set__(self, instance, value):
initialized = hasattr(instance, '_instance_intialized')
was_deferred = self.field_name in instance.get_deferred_fields()
# Sentinel attribute to detect whether we are already trying to
# set the attribute higher up the stack. This prevents infinite
# recursion when retrieving deferred values from the database.
recursion_sentinel_attname = '_setting_' + self.field_name
already_setting = hasattr(instance, recursion_sentinel_attname)
if initialized and was_deferred and not already_setting:
setattr(instance, recursion_sentinel_attname, True)
try:
# Retrieve the value to set the saved_data value.
# This will undefer the field
getattr(instance, self.field_name)
finally:
instance.__dict__.pop(recursion_sentinel_attname, None)
if hasattr(self.descriptor, '__set__'):
self.descriptor.__set__(instance, value)
else:
instance.__dict__[self.field_name] = value
@staticmethod
def cls_for_descriptor(descriptor):
if hasattr(descriptor, '__delete__'):
return FullDescriptorWrapper
else:
return DescriptorWrapper
class FullDescriptorWrapper(DescriptorWrapper):
"""
Wrapper for descriptors with all three descriptor methods.
"""
def __delete__(self, obj):
self.descriptor.__delete__(obj)
class FieldInstanceTracker(object):
def __init__(self, instance, fields, field_map):
self.instance = instance
self.fields = fields
self.field_map = field_map
self.init_deferred_fields()
if django.VERSION < (1, 10):
self.init_deferred_fields()
@property
def deferred_fields(self):
return self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields()
def get_field_value(self, field):
return getattr(self.instance, self.field_map[field])
@ -54,10 +115,11 @@ class FieldInstanceTracker(object):
def current(self, fields=None):
"""Returns dict of current values for all tracked fields"""
if fields is None:
if self.instance._deferred_fields:
deferred_fields = self.deferred_fields
if deferred_fields:
fields = [
field for field in self.fields
if field not in self.instance._deferred_fields
if field not in deferred_fields
]
else:
fields = self.fields
@ -67,12 +129,31 @@ class FieldInstanceTracker(object):
def has_changed(self, field):
"""Returns ``True`` if field has changed from currently saved value"""
if field in self.fields:
# deferred fields haven't changed
if field in self.deferred_fields and field not in self.instance.__dict__:
return False
return self.previous(field) != self.get_field_value(field)
else:
raise FieldError('field "%s" not tracked' % field)
def previous(self, field):
"""Returns currently saved value of given field"""
# handle deferred fields that have not yet been loaded from the database
if self.instance.pk and field in self.deferred_fields and field not in self.saved_data:
# if the field has not been assigned locally, simply fetch and un-defer the value
if field not in self.instance.__dict__:
self.get_field_value(field)
# if the field has been assigned locally, store the local value, fetch the database value,
# store database value to saved_data, and restore the local value
else:
current_value = self.get_field_value(field)
self.instance.refresh_from_db(fields=[field])
self.saved_data[field] = deepcopy(self.get_field_value(field))
setattr(self.instance, self.field_map[field], current_value)
return self.saved_data.get(field)
def changed(self):
@ -97,35 +178,15 @@ class FieldInstanceTracker(object):
def _get_field_name(self):
return self.field.name
if django.VERSION >= (1, 8):
self.instance._deferred_fields = self.instance.get_deferred_fields()
for field in self.instance._deferred_fields:
if django.VERSION >= (1, 10):
field_obj = getattr(self.instance.__class__, field)
else:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, FileDescriptor):
field_tracker = FileDescriptorTracker(field_obj.field)
setattr(self.instance.__class__, field, field_tracker)
else:
field_tracker = DeferredAttributeTracker(
field_obj.field_name, None)
setattr(self.instance.__class__, field, field_tracker)
else:
for field in self.fields:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, DeferredAttribute):
self.instance._deferred_fields.add(field)
# Django 1.4
if django.VERSION >= (1, 5):
model = None
else:
model = field_obj.model_ref()
field_tracker = DeferredAttributeTracker(
field_obj.field_name, model)
setattr(self.instance.__class__, field, field_tracker)
self.instance._deferred_fields = self.instance.get_deferred_fields()
for field in self.instance._deferred_fields:
field_obj = self.instance.__class__.__dict__.get(field)
if isinstance(field_obj, FileDescriptor):
field_tracker = FileDescriptorTracker(field_obj.field)
setattr(self.instance.__class__, field, field_tracker)
else:
field_tracker = DeferredAttributeTracker(field, type(self.instance))
setattr(self.instance.__class__, field, field_tracker)
class FieldTracker(object):
@ -146,12 +207,19 @@ class FieldTracker(object):
def contribute_to_class(self, cls, name):
self.name = name
self.attname = '_%s' % name
self.patch_save(cls)
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
def finalize_class(self, sender, **kwargs):
if self.fields is None:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
if django.VERSION >= (1, 10):
for field_name in self.fields:
descriptor = getattr(sender, field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
self.field_map = self.get_field_map(sender)
models.signals.post_init.connect(self.initialize_tracker)
self.model_class = sender
@ -163,12 +231,13 @@ class FieldTracker(object):
tracker = self.tracker_class(instance, self.fields, self.field_map)
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
self.patch_save(instance)
instance._instance_intialized = True
def patch_save(self, instance):
original_save = instance.save
def save(**kwargs):
ret = original_save(**kwargs)
def patch_save(self, model):
original_save = model.save
def save(instance, *args, **kwargs):
ret = original_save(instance, *args, **kwargs)
update_fields = kwargs.get('update_fields')
if not update_fields and update_fields is not None: # () or []
fields = update_fields
@ -183,7 +252,8 @@ class FieldTracker(object):
fields=fields
)
return ret
instance.save = save
model.save = save
def __get__(self, instance, owner):
if instance is None:

View file

@ -1,2 +1,3 @@
pytest==3.3.1
pytest-django==3.1.2
psycopg2==2.7.6.1

View file

@ -1,13 +1,19 @@
from __future__ import unicode_literals, absolute_import
import django
from django.db import models
from django.db.models.query_utils import DeferredAttribute
from django.db.models import Manager
from django.utils.encoding import python_2_unicode_compatible
from django.utils.translation import ugettext_lazy as _
from model_utils import Choices
from model_utils.fields import SplitField, MonitorField, StatusField
from model_utils.managers import QueryManager, InheritanceManager
from model_utils.managers import (
QueryManager,
InheritanceManager,
JoinManagerMixin
)
from model_utils.models import (
SoftDeletableModel,
StatusModel,
@ -36,9 +42,6 @@ class InheritanceManagerTestParent(models.Model):
on_delete=models.CASCADE)
objects = InheritanceManager()
def __unicode__(self):
return unicode(self.pk)
def __str__(self):
return "%s(%s)" % (
self.__class__.__name__[len('InheritanceManagerTest'):],
@ -331,3 +334,62 @@ class CustomSoftDelete(SoftDeletableModel):
is_read = models.BooleanField(default=False)
objects = CustomSoftDeleteManager()
class StringyDescriptor(object):
"""
Descriptor that returns a string version of the underlying integer value.
"""
def __init__(self, name):
self.name = name
def __get__(self, obj, cls=None):
if obj is None:
return self
if self.name in obj.get_deferred_fields():
# This queries the database, and sets the value on the instance.
if django.VERSION < (2, 1):
DeferredAttribute(field_name=self.name, model=cls).__get__(obj, cls)
else:
DeferredAttribute(field_name=self.name).__get__(obj, cls)
return str(obj.__dict__[self.name])
def __set__(self, obj, value):
obj.__dict__[self.name] = int(value)
def __delete__(self, obj):
del obj.__dict__[self.name]
class CustomDescriptorField(models.IntegerField):
def contribute_to_class(self, cls, name, **kwargs):
super(CustomDescriptorField, self).contribute_to_class(cls, name, **kwargs)
setattr(cls, name, StringyDescriptor(name))
class ModelWithCustomDescriptor(models.Model):
custom_field = CustomDescriptorField()
tracked_custom_field = CustomDescriptorField()
regular_field = models.IntegerField()
tracked_regular_field = models.IntegerField()
tracker = FieldTracker(fields=['tracked_custom_field', 'tracked_regular_field'])
class JoinManager(JoinManagerMixin, models.Manager):
pass
class BoxJoinModel(models.Model):
name = models.CharField(max_length=32)
objects = JoinManager()
class JoinItemForeignKey(models.Model):
weight = models.IntegerField()
belonging = models.ForeignKey(
BoxJoinModel,
null=True,
on_delete=models.CASCADE
)
objects = JoinManager()

View file

@ -1,10 +1,22 @@
import os
INSTALLED_APPS = (
'model_utils',
'tests',
)
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3'
}
"default": {
"ENGINE": "django.db.backends.postgresql_psycopg2",
"NAME": os.environ.get("DJANGO_DATABASE_NAME_POSTGRES", "modelutils"),
"USER": os.environ.get("DJANGO_DATABASE_USER_POSTGRES", 'postgres'),
"PASSWORD": os.environ.get("DJANGO_DATABASE_PASSWORD_POSTGRES", ""),
"HOST": os.environ.get("DJANGO_DATABASE_HOST_POSTGRES", ""),
},
}
SECRET_KEY = 'dummy'
CACHES = {
'default': {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
}
}

View file

@ -16,7 +16,12 @@ class ChoicesTests(TestCase):
self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED')
def test_iteration(self):
self.assertEqual(tuple(self.STATUS), (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED')))
self.assertEqual(tuple(self.STATUS),
(('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED')))
def test_reversed(self):
self.assertEqual(tuple(reversed(self.STATUS)),
(('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT')))
def test_len(self):
self.assertEqual(len(self.STATUS), 2)
@ -78,8 +83,15 @@ class LabelChoicesTests(ChoicesTests):
self.assertEqual(tuple(self.STATUS), (
('DRAFT', 'is draft'),
('PUBLISHED', 'is published'),
('DELETED', 'DELETED'))
)
('DELETED', 'DELETED'),
))
def test_reversed(self):
self.assertEqual(tuple(reversed(self.STATUS)), (
('DELETED', 'DELETED'),
('PUBLISHED', 'is published'),
('DRAFT', 'is draft'),
))
def test_indexing(self):
self.assertEqual(self.STATUS['PUBLISHED'], 'is published')
@ -169,7 +181,15 @@ class IdentifierChoicesTests(ChoicesTests):
self.assertEqual(tuple(self.STATUS), (
(0, 'is draft'),
(1, 'is published'),
(2, 'is deleted')))
(2, 'is deleted'),
))
def test_reversed(self):
self.assertEqual(tuple(reversed(self.STATUS)), (
(2, 'is deleted'),
(1, 'is published'),
(0, 'is draft'),
))
def test_indexing(self):
self.assertEqual(self.STATUS[1], 'is published')

View file

@ -1,12 +1,11 @@
from __future__ import unicode_literals
from unittest import skipUnless
import django
from django.core.exceptions import FieldError
from django.test import TestCase
from django.core.cache import cache
from model_utils import FieldTracker
from model_utils.tracker import DescriptorWrapper
from tests.models import (
Tracked, TrackedFK, InheritedTrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple,
InheritedTracked, TrackedFileField,
@ -74,7 +73,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertChanged(name=None, number=None)
self.instance.name = ''
self.assertChanged(name=None, number=None)
self.instance.mutable = [1,2,3]
self.instance.mutable = [1, 2, 3]
self.assertChanged(name=None, number=None, mutable=None)
def test_pre_save_has_changed(self):
@ -83,9 +82,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertHasChanged(name=True, number=False, mutable=False)
self.instance.number = 7
self.assertHasChanged(name=True, number=True)
self.instance.mutable = [1,2,3]
self.instance.mutable = [1, 2, 3]
self.assertHasChanged(name=True, number=True, mutable=True)
def test_save_with_args(self):
self.instance.number = 1
self.instance.save(False, False, None, None)
self.assertChanged()
def test_first_save(self):
self.assertHasChanged(name=True, number=False, mutable=False)
self.assertPrevious(name=None, number=None, mutable=None)
@ -93,22 +97,22 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertChanged(name=None)
self.instance.name = 'retro'
self.instance.number = 4
self.instance.mutable = [1,2,3]
self.instance.mutable = [1, 2, 3]
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3])
self.assertChanged(name=None, number=None, mutable=None)
self.instance.save(update_fields=[])
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3])
self.assertChanged(name=None, number=None, mutable=None)
with self.assertRaises(ValueError):
self.instance.save(update_fields=['number'])
def test_post_save_has_changed(self):
self.update_instance(name='retro', number=4, mutable=[1,2,3])
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertHasChanged(name=False, number=False, mutable=False)
self.instance.name = 'new age'
self.assertHasChanged(name=True, number=False)
@ -120,14 +124,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertHasChanged(name=False, number=True, mutable=True)
def test_post_save_previous(self):
self.update_instance(name='retro', number=4, mutable=[1,2,3])
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.instance.name = 'new age'
self.assertPrevious(name='retro', number=4, mutable=[1,2,3])
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
self.instance.mutable[1] = 4
self.assertPrevious(name='retro', number=4, mutable=[1,2,3])
self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3])
def test_post_save_changed(self):
self.update_instance(name='retro', number=4, mutable=[1,2,3])
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertChanged()
self.instance.name = 'new age'
self.assertChanged(name='retro')
@ -136,8 +140,8 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.name = 'retro'
self.assertChanged(number=4)
self.instance.mutable[1] = 4
self.assertChanged(number=4, mutable=[1,2,3])
self.instance.mutable = [1,2,3]
self.assertChanged(number=4, mutable=[1, 2, 3])
self.instance.mutable = [1, 2, 3]
self.assertChanged(number=4)
def test_current(self):
@ -146,29 +150,29 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.assertCurrent(id=None, name='new age', number=None, mutable=None)
self.instance.number = 8
self.assertCurrent(id=None, name='new age', number=8, mutable=None)
self.instance.mutable = [1,2,3]
self.assertCurrent(id=None, name='new age', number=8, mutable=[1,2,3])
self.instance.mutable = [1, 2, 3]
self.assertCurrent(id=None, name='new age', number=8, mutable=[1, 2, 3])
self.instance.mutable[1] = 4
self.assertCurrent(id=None, name='new age', number=8, mutable=[1,4,3])
self.assertCurrent(id=None, name='new age', number=8, mutable=[1, 4, 3])
self.instance.save()
self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1,4,3])
self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3])
def test_update_fields(self):
self.update_instance(name='retro', number=4, mutable=[1,2,3])
self.update_instance(name='retro', number=4, mutable=[1, 2, 3])
self.assertChanged()
self.instance.name = 'new age'
self.instance.number = 8
self.instance.mutable = [4,5,6]
self.assertChanged(name='retro', number=4, mutable=[1,2,3])
self.instance.mutable = [4, 5, 6]
self.assertChanged(name='retro', number=4, mutable=[1, 2, 3])
self.instance.save(update_fields=[])
self.assertChanged(name='retro', number=4, mutable=[1,2,3])
self.assertChanged(name='retro', number=4, mutable=[1, 2, 3])
self.instance.save(update_fields=['name'])
in_db = self.tracked_class.objects.get(id=self.instance.id)
self.assertEqual(in_db.name, self.instance.name)
self.assertNotEqual(in_db.number, self.instance.number)
self.assertChanged(number=4, mutable=[1,2,3])
self.assertChanged(number=4, mutable=[1, 2, 3])
self.instance.save(update_fields=['number'])
self.assertChanged(mutable=[1,2,3])
self.assertChanged(mutable=[1, 2, 3])
self.instance.save(update_fields=['mutable'])
self.assertChanged()
in_db = self.tracked_class.objects.get(id=self.instance.id)
@ -180,20 +184,65 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
self.instance.name = 'new age'
self.instance.number = 1
self.instance.save()
item = list(self.tracked_class.objects.only('name').all())[0]
self.assertTrue(item._deferred_fields)
item = self.tracked_class.objects.only('name').first()
if django.VERSION >= (1, 10):
self.assertTrue(item.get_deferred_fields())
else:
self.assertTrue(item._deferred_fields)
self.assertEqual(item.tracker.previous('number'), None)
self.assertTrue('number' in item._deferred_fields)
# has_changed() returns False for deferred fields, without un-deferring them.
# Use an if because ModelTracked doesn't support has_changed() in this case.
if self.tracked_class == Tracked:
self.assertFalse(item.tracker.has_changed('number'))
if django.VERSION >= (1, 10):
self.assertIsInstance(item.__class__.number, DescriptorWrapper)
self.assertTrue('number' in item.get_deferred_fields())
else:
self.assertTrue('number' in item._deferred_fields)
# previous() un-defers field and returns value
self.assertEqual(item.tracker.previous('number'), 1)
if django.VERSION >= (1, 10):
self.assertNotIn('number', item.get_deferred_fields())
else:
self.assertNotIn('number', item._deferred_fields)
# examining a deferred field un-defers it
item = self.tracked_class.objects.only('name').first()
self.assertEqual(item.number, 1)
self.assertTrue('number' not in item._deferred_fields)
if django.VERSION >= (1, 10):
self.assertTrue('number' not in item.get_deferred_fields())
else:
self.assertTrue('number' not in item._deferred_fields)
self.assertEqual(item.tracker.previous('number'), 1)
self.assertFalse(item.tracker.has_changed('number'))
# has_changed() returns correct values after deferred field is examined
self.assertFalse(item.tracker.has_changed('number'))
item.number = 2
self.assertTrue(item.tracker.has_changed('number'))
# previous() returns correct value after deferred field is examined
self.assertEqual(item.tracker.previous('number'), 1)
# assigning to a deferred field un-defers it
# Use an if because ModelTracked doesn't handle this case.
if self.tracked_class == Tracked:
item = self.tracked_class.objects.only('name').first()
item.number = 2
# previous() fetches correct value from database after deferred field is assigned
self.assertEqual(item.tracker.previous('number'), 1)
# database fetch of previous() value doesn't affect current value
self.assertEqual(item.number, 2)
# has_changed() returns correct values after deferred field is assigned
self.assertTrue(item.tracker.has_changed('number'))
item.number = 1
self.assertFalse(item.tracker.has_changed('number'))
class FieldTrackerMultipleInstancesTests(TestCase):
@ -595,6 +644,16 @@ class ModelTrackerTests(FieldTrackerTests):
tracked_class = ModelTracked
def test_cache_compatible(self):
cache.set('key', self.instance)
instance = cache.get('key')
instance.number = 1
instance.name = 'cached'
instance.save()
self.assertChanged()
instance.number = 2
self.assertHasChanged(number=True)
def test_pre_save_changed(self):
self.assertChanged()
self.instance.name = 'new age'
@ -603,7 +662,7 @@ class ModelTrackerTests(FieldTrackerTests):
self.assertChanged()
self.instance.name = ''
self.assertChanged()
self.instance.mutable = [1,2,3]
self.instance.mutable = [1, 2, 3]
self.assertChanged()
def test_first_save(self):
@ -613,16 +672,16 @@ class ModelTrackerTests(FieldTrackerTests):
self.assertChanged()
self.instance.name = 'retro'
self.instance.number = 4
self.instance.mutable = [1,2,3]
self.instance.mutable = [1, 2, 3]
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3])
self.assertChanged()
self.instance.save(update_fields=[])
self.assertHasChanged(name=True, number=True, mutable=True)
self.assertPrevious(name=None, number=None, mutable=None)
self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3])
self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3])
self.assertChanged()
with self.assertRaises(ValueError):
self.instance.save(update_fields=['number'])

View file

@ -6,7 +6,7 @@ from model_utils.fields import StatusField
from tests.models import (
Article, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled,
StatusFieldChoicesName,
)
)
class StatusFieldTests(TestCase):

View file

@ -0,0 +1,22 @@
from __future__ import unicode_literals
from unittest import skipIf
import django
from django.test import TestCase
from django.db.models import Prefetch
from tests.models import InheritanceManagerTestParent, InheritanceManagerTestChild1
class InheritanceIterableTest(TestCase):
@skipIf(django.VERSION[:2] == (1, 10), "Django 1.10 expects ModelIterable not a subclass of it")
def test_prefetch(self):
qs = InheritanceManagerTestChild1.objects.all().prefetch_related(
Prefetch(
'normal_field',
queryset=InheritanceManagerTestParent.objects.all(),
to_attr='normal_field_prefetched'
)
)
self.assertEquals(qs.count(), 0)

View file

@ -6,11 +6,12 @@ import django
from django.db import models
from django.test import TestCase
from tests.models import (InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1,
InheritanceManagerTestGrandChild1_2, InheritanceManagerTestParent,
InheritanceManagerTestChild1,
InheritanceManagerTestChild2, TimeFrame, InheritanceManagerTestChild3
)
from tests.models import (
InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1,
InheritanceManagerTestGrandChild1_2, InheritanceManagerTestParent,
InheritanceManagerTestChild1,
InheritanceManagerTestChild2, TimeFrame, InheritanceManagerTestChild3
)
class InheritanceManagerTests(TestCase):
@ -115,9 +116,6 @@ class InheritanceManagerTests(TestCase):
"inheritancemanagertestchild2").get(pk=self.child1.pk)
obj.inheritancemanagertestchild1
def test_version_determining_any_depth(self):
self.assertIsNone(self.get_manager().all()._get_maximum_depth())
def test_manually_specifying_parent_fk_including_grandchildren(self):
"""
given a Model which inherits from another Model, but also declares
@ -125,10 +123,16 @@ class InheritanceManagerTests(TestCase):
ensure that the relation names and subclasses are obtained correctly.
"""
child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.all().select_subclasses()
qs = InheritanceManagerTestParent.objects.all()
results = qs.select_subclasses().order_by('pk')
expected_objs = [self.child1, self.child2, self.grandchild1,
self.grandchild1_2, child3]
expected_objs = [
self.child1,
self.child2,
self.grandchild1,
self.grandchild1_2,
child3
]
self.assertEqual(list(results), expected_objs)
expected_related_names = [
@ -148,7 +152,8 @@ class InheritanceManagerTests(TestCase):
"""
related_name = 'manual_onetoone'
child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.all().select_subclasses(related_name)
qs = InheritanceManagerTestParent.objects.all()
results = qs.select_subclasses(related_name).order_by('pk')
expected_objs = [InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestParent(pk=self.child2.pk),
@ -180,27 +185,26 @@ class InheritanceManagerTests(TestCase):
# No argument to select_subclasses
objs_1 = list(
self.get_manager().
select_subclasses().
values_list('id')
self.get_manager()
.select_subclasses()
.values_list('id')
)
# String argument to select_subclasses
objs_2 = list(
self.get_manager().
select_subclasses(
self.get_manager()
.select_subclasses(
"inheritancemanagertestchild2"
).
values_list('id')
)
.values_list('id')
)
# String argument to select_subclasses
objs_3 = list(
self.get_manager().
select_subclasses(
self.get_manager()
.select_subclasses(
InheritanceManagerTestChild2
).
values_list('id')
).values_list('id')
)
assert all((
@ -392,14 +396,16 @@ class InheritanceManagerUsingModelsTests(TestCase):
"""
child3 = InheritanceManagerTestChild3.objects.create()
results = InheritanceManagerTestParent.objects.all().select_subclasses(
InheritanceManagerTestChild3)
InheritanceManagerTestChild3).order_by('pk')
expected_objs = [InheritanceManagerTestParent(pk=self.parent1.pk),
InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestParent(pk=self.child2.pk),
InheritanceManagerTestParent(pk=self.grandchild1.pk),
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
child3]
expected_objs = [
InheritanceManagerTestParent(pk=self.parent1.pk),
InheritanceManagerTestParent(pk=self.child1.pk),
InheritanceManagerTestParent(pk=self.child2.pk),
InheritanceManagerTestParent(pk=self.grandchild1.pk),
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
child3
]
self.assertEqual(list(results), expected_objs)
expected_related_names = ['manual_onetoone']
@ -454,3 +460,7 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests):
qs = InheritanceManagerTestParent.objects.annotate(
test_count=models.Count('id')).select_subclasses()
self.assertEqual(qs.get(id=self.child1.id).test_count, 1)
def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self):
qs = InheritanceManagerTestParent.objects.select_subclasses()
self.assertEqual(qs.subclasses, qs._clone().subclasses)

View file

@ -0,0 +1,38 @@
from django.test import TestCase
from tests.models import JoinItemForeignKey, BoxJoinModel
class JoinManagerTest(TestCase):
def setUp(self):
for i in range(20):
BoxJoinModel.objects.create(name='name_{i}'.format(i=i))
JoinItemForeignKey.objects.create(
weight=10, belonging=BoxJoinModel.objects.get(name='name_1')
)
JoinItemForeignKey.objects.create(weight=20)
def test_self_join(self):
a_slice = BoxJoinModel.objects.all()[0:10]
with self.assertNumQueries(1):
result = a_slice.join()
self.assertEquals(result.count(), 10)
def test_self_join_with_where_statement(self):
qs = BoxJoinModel.objects.filter(name='name_1')
result = qs.join()
self.assertEquals(result.count(), 1)
def test_join_with_other_qs(self):
item_qs = JoinItemForeignKey.objects.filter(weight=10)
boxes = BoxJoinModel.objects.all().join(qs=item_qs)
self.assertEquals(boxes.count(), 1)
self.assertEquals(boxes[0].name, 'name_1')
def test_reverse_join(self):
box_qs = BoxJoinModel.objects.filter(name='name_1')
items = JoinItemForeignKey.objects.all().join(box_qs)
self.assertEquals(items.count(), 1)
self.assertEquals(items[0].weight, 10)

View file

@ -0,0 +1,71 @@
from __future__ import unicode_literals
import django
from django.test import TestCase
from tests.models import ModelWithCustomDescriptor
class CustomDescriptorTests(TestCase):
def setUp(self):
self.instance = ModelWithCustomDescriptor.objects.create(
custom_field='1',
tracked_custom_field='1',
regular_field=1,
tracked_regular_field=1,
)
def test_custom_descriptor_works(self):
instance = self.instance
self.assertEqual(instance.custom_field, '1')
self.assertEqual(instance.__dict__['custom_field'], 1)
self.assertEqual(instance.regular_field, 1)
instance.custom_field = 2
self.assertEqual(instance.custom_field, '2')
self.assertEqual(instance.__dict__['custom_field'], 2)
instance.save()
instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk)
self.assertEqual(instance.custom_field, '2')
self.assertEqual(instance.__dict__['custom_field'], 2)
def test_deferred(self):
instance = ModelWithCustomDescriptor.objects.only('id').get(
pk=self.instance.pk)
if django.VERSION >= (1, 10):
self.assertIn('custom_field', instance.get_deferred_fields())
else:
self.assertIn('custom_field', instance._deferred_fields)
self.assertEqual(instance.custom_field, '1')
if django.VERSION >= (1, 10):
self.assertNotIn('custom_field', instance.get_deferred_fields())
else:
self.assertNotIn('custom_field', instance._deferred_fields)
self.assertEqual(instance.regular_field, 1)
self.assertEqual(instance.tracked_custom_field, '1')
self.assertEqual(instance.tracked_regular_field, 1)
self.assertFalse(instance.tracker.has_changed('tracked_custom_field'))
self.assertFalse(instance.tracker.has_changed('tracked_regular_field'))
instance.tracked_custom_field = 2
instance.tracked_regular_field = 2
self.assertTrue(instance.tracker.has_changed('tracked_custom_field'))
self.assertTrue(instance.tracker.has_changed('tracked_regular_field'))
instance.save()
instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk)
self.assertEqual(instance.custom_field, '1')
self.assertEqual(instance.regular_field, 1)
self.assertEqual(instance.tracked_custom_field, '2')
self.assertEqual(instance.tracked_regular_field, 2)
instance = ModelWithCustomDescriptor.objects.only('id').get(pk=instance.pk)
if django.VERSION >= (1, 10):
# This fails on 1.8 and 1.9, which is a bug in the deferred field
# implementation on those versions.
instance.tracked_custom_field = 3
self.assertEqual(instance.tracked_custom_field, '3')
self.assertTrue(instance.tracker.has_changed('tracked_custom_field'))
del instance.tracked_custom_field
self.assertEqual(instance.tracked_custom_field, '2')
self.assertFalse(instance.tracker.has_changed('tracked_custom_field'))

View file

@ -18,7 +18,7 @@ class StatusModelTests(TestCase):
c1 = self.model.objects.create()
self.assertTrue(c1.status_changed, datetime(2016, 1, 1))
c2 = self.model.objects.create()
self.model.objects.create()
self.assertEqual(self.model.active.count(), 2)
self.assertEqual(self.model.deleted.count(), 0)

View file

@ -15,6 +15,13 @@ class TimeStampedModelTests(TestCase):
t1 = TimeStamp.objects.create()
self.assertEqual(t1.created, datetime(2016, 1, 1))
def test_created_sets_modified(self):
'''
Ensure that on creation that modifed is set exactly equal to created.
'''
t1 = TimeStamp.objects.create()
self.assertEqual(t1.created, t1.modified)
def test_modified(self):
with freeze_time(datetime(2016, 1, 1)):
t1 = TimeStamp.objects.create()

30
tox.ini
View file

@ -1,24 +1,44 @@
[tox]
envlist =
py27-django{18,19,110,111}
py34-django{18,19,110,111,200}
py35-django{18,19,110,111,200,trunk}
py36-django{111,200,trunk}
py27-django{19,110,111}
py34-django{19,110,111,200}
py35-django{19,110,111,200,201,trunk}
py36-django{111,200,201,trunk}
flake8
[testenv]
deps =
django18: Django>=1.8,<1.9
django19: Django>=1.9,<1.10
django110: Django>=1.10,<1.11
django111: Django>=1.11,<1.12
django200: Django>=2.0,<2.1
django201: Django>=2.1,<2.2
djangotrunk: https://github.com/django/django/archive/master.tar.gz
freezegun == 0.3.8
-rrequirements-test.txt
pytest-cov
ignore_outcome =
djangotrunk: True
passenv =
CI
TRAVIS
TRAVIS_*
commands =
pip install -e .
py.test {posargs}
[testenv:flake8]
basepython =
python3.6
deps =
flake8
commands =
flake8 model_utils tests
[flake8]
ignore =
E731 ; do not assign a lambda expression, use a def
W503 ; line break before binary operator
E402 ; module level import not at top of file
E501 ; line too long