diff --git a/.coveragerc b/.coveragerc index 62d6d1c..8708371 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,5 +1,2 @@ [run] -source = model_utils -omit = .* - tests/* - */_* +include = model_utils/*.py diff --git a/.travis.yml b/.travis.yml index 83a070b..a7d4622 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,9 @@ python: - 3.6 install: pip install tox-travis codecov # positional args ({posargs}) to pass into tox.ini -script: tox -- --cov +script: tox -- --cov --cov-append +services: + - postgresql after_success: codecov deploy: provider: pypi diff --git a/CHANGES.rst b/CHANGES.rst index 922821d..7e6039d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,22 @@ CHANGES master (unreleased) ------------------- +- Fix handling of deferred attributes on Django 1.10+, fixes GH-278 +- Fix `FieldTracker.has_changed()` and `FieldTracker.previous()` to return + correct responses for deferred fields. +- Update AutoLastModifiedField so that at instance creation it will + always be set equal to created to make querying easier. Fixes GH-254 +- Support `reversed` for all kinds of `Choices` objects, fixes GH-309 +- Fix Model instance non picklable GH-330 +- Fix patched `save` in FieldTracker + +3.1.2 (2018.05.09) +------------------ +* Update InheritanceIterable to inherit from + ModelIterable instead of BaseIterable, fixes GH-277. + +* Add all_objects Manager for 'SoftDeletableModel' to include soft + deleted objects on queries as per issue GH-255 3.1.1 (2017.12.17) ------------------ diff --git a/LICENSE.txt b/LICENSE.txt index 0eadf47..01e3613 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2009-2015, Carl Meyer and contributors +Copyright (c) 2009-2019, Carl Meyer and contributors All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.rst b/README.rst index 3014eb0..6d74e7a 100644 --- a/README.rst +++ b/README.rst @@ -14,7 +14,7 @@ django-model-utils Django model mixins and utilities. -``django-model-utils`` supports `Django`_ 1.8 to 2.0. +``django-model-utils`` supports `Django`_ 1.8 to 2.1. .. _Django: http://www.djangoproject.com/ @@ -28,6 +28,15 @@ Getting Help Documentation for django-model-utils is available https://django-model-utils.readthedocs.io/ + +Run tests +--------- + +.. code-block + + pip install -e . + py.test + Contributing ============ diff --git a/docs/managers.rst b/docs/managers.rst index 43aa030..b90f3d4 100644 --- a/docs/managers.rst +++ b/docs/managers.rst @@ -86,6 +86,33 @@ it's safe to use as your default manager for the model. .. _contributed by Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/ +JoinManager +----------- + +The ``JoinManager`` will create a temporary table of your current queryset +and join that temporary table with the model of your current queryset. This can +be advantageous if you have to page through your entire DB and using django's +slice mechanism to do that. ``LIMIT .. OFFSET ..`` becomes slower the bigger +offset you use. + +.. code-block:: python + + sliced_qs = Place.objects.all()[2000:2010] + qs = sliced_qs.join() + # qs contains 10 objects, and there will be a much smaller performance hit + # for paging through all of first 2000 objects. + +Alternatively, you can give it a queryset and the manager will create a temporary +table and join that to your current queryset. This can work as a more performant +alternative to using django's ``__in`` as described in the following +(`StackExchange answer`_). + +.. code-block:: python + + big_qs = Restaurant.objects.filter(menu='vegetarian') + qs = Country.objects.filter(country_code='SE').join(big_qs) + +.. _StackExchange answer: https://dba.stackexchange.com/questions/91247/optimizing-a-postgres-query-with-a-large-in .. _QueryManager: diff --git a/docs/setup.rst b/docs/setup.rst index db2a34e..1fca10c 100644 --- a/docs/setup.rst +++ b/docs/setup.rst @@ -17,7 +17,7 @@ modify your ``INSTALLED_APPS`` setting. Dependencies ============ -``django-model-utils`` supports `Django`_ 1.8 through 1.10 (latest bugfix -release in each series only) on Python 2.7, 3.3 (Django 1.8 only), 3.4 and 3.5. +``django-model-utils`` supports `Django`_ 1.8 through 2.1 (latest bugfix +release in each series only) on Python 2.7, 3.4, 3.5 and 3.6. .. _Django: http://www.djangoproject.com/ diff --git a/docs/utilities.rst b/docs/utilities.rst index 44824f5..b763ba0 100644 --- a/docs/utilities.rst +++ b/docs/utilities.rst @@ -150,6 +150,10 @@ Returns the value of the given field during the last save: Returns ``None`` when the model instance isn't saved yet. +If a field is `deferred`_, calling ``previous()`` will load the previous value from the database. + +.. _deferred: https://docs.djangoproject.com/en/2.0/ref/models/querysets/#defer + has_changed ~~~~~~~~~~~ @@ -167,6 +171,8 @@ Returns ``True`` if the given field has changed since the last save. The ``has_c The ``has_changed`` method relies on ``previous`` to determine whether a field's values has changed. +If a field is `deferred`_ and has been assigned locally, calling ``has_changed()`` +will load the previous value from the database to perform the comparison. changed ~~~~~~~ diff --git a/model_utils/__init__.py b/model_utils/__init__.py index f23bd9c..2fa87c0 100644 --- a/model_utils/__init__.py +++ b/model_utils/__init__.py @@ -1,4 +1,4 @@ -from .choices import Choices -from .tracker import FieldTracker, ModelTracker +from .choices import Choices # noqa:F401 +from .tracker import FieldTracker, ModelTracker # noqa:F401 -__version__ = '3.1.1' +__version__ = '3.2.0' diff --git a/model_utils/choices.py b/model_utils/choices.py index d48ba90..6339503 100644 --- a/model_utils/choices.py +++ b/model_utils/choices.py @@ -57,7 +57,6 @@ class Choices(object): self._process(choices) - def _store(self, triple, triple_collector, double_collector): self._identifier_map[triple[1]] = triple[0] self._display_map[triple[0]] = triple[2] @@ -65,7 +64,6 @@ class Choices(object): triple_collector.append(triple) double_collector.append((triple[0], triple[2])) - def _process(self, choices, triple_collector=None, double_collector=None): if triple_collector is None: triple_collector = self._triples @@ -94,18 +92,18 @@ class Choices(object): raise ValueError( "Choices can't take a list of length %s, only 2 or 3" % len(choice) - ) + ) else: store((choice, choice, choice)) - def __len__(self): return len(self._doubles) - def __iter__(self): return iter(self._doubles) + def __reversed__(self): + return reversed(self._doubles) def __getattr__(self, attname): try: @@ -113,11 +111,9 @@ class Choices(object): except KeyError: raise AttributeError(attname) - def __getitem__(self, key): return self._display_map[key] - def __add__(self, other): if isinstance(other, self.__class__): other = other._triples @@ -125,29 +121,24 @@ class Choices(object): other = list(other) return Choices(*(self._triples + other)) - def __radd__(self, other): # radd is never called for matching types, so we don't check here other = list(other) return Choices(*(other + self._triples)) - def __eq__(self, other): if isinstance(other, self.__class__): return self._triples == other._triples return False - def __repr__(self): return '%s(%s)' % ( self.__class__.__name__, ', '.join(("%s" % repr(i) for i in self._triples)) - ) - + ) def __contains__(self, item): return item in self._db_values - def __deepcopy__(self, memo): return self.__class__(*copy.deepcopy(self._triples, memo)) diff --git a/model_utils/fields.py b/model_utils/fields.py index f308706..2799eac 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -32,6 +32,11 @@ class AutoLastModifiedField(AutoCreatedField): """ def pre_save(self, model_instance, add): value = now() + if not model_instance.pk: + for field in model_instance._meta.get_fields(): + if isinstance(field, AutoCreatedField): + value = getattr(model_instance, field.name) + break setattr(model_instance, self.attname, value) return value @@ -141,6 +146,7 @@ SPLIT_DEFAULT_PARAGRAPHS = getattr(settings, 'SPLIT_DEFAULT_PARAGRAPHS', 2) _excerpt_field_name = lambda name: '_%s_excerpt' % name + def get_excerpt(content): excerpt = [] default_excerpt = [] @@ -156,6 +162,7 @@ def get_excerpt(content): return '\n'.join(default_excerpt) + @python_2_unicode_compatible class SplitText(object): def __init__(self, instance, field_name, excerpt_field_name): @@ -166,11 +173,13 @@ class SplitText(object): self.excerpt_field_name = excerpt_field_name # content is read/write - def _get_content(self): + @property + def content(self): return self.instance.__dict__[self.field_name] - def _set_content(self, val): + + @content.setter + def content(self, val): setattr(self.instance, self.field_name, val) - content = property(_get_content, _set_content) # excerpt is a read only property def _get_excerpt(self): @@ -185,6 +194,7 @@ class SplitText(object): def __str__(self): return self.content + class SplitDescriptor(object): def __init__(self, field): self.field = field @@ -205,6 +215,7 @@ class SplitDescriptor(object): else: obj.__dict__[self.field.name] = value + class SplitField(models.TextField): def __init__(self, *args, **kwargs): # for South FakeORM compatibility: the frozen version of a diff --git a/model_utils/locale/cs/LC_MESSAGES/django.mo b/model_utils/locale/cs/LC_MESSAGES/django.mo new file mode 100644 index 0000000..758c32d Binary files /dev/null and b/model_utils/locale/cs/LC_MESSAGES/django.mo differ diff --git a/model_utils/locale/cs/LC_MESSAGES/django.po b/model_utils/locale/cs/LC_MESSAGES/django.po new file mode 100644 index 0000000..eae5e9a --- /dev/null +++ b/model_utils/locale/cs/LC_MESSAGES/django.po @@ -0,0 +1,46 @@ +# Czech translations of django-model-utils +# +# This file is distributed under the same license as the django-model-utils package. +# +# Translators: +# ------------ +# Václav Dohnal , 2018. +# +msgid "" +msgstr "" +"Project-Id-Version: django-model-utils\n" +"Report-Msgid-Bugs-To: https://github.com/jazzband/django-model-utils/issues\n" +"POT-Creation-Date: 2018-05-04 13:40+0200\n" +"PO-Revision-Date: 2018-05-04 13:46+0200\n" +"Language: cs\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n==1) ? 0 : (n>=2 && n<=4) ? 1 : 2;\n" +"Last-Translator: Václav Dohnal \n" +"Language-Team: N/A\n" +"X-Generator: Poedit 2.0.7\n" + +#: .\models.py:24 +msgid "created" +msgstr "vytvořeno" + +#: .\models.py:25 +msgid "modified" +msgstr "upraveno" + +#: .\models.py:37 +msgid "start" +msgstr "začátek" + +#: .\models.py:38 +msgid "end" +msgstr "konec" + +#: .\models.py:53 +msgid "status" +msgstr "stav" + +#: .\models.py:54 +msgid "status changed" +msgstr "změna stavu" diff --git a/model_utils/locale/ru/LC_MESSAGES/django.mo b/model_utils/locale/ru/LC_MESSAGES/django.mo new file mode 100644 index 0000000..8edb2d8 Binary files /dev/null and b/model_utils/locale/ru/LC_MESSAGES/django.mo differ diff --git a/model_utils/locale/ru/LC_MESSAGES/django.po b/model_utils/locale/ru/LC_MESSAGES/django.po new file mode 100644 index 0000000..bd5d90c --- /dev/null +++ b/model_utils/locale/ru/LC_MESSAGES/django.po @@ -0,0 +1,43 @@ +# This file is distributed under the same license as the django-model-utils package. +# +# Translators: +# Arseny Sysolyatin , 2017. +msgid "" +msgstr "" +"Project-Id-Version: django-model-utils\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2017-05-22 19:46+0300\n" +"PO-Revision-Date: 2017-05-22 19:46+0300\n" +"Last-Translator: Arseny Sysolyatin \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=4; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && n" +"%10<=4 && (n%100<12 || n%100>14) ? 1 : n%10==0 || (n%10>=5 && n%10<=9) || (n" +"%100>=11 && n%100<=14)? 2 : 3);\n" + +#: models.py:24 +msgid "created" +msgstr "создано" + +#: models.py:25 +msgid "modified" +msgstr "изменено" + +#: models.py:37 +msgid "start" +msgstr "начало" + +#: models.py:38 +msgid "end" +msgstr "конец" + +#: models.py:53 +msgid "status" +msgstr "статус" + +#: models.py:54 +msgid "status changed" +msgstr "статус изменен" diff --git a/model_utils/managers.py b/model_utils/managers.py index e8c5029..15b5c7c 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -3,18 +3,17 @@ import django from django.db import models from django.db.models.fields.related import OneToOneField, OneToOneRel from django.db.models.query import QuerySet -try: - from django.db.models.query import BaseIterable, ModelIterable -except ImportError: - # Django 1.8 does not have iterable classes - BaseIterable = object +from django.db.models.query import ModelIterable from django.core.exceptions import ObjectDoesNotExist from django.db.models.constants import LOOKUP_SEP from django.utils.six import string_types +from django.db import connection +from django.db.models.sql.datastructures import Join -class InheritanceIterable(BaseIterable): + +class InheritanceIterable(ModelIterable): def __iter__(self): queryset = self.queryset iter = ModelIterable(queryset) @@ -48,11 +47,10 @@ class InheritanceIterable(BaseIterable): class InheritanceQuerySetMixin(object): def __init__(self, *args, **kwargs): super(InheritanceQuerySetMixin, self).__init__(*args, **kwargs) - if django.VERSION > (1, 8): - self._iterable_class = InheritanceIterable + self._iterable_class = InheritanceIterable def select_subclasses(self, *subclasses): - levels = self._get_maximum_depth() + levels = None calculated_subclasses = self._get_subclasses_recurse( self.model, levels=levels) # if none were passed in, we can just short circuit and select all @@ -77,7 +75,7 @@ class InheritanceQuerySetMixin(object): raise ValueError( '%r is not in the discovered subclasses, tried: %s' % ( subclass, ', '.join(calculated_subclasses)) - ) + ) subclasses = verified_subclasses # workaround https://code.djangoproject.com/ticket/16855 @@ -99,16 +97,16 @@ class InheritanceQuerySetMixin(object): def _clone(self, klass=None, setup=False, **kwargs): if django.VERSION >= (2, 0): - return super(InheritanceQuerySetMixin, self)._clone() + qs = super(InheritanceQuerySetMixin, self)._clone() + for name in ['subclasses', '_annotated']: + if hasattr(self, name): + setattr(qs, name, getattr(self, name)) + return qs for name in ['subclasses', '_annotated']: if hasattr(self, name): kwargs[name] = getattr(self, name) - if django.VERSION < (1, 9): - kwargs['klass'] = klass - kwargs['setup'] = setup - return super(InheritanceQuerySetMixin, self)._clone(**kwargs) def annotate(self, *args, **kwargs): @@ -151,19 +149,16 @@ class InheritanceQuerySetMixin(object): recursively, returning a `list` of strings representing the relations for select_related """ - if django.VERSION < (1, 8): - related_objects = model._meta.get_all_related_objects() - else: - related_objects = [ - f for f in model._meta.get_fields() - if isinstance(f, OneToOneRel)] + related_objects = [ + f for f in model._meta.get_fields() + if isinstance(f, OneToOneRel)] rels = [ rel for rel in related_objects if isinstance(rel.field, OneToOneField) and issubclass(rel.field.model, model) and model is not rel.field.model - ] + ] subclasses = [] if levels: @@ -193,16 +188,10 @@ class InheritanceQuerySetMixin(object): if levels: levels -= 1 while parent_link is not None: - if django.VERSION < (1, 9): - related = parent_link.rel - else: - related = parent_link.remote_field + related = parent_link.remote_field ancestry.insert(0, related.get_accessor_name()) if levels or levels is None: - if django.VERSION < (1, 8): - parent_model = related.parent_model - else: - parent_model = related.model + parent_model = related.model parent_link = parent_model._meta.get_ancestor_link( self.model) else: @@ -230,17 +219,6 @@ class InheritanceQuerySetMixin(object): def get_subclass(self, *args, **kwargs): return self.select_subclasses().get(*args, **kwargs) - def _get_maximum_depth(self): - """ - Under Django versions < 1.6, to avoid triggering - https://code.djangoproject.com/ticket/16572 we can only look - as far as children. - """ - levels = None - if django.VERSION < (1, 6, 0): - levels = 1 - return levels - class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): pass @@ -326,3 +304,111 @@ class SoftDeletableManagerMixin(object): class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager): pass + + +class JoinQueryset(models.QuerySet): + + def get_quoted_query(self, query): + query, params = query.sql_with_params() + + # Put additional quotes around string. + params = [ + '\'{}\''.format(p) + if isinstance(p, str) else p + for p in params + ] + + # Cast list of parameters to tuple because I got + # "not enough format characters" otherwise. + params = tuple(params) + return query % params + + def join(self, qs=None): + ''' + Join one queryset together with another using a temporary table. If + no queryset is used, it will use the current queryset and join that + to itself. + + `Join` either uses the current queryset and effectively does a self-join to + create a new limited queryset OR it uses a querset given by the user. + + The model of a given queryset needs to contain a valid foreign key to + the current queryset to perform a join. A new queryset is then created. + ''' + to_field = 'id' + + if qs: + fk = [ + fk for fk in qs.model._meta.fields + if getattr(fk, 'related_model', None) == self.model + ] + fk = fk[0] if fk else None + model_set = '{}_set'.format(self.model.__name__.lower()) + key = fk or getattr(qs.model, model_set, None) + + if not key: + raise ValueError('QuerySet is not related to current model') + + try: + fk_column = key.column + except AttributeError: + fk_column = 'id' + to_field = key.field.column + + qs = qs.only(fk_column) + # if we give a qs we need to keep the model qs to not lose anything + new_qs = self + else: + fk_column = 'id' + qs = self.only(fk_column) + new_qs = self.model.objects.all() + + TABLE_NAME = 'temp_stuff' + query = self.get_quoted_query(qs.query) + sql = ''' + DROP TABLE IF EXISTS {table_name}; + DROP INDEX IF EXISTS {table_name}_id; + CREATE TEMPORARY TABLE {table_name} AS {query}; + CREATE INDEX {table_name}_{fk_column} ON {table_name} ({fk_column}); + '''.format(table_name=TABLE_NAME, fk_column=fk_column, query=str(query)) + + with connection.cursor() as cursor: + cursor.execute(sql) + + class TempModel(models.Model): + temp_key = models.ForeignKey( + self.model, + on_delete=models.DO_NOTHING, + db_column=fk_column, + to_field=to_field + ) + + class Meta: + managed = False + db_table = TABLE_NAME + + conn = Join( + table_name=TempModel._meta.db_table, + parent_alias=new_qs.query.get_initial_alias(), + table_alias=None, + join_type='INNER JOIN', + join_field=self.model.tempmodel_set.rel, + nullable=False + ) + new_qs.query.join(conn, reuse=None) + return new_qs + + +class JoinManagerMixin(object): + """ + Manager that adds a method join. This method allows you to join two + querysets together. + """ + _queryset_class = JoinQueryset + + def get_queryset(self): + return self._queryset_class(model=self.model, using=self._db) + + +class JoinManager(JoinManagerMixin, models.Manager): + pass diff --git a/model_utils/models.py b/model_utils/models.py index c679fc6..2f21695 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -123,6 +123,7 @@ class SoftDeletableModel(models.Model): abstract = True objects = SoftDeletableManager() + all_objects = models.Manager() def delete(self, using=None, soft=True, *args, **kwargs): """ diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 93e4a5d..837e1ce 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -29,12 +29,73 @@ class DescriptorMixin(object): return self.field_name +class DescriptorWrapper(object): + + def __init__(self, field_name, descriptor, tracker_attname): + self.field_name = field_name + self.descriptor = descriptor + self.tracker_attname = tracker_attname + + def __get__(self, instance, owner): + if instance is None: + return self + was_deferred = self.field_name in instance.get_deferred_fields() + value = self.descriptor.__get__(instance, owner) + if was_deferred: + tracker_instance = getattr(instance, self.tracker_attname) + tracker_instance.saved_data[self.field_name] = deepcopy(value) + return value + + def __set__(self, instance, value): + initialized = hasattr(instance, '_instance_intialized') + was_deferred = self.field_name in instance.get_deferred_fields() + + # Sentinel attribute to detect whether we are already trying to + # set the attribute higher up the stack. This prevents infinite + # recursion when retrieving deferred values from the database. + recursion_sentinel_attname = '_setting_' + self.field_name + already_setting = hasattr(instance, recursion_sentinel_attname) + + if initialized and was_deferred and not already_setting: + setattr(instance, recursion_sentinel_attname, True) + try: + # Retrieve the value to set the saved_data value. + # This will undefer the field + getattr(instance, self.field_name) + finally: + instance.__dict__.pop(recursion_sentinel_attname, None) + if hasattr(self.descriptor, '__set__'): + self.descriptor.__set__(instance, value) + else: + instance.__dict__[self.field_name] = value + + @staticmethod + def cls_for_descriptor(descriptor): + if hasattr(descriptor, '__delete__'): + return FullDescriptorWrapper + else: + return DescriptorWrapper + + +class FullDescriptorWrapper(DescriptorWrapper): + """ + Wrapper for descriptors with all three descriptor methods. + """ + def __delete__(self, obj): + self.descriptor.__delete__(obj) + + class FieldInstanceTracker(object): def __init__(self, instance, fields, field_map): self.instance = instance self.fields = fields self.field_map = field_map - self.init_deferred_fields() + if django.VERSION < (1, 10): + self.init_deferred_fields() + + @property + def deferred_fields(self): + return self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields() def get_field_value(self, field): return getattr(self.instance, self.field_map[field]) @@ -54,10 +115,11 @@ class FieldInstanceTracker(object): def current(self, fields=None): """Returns dict of current values for all tracked fields""" if fields is None: - if self.instance._deferred_fields: + deferred_fields = self.deferred_fields + if deferred_fields: fields = [ field for field in self.fields - if field not in self.instance._deferred_fields + if field not in deferred_fields ] else: fields = self.fields @@ -67,12 +129,31 @@ class FieldInstanceTracker(object): def has_changed(self, field): """Returns ``True`` if field has changed from currently saved value""" if field in self.fields: + # deferred fields haven't changed + if field in self.deferred_fields and field not in self.instance.__dict__: + return False return self.previous(field) != self.get_field_value(field) else: raise FieldError('field "%s" not tracked' % field) def previous(self, field): """Returns currently saved value of given field""" + + # handle deferred fields that have not yet been loaded from the database + if self.instance.pk and field in self.deferred_fields and field not in self.saved_data: + + # if the field has not been assigned locally, simply fetch and un-defer the value + if field not in self.instance.__dict__: + self.get_field_value(field) + + # if the field has been assigned locally, store the local value, fetch the database value, + # store database value to saved_data, and restore the local value + else: + current_value = self.get_field_value(field) + self.instance.refresh_from_db(fields=[field]) + self.saved_data[field] = deepcopy(self.get_field_value(field)) + setattr(self.instance, self.field_map[field], current_value) + return self.saved_data.get(field) def changed(self): @@ -97,35 +178,15 @@ class FieldInstanceTracker(object): def _get_field_name(self): return self.field.name - if django.VERSION >= (1, 8): - self.instance._deferred_fields = self.instance.get_deferred_fields() - for field in self.instance._deferred_fields: - if django.VERSION >= (1, 10): - field_obj = getattr(self.instance.__class__, field) - else: - field_obj = self.instance.__class__.__dict__.get(field) - if isinstance(field_obj, FileDescriptor): - field_tracker = FileDescriptorTracker(field_obj.field) - setattr(self.instance.__class__, field, field_tracker) - else: - field_tracker = DeferredAttributeTracker( - field_obj.field_name, None) - setattr(self.instance.__class__, field, field_tracker) - else: - for field in self.fields: - field_obj = self.instance.__class__.__dict__.get(field) - if isinstance(field_obj, DeferredAttribute): - self.instance._deferred_fields.add(field) - - # Django 1.4 - if django.VERSION >= (1, 5): - model = None - else: - model = field_obj.model_ref() - - field_tracker = DeferredAttributeTracker( - field_obj.field_name, model) - setattr(self.instance.__class__, field, field_tracker) + self.instance._deferred_fields = self.instance.get_deferred_fields() + for field in self.instance._deferred_fields: + field_obj = self.instance.__class__.__dict__.get(field) + if isinstance(field_obj, FileDescriptor): + field_tracker = FileDescriptorTracker(field_obj.field) + setattr(self.instance.__class__, field, field_tracker) + else: + field_tracker = DeferredAttributeTracker(field, type(self.instance)) + setattr(self.instance.__class__, field, field_tracker) class FieldTracker(object): @@ -146,12 +207,19 @@ class FieldTracker(object): def contribute_to_class(self, cls, name): self.name = name self.attname = '_%s' % name + self.patch_save(cls) models.signals.class_prepared.connect(self.finalize_class, sender=cls) def finalize_class(self, sender, **kwargs): if self.fields is None: self.fields = (field.attname for field in sender._meta.fields) self.fields = set(self.fields) + if django.VERSION >= (1, 10): + for field_name in self.fields: + descriptor = getattr(sender, field_name) + wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) + wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) + setattr(sender, field_name, wrapped_descriptor) self.field_map = self.get_field_map(sender) models.signals.post_init.connect(self.initialize_tracker) self.model_class = sender @@ -163,12 +231,13 @@ class FieldTracker(object): tracker = self.tracker_class(instance, self.fields, self.field_map) setattr(instance, self.attname, tracker) tracker.set_saved_fields() - self.patch_save(instance) + instance._instance_intialized = True - def patch_save(self, instance): - original_save = instance.save - def save(**kwargs): - ret = original_save(**kwargs) + def patch_save(self, model): + original_save = model.save + + def save(instance, *args, **kwargs): + ret = original_save(instance, *args, **kwargs) update_fields = kwargs.get('update_fields') if not update_fields and update_fields is not None: # () or [] fields = update_fields @@ -183,7 +252,8 @@ class FieldTracker(object): fields=fields ) return ret - instance.save = save + + model.save = save def __get__(self, instance, owner): if instance is None: diff --git a/requirements-test.txt b/requirements-test.txt index 493f267..fa21abd 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,3 @@ pytest==3.3.1 pytest-django==3.1.2 +psycopg2==2.7.6.1 diff --git a/tests/models.py b/tests/models.py index a65d499..888aba5 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,13 +1,19 @@ from __future__ import unicode_literals, absolute_import +import django from django.db import models +from django.db.models.query_utils import DeferredAttribute from django.db.models import Manager from django.utils.encoding import python_2_unicode_compatible from django.utils.translation import ugettext_lazy as _ from model_utils import Choices from model_utils.fields import SplitField, MonitorField, StatusField -from model_utils.managers import QueryManager, InheritanceManager +from model_utils.managers import ( + QueryManager, + InheritanceManager, + JoinManagerMixin +) from model_utils.models import ( SoftDeletableModel, StatusModel, @@ -36,9 +42,6 @@ class InheritanceManagerTestParent(models.Model): on_delete=models.CASCADE) objects = InheritanceManager() - def __unicode__(self): - return unicode(self.pk) - def __str__(self): return "%s(%s)" % ( self.__class__.__name__[len('InheritanceManagerTest'):], @@ -331,3 +334,62 @@ class CustomSoftDelete(SoftDeletableModel): is_read = models.BooleanField(default=False) objects = CustomSoftDeleteManager() + + +class StringyDescriptor(object): + """ + Descriptor that returns a string version of the underlying integer value. + """ + def __init__(self, name): + self.name = name + + def __get__(self, obj, cls=None): + if obj is None: + return self + if self.name in obj.get_deferred_fields(): + # This queries the database, and sets the value on the instance. + if django.VERSION < (2, 1): + DeferredAttribute(field_name=self.name, model=cls).__get__(obj, cls) + else: + DeferredAttribute(field_name=self.name).__get__(obj, cls) + return str(obj.__dict__[self.name]) + + def __set__(self, obj, value): + obj.__dict__[self.name] = int(value) + + def __delete__(self, obj): + del obj.__dict__[self.name] + + +class CustomDescriptorField(models.IntegerField): + def contribute_to_class(self, cls, name, **kwargs): + super(CustomDescriptorField, self).contribute_to_class(cls, name, **kwargs) + setattr(cls, name, StringyDescriptor(name)) + + +class ModelWithCustomDescriptor(models.Model): + custom_field = CustomDescriptorField() + tracked_custom_field = CustomDescriptorField() + regular_field = models.IntegerField() + tracked_regular_field = models.IntegerField() + + tracker = FieldTracker(fields=['tracked_custom_field', 'tracked_regular_field']) + + +class JoinManager(JoinManagerMixin, models.Manager): + pass + + +class BoxJoinModel(models.Model): + name = models.CharField(max_length=32) + objects = JoinManager() + + +class JoinItemForeignKey(models.Model): + weight = models.IntegerField() + belonging = models.ForeignKey( + BoxJoinModel, + null=True, + on_delete=models.CASCADE + ) + objects = JoinManager() diff --git a/tests/settings.py b/tests/settings.py index 8817e83..e34c891 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,10 +1,22 @@ +import os + INSTALLED_APPS = ( 'model_utils', 'tests', ) DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3' - } + "default": { + "ENGINE": "django.db.backends.postgresql_psycopg2", + "NAME": os.environ.get("DJANGO_DATABASE_NAME_POSTGRES", "modelutils"), + "USER": os.environ.get("DJANGO_DATABASE_USER_POSTGRES", 'postgres'), + "PASSWORD": os.environ.get("DJANGO_DATABASE_PASSWORD_POSTGRES", ""), + "HOST": os.environ.get("DJANGO_DATABASE_HOST_POSTGRES", ""), + }, } SECRET_KEY = 'dummy' + +CACHES = { + 'default': { + 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', + } +} diff --git a/tests/test_choices.py b/tests/test_choices.py index a503405..986670c 100644 --- a/tests/test_choices.py +++ b/tests/test_choices.py @@ -16,7 +16,12 @@ class ChoicesTests(TestCase): self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') def test_iteration(self): - self.assertEqual(tuple(self.STATUS), (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) + self.assertEqual(tuple(self.STATUS), + (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) + + def test_reversed(self): + self.assertEqual(tuple(reversed(self.STATUS)), + (('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT'))) def test_len(self): self.assertEqual(len(self.STATUS), 2) @@ -78,8 +83,15 @@ class LabelChoicesTests(ChoicesTests): self.assertEqual(tuple(self.STATUS), ( ('DRAFT', 'is draft'), ('PUBLISHED', 'is published'), - ('DELETED', 'DELETED')) - ) + ('DELETED', 'DELETED'), + )) + + def test_reversed(self): + self.assertEqual(tuple(reversed(self.STATUS)), ( + ('DELETED', 'DELETED'), + ('PUBLISHED', 'is published'), + ('DRAFT', 'is draft'), + )) def test_indexing(self): self.assertEqual(self.STATUS['PUBLISHED'], 'is published') @@ -169,7 +181,15 @@ class IdentifierChoicesTests(ChoicesTests): self.assertEqual(tuple(self.STATUS), ( (0, 'is draft'), (1, 'is published'), - (2, 'is deleted'))) + (2, 'is deleted'), + )) + + def test_reversed(self): + self.assertEqual(tuple(reversed(self.STATUS)), ( + (2, 'is deleted'), + (1, 'is published'), + (0, 'is draft'), + )) def test_indexing(self): self.assertEqual(self.STATUS[1], 'is published') diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index 00c28b3..5b5d5c2 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -1,12 +1,11 @@ from __future__ import unicode_literals -from unittest import skipUnless - import django from django.core.exceptions import FieldError from django.test import TestCase - +from django.core.cache import cache from model_utils import FieldTracker +from model_utils.tracker import DescriptorWrapper from tests.models import ( Tracked, TrackedFK, InheritedTrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple, InheritedTracked, TrackedFileField, @@ -74,7 +73,7 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertChanged(name=None, number=None) self.instance.name = '' self.assertChanged(name=None, number=None) - self.instance.mutable = [1,2,3] + self.instance.mutable = [1, 2, 3] self.assertChanged(name=None, number=None, mutable=None) def test_pre_save_has_changed(self): @@ -83,9 +82,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertHasChanged(name=True, number=False, mutable=False) self.instance.number = 7 self.assertHasChanged(name=True, number=True) - self.instance.mutable = [1,2,3] + self.instance.mutable = [1, 2, 3] self.assertHasChanged(name=True, number=True, mutable=True) + def test_save_with_args(self): + self.instance.number = 1 + self.instance.save(False, False, None, None) + self.assertChanged() + def test_first_save(self): self.assertHasChanged(name=True, number=False, mutable=False) self.assertPrevious(name=None, number=None, mutable=None) @@ -93,22 +97,22 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertChanged(name=None) self.instance.name = 'retro' self.instance.number = 4 - self.instance.mutable = [1,2,3] + self.instance.mutable = [1, 2, 3] self.assertHasChanged(name=True, number=True, mutable=True) self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3]) self.assertChanged(name=None, number=None, mutable=None) self.instance.save(update_fields=[]) self.assertHasChanged(name=True, number=True, mutable=True) self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3]) self.assertChanged(name=None, number=None, mutable=None) with self.assertRaises(ValueError): self.instance.save(update_fields=['number']) def test_post_save_has_changed(self): - self.update_instance(name='retro', number=4, mutable=[1,2,3]) + self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertHasChanged(name=False, number=False, mutable=False) self.instance.name = 'new age' self.assertHasChanged(name=True, number=False) @@ -120,14 +124,14 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertHasChanged(name=False, number=True, mutable=True) def test_post_save_previous(self): - self.update_instance(name='retro', number=4, mutable=[1,2,3]) + self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.instance.name = 'new age' - self.assertPrevious(name='retro', number=4, mutable=[1,2,3]) + self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) self.instance.mutable[1] = 4 - self.assertPrevious(name='retro', number=4, mutable=[1,2,3]) + self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) def test_post_save_changed(self): - self.update_instance(name='retro', number=4, mutable=[1,2,3]) + self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertChanged() self.instance.name = 'new age' self.assertChanged(name='retro') @@ -136,8 +140,8 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.instance.name = 'retro' self.assertChanged(number=4) self.instance.mutable[1] = 4 - self.assertChanged(number=4, mutable=[1,2,3]) - self.instance.mutable = [1,2,3] + self.assertChanged(number=4, mutable=[1, 2, 3]) + self.instance.mutable = [1, 2, 3] self.assertChanged(number=4) def test_current(self): @@ -146,29 +150,29 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.assertCurrent(id=None, name='new age', number=None, mutable=None) self.instance.number = 8 self.assertCurrent(id=None, name='new age', number=8, mutable=None) - self.instance.mutable = [1,2,3] - self.assertCurrent(id=None, name='new age', number=8, mutable=[1,2,3]) + self.instance.mutable = [1, 2, 3] + self.assertCurrent(id=None, name='new age', number=8, mutable=[1, 2, 3]) self.instance.mutable[1] = 4 - self.assertCurrent(id=None, name='new age', number=8, mutable=[1,4,3]) + self.assertCurrent(id=None, name='new age', number=8, mutable=[1, 4, 3]) self.instance.save() - self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1,4,3]) + self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3]) def test_update_fields(self): - self.update_instance(name='retro', number=4, mutable=[1,2,3]) + self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) self.assertChanged() self.instance.name = 'new age' self.instance.number = 8 - self.instance.mutable = [4,5,6] - self.assertChanged(name='retro', number=4, mutable=[1,2,3]) + self.instance.mutable = [4, 5, 6] + self.assertChanged(name='retro', number=4, mutable=[1, 2, 3]) self.instance.save(update_fields=[]) - self.assertChanged(name='retro', number=4, mutable=[1,2,3]) + self.assertChanged(name='retro', number=4, mutable=[1, 2, 3]) self.instance.save(update_fields=['name']) in_db = self.tracked_class.objects.get(id=self.instance.id) self.assertEqual(in_db.name, self.instance.name) self.assertNotEqual(in_db.number, self.instance.number) - self.assertChanged(number=4, mutable=[1,2,3]) + self.assertChanged(number=4, mutable=[1, 2, 3]) self.instance.save(update_fields=['number']) - self.assertChanged(mutable=[1,2,3]) + self.assertChanged(mutable=[1, 2, 3]) self.instance.save(update_fields=['mutable']) self.assertChanged() in_db = self.tracked_class.objects.get(id=self.instance.id) @@ -180,20 +184,65 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): self.instance.name = 'new age' self.instance.number = 1 self.instance.save() - item = list(self.tracked_class.objects.only('name').all())[0] - self.assertTrue(item._deferred_fields) + item = self.tracked_class.objects.only('name').first() + if django.VERSION >= (1, 10): + self.assertTrue(item.get_deferred_fields()) + else: + self.assertTrue(item._deferred_fields) - self.assertEqual(item.tracker.previous('number'), None) - self.assertTrue('number' in item._deferred_fields) + # has_changed() returns False for deferred fields, without un-deferring them. + # Use an if because ModelTracked doesn't support has_changed() in this case. + if self.tracked_class == Tracked: + self.assertFalse(item.tracker.has_changed('number')) + if django.VERSION >= (1, 10): + self.assertIsInstance(item.__class__.number, DescriptorWrapper) + self.assertTrue('number' in item.get_deferred_fields()) + else: + self.assertTrue('number' in item._deferred_fields) + # previous() un-defers field and returns value + self.assertEqual(item.tracker.previous('number'), 1) + if django.VERSION >= (1, 10): + self.assertNotIn('number', item.get_deferred_fields()) + else: + self.assertNotIn('number', item._deferred_fields) + + # examining a deferred field un-defers it + item = self.tracked_class.objects.only('name').first() self.assertEqual(item.number, 1) - self.assertTrue('number' not in item._deferred_fields) + if django.VERSION >= (1, 10): + self.assertTrue('number' not in item.get_deferred_fields()) + else: + self.assertTrue('number' not in item._deferred_fields) self.assertEqual(item.tracker.previous('number'), 1) self.assertFalse(item.tracker.has_changed('number')) + # has_changed() returns correct values after deferred field is examined + self.assertFalse(item.tracker.has_changed('number')) item.number = 2 self.assertTrue(item.tracker.has_changed('number')) + # previous() returns correct value after deferred field is examined + self.assertEqual(item.tracker.previous('number'), 1) + + # assigning to a deferred field un-defers it + # Use an if because ModelTracked doesn't handle this case. + if self.tracked_class == Tracked: + + item = self.tracked_class.objects.only('name').first() + item.number = 2 + + # previous() fetches correct value from database after deferred field is assigned + self.assertEqual(item.tracker.previous('number'), 1) + + # database fetch of previous() value doesn't affect current value + self.assertEqual(item.number, 2) + + # has_changed() returns correct values after deferred field is assigned + self.assertTrue(item.tracker.has_changed('number')) + item.number = 1 + self.assertFalse(item.tracker.has_changed('number')) + class FieldTrackerMultipleInstancesTests(TestCase): @@ -595,6 +644,16 @@ class ModelTrackerTests(FieldTrackerTests): tracked_class = ModelTracked + def test_cache_compatible(self): + cache.set('key', self.instance) + instance = cache.get('key') + instance.number = 1 + instance.name = 'cached' + instance.save() + self.assertChanged() + instance.number = 2 + self.assertHasChanged(number=True) + def test_pre_save_changed(self): self.assertChanged() self.instance.name = 'new age' @@ -603,7 +662,7 @@ class ModelTrackerTests(FieldTrackerTests): self.assertChanged() self.instance.name = '' self.assertChanged() - self.instance.mutable = [1,2,3] + self.instance.mutable = [1, 2, 3] self.assertChanged() def test_first_save(self): @@ -613,16 +672,16 @@ class ModelTrackerTests(FieldTrackerTests): self.assertChanged() self.instance.name = 'retro' self.instance.number = 4 - self.instance.mutable = [1,2,3] + self.instance.mutable = [1, 2, 3] self.assertHasChanged(name=True, number=True, mutable=True) self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3]) self.assertChanged() self.instance.save(update_fields=[]) self.assertHasChanged(name=True, number=True, mutable=True) self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3]) self.assertChanged() with self.assertRaises(ValueError): self.instance.save(update_fields=['number']) diff --git a/tests/test_fields/test_status_field.py b/tests/test_fields/test_status_field.py index 5f077da..dc0f223 100644 --- a/tests/test_fields/test_status_field.py +++ b/tests/test_fields/test_status_field.py @@ -6,7 +6,7 @@ from model_utils.fields import StatusField from tests.models import ( Article, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled, StatusFieldChoicesName, - ) +) class StatusFieldTests(TestCase): diff --git a/tests/test_inheritance_iterable.py b/tests/test_inheritance_iterable.py new file mode 100644 index 0000000..884b763 --- /dev/null +++ b/tests/test_inheritance_iterable.py @@ -0,0 +1,22 @@ +from __future__ import unicode_literals + +from unittest import skipIf + +import django +from django.test import TestCase +from django.db.models import Prefetch + +from tests.models import InheritanceManagerTestParent, InheritanceManagerTestChild1 + + +class InheritanceIterableTest(TestCase): + @skipIf(django.VERSION[:2] == (1, 10), "Django 1.10 expects ModelIterable not a subclass of it") + def test_prefetch(self): + qs = InheritanceManagerTestChild1.objects.all().prefetch_related( + Prefetch( + 'normal_field', + queryset=InheritanceManagerTestParent.objects.all(), + to_attr='normal_field_prefetched' + ) + ) + self.assertEquals(qs.count(), 0) diff --git a/tests/test_managers/test_inheritance_manager.py b/tests/test_managers/test_inheritance_manager.py index 4509175..d2b8b4f 100644 --- a/tests/test_managers/test_inheritance_manager.py +++ b/tests/test_managers/test_inheritance_manager.py @@ -6,11 +6,12 @@ import django from django.db import models from django.test import TestCase -from tests.models import (InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1, - InheritanceManagerTestGrandChild1_2, InheritanceManagerTestParent, - InheritanceManagerTestChild1, - InheritanceManagerTestChild2, TimeFrame, InheritanceManagerTestChild3 - ) +from tests.models import ( + InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1, + InheritanceManagerTestGrandChild1_2, InheritanceManagerTestParent, + InheritanceManagerTestChild1, + InheritanceManagerTestChild2, TimeFrame, InheritanceManagerTestChild3 +) class InheritanceManagerTests(TestCase): @@ -115,9 +116,6 @@ class InheritanceManagerTests(TestCase): "inheritancemanagertestchild2").get(pk=self.child1.pk) obj.inheritancemanagertestchild1 - def test_version_determining_any_depth(self): - self.assertIsNone(self.get_manager().all()._get_maximum_depth()) - def test_manually_specifying_parent_fk_including_grandchildren(self): """ given a Model which inherits from another Model, but also declares @@ -125,10 +123,16 @@ class InheritanceManagerTests(TestCase): ensure that the relation names and subclasses are obtained correctly. """ child3 = InheritanceManagerTestChild3.objects.create() - results = InheritanceManagerTestParent.objects.all().select_subclasses() + qs = InheritanceManagerTestParent.objects.all() + results = qs.select_subclasses().order_by('pk') - expected_objs = [self.child1, self.child2, self.grandchild1, - self.grandchild1_2, child3] + expected_objs = [ + self.child1, + self.child2, + self.grandchild1, + self.grandchild1_2, + child3 + ] self.assertEqual(list(results), expected_objs) expected_related_names = [ @@ -148,7 +152,8 @@ class InheritanceManagerTests(TestCase): """ related_name = 'manual_onetoone' child3 = InheritanceManagerTestChild3.objects.create() - results = InheritanceManagerTestParent.objects.all().select_subclasses(related_name) + qs = InheritanceManagerTestParent.objects.all() + results = qs.select_subclasses(related_name).order_by('pk') expected_objs = [InheritanceManagerTestParent(pk=self.child1.pk), InheritanceManagerTestParent(pk=self.child2.pk), @@ -180,27 +185,26 @@ class InheritanceManagerTests(TestCase): # No argument to select_subclasses objs_1 = list( - self.get_manager(). - select_subclasses(). - values_list('id') + self.get_manager() + .select_subclasses() + .values_list('id') ) # String argument to select_subclasses objs_2 = list( - self.get_manager(). - select_subclasses( + self.get_manager() + .select_subclasses( "inheritancemanagertestchild2" - ). - values_list('id') + ) + .values_list('id') ) # String argument to select_subclasses objs_3 = list( - self.get_manager(). - select_subclasses( + self.get_manager() + .select_subclasses( InheritanceManagerTestChild2 - ). - values_list('id') + ).values_list('id') ) assert all(( @@ -392,14 +396,16 @@ class InheritanceManagerUsingModelsTests(TestCase): """ child3 = InheritanceManagerTestChild3.objects.create() results = InheritanceManagerTestParent.objects.all().select_subclasses( - InheritanceManagerTestChild3) + InheritanceManagerTestChild3).order_by('pk') - expected_objs = [InheritanceManagerTestParent(pk=self.parent1.pk), - InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestParent(pk=self.child2.pk), - InheritanceManagerTestParent(pk=self.grandchild1.pk), - InheritanceManagerTestParent(pk=self.grandchild1_2.pk), - child3] + expected_objs = [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + child3 + ] self.assertEqual(list(results), expected_objs) expected_related_names = ['manual_onetoone'] @@ -454,3 +460,7 @@ class InheritanceManagerRelatedTests(InheritanceManagerTests): qs = InheritanceManagerTestParent.objects.annotate( test_count=models.Count('id')).select_subclasses() self.assertEqual(qs.get(id=self.child1.id).test_count, 1) + + def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self): + qs = InheritanceManagerTestParent.objects.select_subclasses() + self.assertEqual(qs.subclasses, qs._clone().subclasses) diff --git a/tests/test_managers/test_join_manager.py b/tests/test_managers/test_join_manager.py new file mode 100644 index 0000000..b8a8131 --- /dev/null +++ b/tests/test_managers/test_join_manager.py @@ -0,0 +1,38 @@ + +from django.test import TestCase + +from tests.models import JoinItemForeignKey, BoxJoinModel + + +class JoinManagerTest(TestCase): + def setUp(self): + for i in range(20): + BoxJoinModel.objects.create(name='name_{i}'.format(i=i)) + + JoinItemForeignKey.objects.create( + weight=10, belonging=BoxJoinModel.objects.get(name='name_1') + ) + JoinItemForeignKey.objects.create(weight=20) + + def test_self_join(self): + a_slice = BoxJoinModel.objects.all()[0:10] + with self.assertNumQueries(1): + result = a_slice.join() + self.assertEquals(result.count(), 10) + + def test_self_join_with_where_statement(self): + qs = BoxJoinModel.objects.filter(name='name_1') + result = qs.join() + self.assertEquals(result.count(), 1) + + def test_join_with_other_qs(self): + item_qs = JoinItemForeignKey.objects.filter(weight=10) + boxes = BoxJoinModel.objects.all().join(qs=item_qs) + self.assertEquals(boxes.count(), 1) + self.assertEquals(boxes[0].name, 'name_1') + + def test_reverse_join(self): + box_qs = BoxJoinModel.objects.filter(name='name_1') + items = JoinItemForeignKey.objects.all().join(box_qs) + self.assertEquals(items.count(), 1) + self.assertEquals(items[0].weight, 10) diff --git a/tests/test_models/test_deferred_fields.py b/tests/test_models/test_deferred_fields.py new file mode 100644 index 0000000..ea8e3bd --- /dev/null +++ b/tests/test_models/test_deferred_fields.py @@ -0,0 +1,71 @@ +from __future__ import unicode_literals + +import django +from django.test import TestCase + +from tests.models import ModelWithCustomDescriptor + + +class CustomDescriptorTests(TestCase): + def setUp(self): + self.instance = ModelWithCustomDescriptor.objects.create( + custom_field='1', + tracked_custom_field='1', + regular_field=1, + tracked_regular_field=1, + ) + + def test_custom_descriptor_works(self): + instance = self.instance + self.assertEqual(instance.custom_field, '1') + self.assertEqual(instance.__dict__['custom_field'], 1) + self.assertEqual(instance.regular_field, 1) + instance.custom_field = 2 + self.assertEqual(instance.custom_field, '2') + self.assertEqual(instance.__dict__['custom_field'], 2) + instance.save() + instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk) + self.assertEqual(instance.custom_field, '2') + self.assertEqual(instance.__dict__['custom_field'], 2) + + def test_deferred(self): + instance = ModelWithCustomDescriptor.objects.only('id').get( + pk=self.instance.pk) + if django.VERSION >= (1, 10): + self.assertIn('custom_field', instance.get_deferred_fields()) + else: + self.assertIn('custom_field', instance._deferred_fields) + self.assertEqual(instance.custom_field, '1') + if django.VERSION >= (1, 10): + self.assertNotIn('custom_field', instance.get_deferred_fields()) + else: + self.assertNotIn('custom_field', instance._deferred_fields) + self.assertEqual(instance.regular_field, 1) + self.assertEqual(instance.tracked_custom_field, '1') + self.assertEqual(instance.tracked_regular_field, 1) + + self.assertFalse(instance.tracker.has_changed('tracked_custom_field')) + self.assertFalse(instance.tracker.has_changed('tracked_regular_field')) + + instance.tracked_custom_field = 2 + instance.tracked_regular_field = 2 + self.assertTrue(instance.tracker.has_changed('tracked_custom_field')) + self.assertTrue(instance.tracker.has_changed('tracked_regular_field')) + instance.save() + + instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk) + self.assertEqual(instance.custom_field, '1') + self.assertEqual(instance.regular_field, 1) + self.assertEqual(instance.tracked_custom_field, '2') + self.assertEqual(instance.tracked_regular_field, 2) + + instance = ModelWithCustomDescriptor.objects.only('id').get(pk=instance.pk) + if django.VERSION >= (1, 10): + # This fails on 1.8 and 1.9, which is a bug in the deferred field + # implementation on those versions. + instance.tracked_custom_field = 3 + self.assertEqual(instance.tracked_custom_field, '3') + self.assertTrue(instance.tracker.has_changed('tracked_custom_field')) + del instance.tracked_custom_field + self.assertEqual(instance.tracked_custom_field, '2') + self.assertFalse(instance.tracker.has_changed('tracked_custom_field')) diff --git a/tests/test_models/test_status_model.py b/tests/test_models/test_status_model.py index f660936..6950dbf 100644 --- a/tests/test_models/test_status_model.py +++ b/tests/test_models/test_status_model.py @@ -18,7 +18,7 @@ class StatusModelTests(TestCase): c1 = self.model.objects.create() self.assertTrue(c1.status_changed, datetime(2016, 1, 1)) - c2 = self.model.objects.create() + self.model.objects.create() self.assertEqual(self.model.active.count(), 2) self.assertEqual(self.model.deleted.count(), 0) diff --git a/tests/test_models/test_timestamped_model.py b/tests/test_models/test_timestamped_model.py index 8760411..cac07f3 100644 --- a/tests/test_models/test_timestamped_model.py +++ b/tests/test_models/test_timestamped_model.py @@ -15,6 +15,13 @@ class TimeStampedModelTests(TestCase): t1 = TimeStamp.objects.create() self.assertEqual(t1.created, datetime(2016, 1, 1)) + def test_created_sets_modified(self): + ''' + Ensure that on creation that modifed is set exactly equal to created. + ''' + t1 = TimeStamp.objects.create() + self.assertEqual(t1.created, t1.modified) + def test_modified(self): with freeze_time(datetime(2016, 1, 1)): t1 = TimeStamp.objects.create() diff --git a/tox.ini b/tox.ini index fe42bb1..c092160 100644 --- a/tox.ini +++ b/tox.ini @@ -1,24 +1,44 @@ [tox] envlist = - py27-django{18,19,110,111} - py34-django{18,19,110,111,200} - py35-django{18,19,110,111,200,trunk} - py36-django{111,200,trunk} + py27-django{19,110,111} + py34-django{19,110,111,200} + py35-django{19,110,111,200,201,trunk} + py36-django{111,200,201,trunk} + flake8 [testenv] deps = - django18: Django>=1.8,<1.9 django19: Django>=1.9,<1.10 django110: Django>=1.10,<1.11 django111: Django>=1.11,<1.12 django200: Django>=2.0,<2.1 + django201: Django>=2.1,<2.2 djangotrunk: https://github.com/django/django/archive/master.tar.gz freezegun == 0.3.8 -rrequirements-test.txt pytest-cov ignore_outcome = djangotrunk: True +passenv = + CI + TRAVIS + TRAVIS_* commands = pip install -e . py.test {posargs} + +[testenv:flake8] +basepython = + python3.6 +deps = + flake8 +commands = + flake8 model_utils tests + +[flake8] +ignore = + E731 ; do not assign a lambda expression, use a def + W503 ; line break before binary operator + E402 ; module level import not at top of file + E501 ; line too long