diff --git a/.coveragerc b/.coveragerc index 65aaf4d..8708371 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,4 +1,2 @@ [run] -source = model_utils -omit = model_utils/tests/* -branch = 1 +include = model_utils/*.py diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..f384dde --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,14 @@ +## Problem + +Explain the problem you encountered. + +## Environment + +- Django Model Utils version: +- Django version: +- Python version: +- Other libraries used, if any: + +## Code examples + +Give code example that demonstrates the issue, or even better, write new tests that fails because of that issue. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..63db703 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,15 @@ +## Problem + +Explain the problem you are fixing (add the link to the related issue(s), if any). + +## Solution + +Explain the solution that has been implemented, and what has been changed. + +## Commandments + +- [ ] Write PEP8 compliant code. +- [ ] Cover it with tests. +- [ ] Update `CHANGES.rst` file to describe the changes, and quote according issue with `GH-`. +- [ ] Pay attention to backward compatibility, or if it breaks it, explain why. +- [ ] Update documentation (if relevant). diff --git a/.gitignore b/.gitignore index c933d31..5f1c259 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ Django-*.egg htmlcov/ docs/_build/ .idea/ +.eggs/ diff --git a/.travis.yml b/.travis.yml index edaae71..e2fd9a3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,36 +1,25 @@ +sudo: True +dist: xenial language: python - +cache: pip python: - - 2.6 - - 2.7 - - 3.2 - - 3.3 - -env: - - DJANGO=Django==1.4.10 SOUTH=1 - - DJANGO=Django==1.5.5 SOUTH=1 - - DJANGO=Django==1.6.1 SOUTH=1 - - DJANGO=https://github.com/django/django/tarball/master SOUTH=1 - -install: - - pip install $DJANGO - - pip install coverage coveralls - - sh -c "if [ '$SOUTH' = '1' ]; then pip install South==0.8.1; fi" - -script: - - coverage run -a setup.py test - - coverage report - -matrix: - exclude: - - python: 2.6 - env: DJANGO=https://github.com/django/django/tarball/master SOUTH=1 - - python: 3.2 - env: DJANGO=Django==1.4.10 SOUTH=1 - - python: 3.3 - env: DJANGO=Django==1.4.10 SOUTH=1 - include: - - python: 2.7 - env: DJANGO=Django==1.5.5 SOUTH=0 - -after_success: coveralls +- 2.7 +- 3.7 +- 3.6 +install: pip install tox-travis codecov +# positional args ({posargs}) to pass into tox.ini +script: tox -- --cov --cov-append +services: + - postgresql +after_success: codecov +deploy: + provider: pypi + user: jazzband + server: https://jazzband.co/projects/django-model-utils/upload + distributions: sdist bdist_wheel + password: + secure: JxUmEdYS8qT+7xhVyzmVD4Gkwqdz5XKxoUhKP795CWIXoJjtlGszyo6w0XfnFs0epXtd1NuCRXdhea+EqWKFDlQ3Yg7m6Y/yTQV6nMHxCPSvicROho7pAiJmfc/x+rSsPt5ag8av6+S07tOqvMnWBBefYbpHRoel78RXkm9l7Mc= + on: + tags: true + repo: jazzband/django-model-utils + python: 3.6 diff --git a/AUTHORS.rst b/AUTHORS.rst index bacbd67..609d55e 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -1,33 +1,56 @@ -Alejandro Varas -Alex Orange -Andy Freeland -Carl Meyer -Curtis Maloney -Den Lesnov -Donald Stufft -Douglas Meehan -Facundo Gaich -Felipe Prenholato -Filipe Ximenes -Gregor Müllegger -ivirabyan -James Oakley -Jannis Leidel -Javier García Sogo -Jeff Elmore -Keryn Knight -Matthew Schinckel -Michael van Tellingen -Mikhail Silonov -Patryk Zawadzki -Paul McLanahan -Rinat Shigapov -Rodney Folz -rsenkbeil -Ryan Kaskel -Simon Meers -sayane -Tony Aldridge -Travis Swicegood -Trey Hunner -zyegfryed +| ad-m +| Adam Barnes +| Alejandro Varas +| Alex Orange +| Alexey Evseev +| Andy Freeland +| Artis Avotins +| Bram Boogaard +| Carl Meyer +| Curtis Maloney +| Den Lesnov +| Dmytro Kyrychuk +| Donald Stufft +| Douglas Meehan +| Emin Bugra Saral +| Facundo Gaich +| Felipe Prenholato +| Filipe Ximenes +| Gregor Müllegger +| Germano Massullo +| Hanley Hansen +| ivirabyan +| James Oakley +| Jannis Leidel +| Jarek Glowacki +| Javier García Sogo +| Jeff Elmore +| Jonathan Sundqvist +| Keryn Knight +| Martey Dodoo +| Matthew Schinckel +| Michael van Tellingen +| Mike Bryant +| Mikhail Silonov +| Patryk Zawadzki +| Paul McLanahan +| Philipp Steinhardt +| Remy Suen +| Rinat Shigapov +| Rodney Folz +| Romain Garrigues +| rsenkbeil +| Ryan Kaskel +| Simon Meers +| sayane +| Tony Aldridge +| Travis Swicegood +| Trey Hunner +| Karl Wan Nan Wo +| zyegfryed +| Radosław Jan Ganczarek +| Lucas Wiman +| Jack Cushman +| Zach Cheung +| Daniel Andrlik +| marfyl \ No newline at end of file diff --git a/CHANGES.rst b/CHANGES.rst index e39ccd5..abd1a8e 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,8 +1,172 @@ CHANGES ======= -master (unreleased) +3.3.0 (2019.08.19) +------------------ +- Added `Choices.subset`. + +3.2.0 (2019.06.21) ------------------- +- Catch `AttributeError` for deferred abstract fields, fixes GH-331. +- Update documentation to explain usage of `timeframed` model manager, fixes GH-118 +- Honor `OneToOneField.parent_link=False`. +- Fix handling of deferred attributes on Django 1.10+, fixes GH-278 +- Fix `FieldTracker.has_changed()` and `FieldTracker.previous()` to return + correct responses for deferred fields. +- Add Simplified Chinese translations. +- Update AutoLastModifiedField so that at instance creation it will + always be set equal to created to make querying easier. Fixes GH-254 +- Support `reversed` for all kinds of `Choices` objects, fixes GH-309 +- Fix Model instance non picklable GH-330 +- Fix patched `save` in FieldTracker +- Upgrades test requirements (pytest, pytest-django, pytest-cov) and + skips tox test with Python 3.5 and Django (trunk) +- Add UUIDModel and UUIDField support. + +3.1.2 (2018.05.09) +------------------ +* Update InheritanceIterable to inherit from + ModelIterable instead of BaseIterable, fixes GH-277. + +* Add all_objects Manager for 'SoftDeletableModel' to include soft + deleted objects on queries as per issue GH-255 + +3.1.1 (2017.12.17) +------------------ + +- Update classifiers and README via GH-306, fixes GH-305 + +3.1.0 (2017.12.11) +------------------ + +- Support for Django 2.0 via GH-298, fixes GH-297 +- Remove old travis script via GH-300 +- Fix codecov and switch to py.test #301 + +3.0.0 (2017.04.13) +------------------ + +* Drop support for Python 2.6. +* Drop support for Django 1.4, 1.5, 1.6, 1.7. +* Exclude tests from the distribution, fixes GH-258. +* Add support for Django 1.11 GH-269 +* Add a new model to disable pre_save/post_save signals + + +2.6.1 (2017.01.11) +------------------ + +* Fix infinite recursion with multiple `MonitorField` and `defer()` or `only()` + on Django 1.10+. Thanks Romain Garrigues. Merge of GH-242, fixes GH-241. + +* Fix `InheritanceManager` and `SoftDeletableManager` to respect + `self._queryset_class` instead of hardcoding the queryset class. Merge of + GH-250, fixes GH-249. + +* Add mixins for `SoftDeletableQuerySet` and `SoftDeletableManager`, as stated + in the the documentation. + +* Fix `SoftDeletableModel.delete()` to use the correct database connection. + Merge of GH-239. + +* Added boolean keyword argument `soft` to `SoftDeletableModel.delete()` that + revert to default behavior when set to `False`. Merge of GH-240. + +* Enforced default manager in `StatusModel` to avoid manager order issues when + using abstract models that redefine `objects` manager. Merge of GH-253, fixes + GH-251. + + +2.6 (2016.09.19) +---------------- + +* Added `SoftDeletableModel` abstract class, its manageer + `SoftDeletableManager` and queryset `SoftDeletableQuerySet`. + +* Fix issue with field tracker and deferred FileField for Django 1.10. + + +2.5.2 (2016.08.09) +------------------ + +* Include `runtests.py` in sdist. + + +2.5.1 (2016.08.03) +------------------ + +* Fix `InheritanceQuerySet` raising an `AttributeError` exception + under Django 1.9. + +* Django 1.10 support regressed with changes between pre-alpha and final + release; 1.10 currently not supported. + + +2.5 (2016.04.18) +---------------- + +* Drop support for Python 3.2. + +* Add support for Django 1.10 pre-alpha. + +* Track foreign keys on parent models properly when a tracker + is defined on a child model. Fixes GH-214. + + +2.4 (2015.12.03) +---------------- + +* Remove `PassThroughManager`. Use Django's built-in `QuerySet.as_manager()` + and/or `Manager.from_queryset()` utilities instead. + +* Add support for Django 1.9. + + +2.3.1 (2015-07-20) +------------------ + +* Remove all translation-related automation in `setup.py`. Fixes GH-178 and + GH-179. Thanks Joe Weiss, Matt Molyneaux, and others for the reports. + + +2.3 (2015.07.17) +---------------- + +* Keep track of deferred fields on model instance instead of on + FieldInstanceTracker instance. Fixes accessing deferred fields for multiple + instances of a model from the same queryset. Thanks Bram Boogaard. Merge of + GH-151. + +* Fix Django 1.7 migrations compatibility for SplitField. Thanks ad-m. Merge of + GH-157; fixes GH-156. + +* Add German translations. + +* Django 1.8 compatibility. + + +2.2 (2014.07.31) +---------------- + +* Revert GH-130, restoring ability to access ``FieldTracker`` changes in + overridden ``save`` methods or ``post_save`` handlers. This reopens GH-83 + (inability to pickle models with ``FieldTracker``) until a solution can be + found that doesn't break behavior otherwise. Thanks Brian May for the + report. Fixes GH-143. + + +2.1.1 (2014.07.28) +------------------ + +* ASCII-fold all non-ASCII characters in changelog; again. Argh. Apologies to + those whose names are mangled by this change. It seems that distutils makes + it impossible to handle non-ASCII content reliably under Python 3 in a + setup.py long_description, when the system encoding may be ASCII. Thanks + Brian May for the report. Fixes GH-141. + + +2.1.0 (2014.07.25) +------------------ * Add support for Django's built-in migrations to ``MonitorField`` and ``StatusField``. @@ -11,7 +175,7 @@ master (unreleased) ``dir``, allowing `IPython`_ tab completion to be useful. Merge of GH-104, fixes GH-55. -* Add pickle support for models using ``FieldTracker``. Thanks Ondrej Slinták +* Add pickle support for models using ``FieldTracker``. Thanks Ondrej Slintak for the report. Thanks Matthew Schinckel for the fix. Merge of GH-130, fixes GH-83. @@ -253,4 +417,3 @@ master (unreleased) ----- * Added ``QueryManager`` - diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index a789156..23b903b 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -1,32 +1,62 @@ Contributing ============ -Below is a list of tips for submitting issues and pull requests. These are -suggestions and not requirements. +.. image:: https://jazzband.co/static/img/jazzband.svg + :target: https://jazzband.co/ + :alt: Jazzband + +This is a `Jazzband `_ project. By contributing you agree +to abide by the `Contributor Code of Conduct +`_ and follow the `guidelines +`_. + +Below is a list of tips for submitting issues and pull requests. Submitting Issues ----------------- -Issues are often easier to reproduce/resolve when they have: +Issues are easier to reproduce/resolve when they have: - A pull request with a failing test demonstrating the issue - A code example that produces the issue consistently - A traceback (when applicable) + Pull Requests ------------- -When creating a pull request, try to: +When creating a pull request: -- Write tests if applicable -- Note important changes in the `CHANGES`_ file -- Update the documentation if needed +- Write tests +- Note user-facing changes in the `CHANGES`_ file +- Update the documentation - Add yourself to the `AUTHORS`_ file +- If you have added or changed translated strings, run ``make messages`` to + update the ``.po`` translation files, and update translations for any + languages you know. Then run ``make compilemessages`` to compile the ``.mo`` + files. If your pull request leaves some translations incomplete, please + mention that in the pull request and commit message. .. _AUTHORS: AUTHORS.rst .. _CHANGES: CHANGES.rst +Translations +------------ + +If you are able to provide translations for a new language or to update an +existing translation file, make sure to run makemessages beforehand:: + + python django-admin.py makemessages -l ISO_LANGUAGE_CODE + +This command will collect all translation strings from the source directory +and create or update the translation file for the given language. Now open the +translation file (.po) with a text-editor and start editing. +After you finished editing add yourself to the list of translators. +If you have created a new translation, make sure to copy the header from one +of the existing translation files. + + Testing ------- diff --git a/LICENSE.txt b/LICENSE.txt index ab400d7..01e3613 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2009-2013, Carl Meyer and contributors +Copyright (c) 2009-2019, Carl Meyer and contributors All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/MANIFEST.in b/MANIFEST.in index ddc2005..9063cef 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,4 +3,8 @@ include CHANGES.rst include LICENSE.txt include MANIFEST.in include README.rst -include TODO.rst +include requirements*.txt +include Makefile tox.ini +recursive-include model_utils/locale *.po *.mo +graft docs +recursive-include tests *.py diff --git a/Makefile b/Makefile index c736206..4785478 100644 --- a/Makefile +++ b/Makefile @@ -13,3 +13,9 @@ docs: documentation documentation: python setup.py build_sphinx + +messages: + python translations.py make + +compilemessages: + python translations.py compile diff --git a/README.rst b/README.rst index 0f76cf9..490e170 100644 --- a/README.rst +++ b/README.rst @@ -2,44 +2,47 @@ django-model-utils ================== -.. image:: https://secure.travis-ci.org/carljm/django-model-utils.png?branch=master - :target: http://travis-ci.org/carljm/django-model-utils -.. image:: https://coveralls.io/repos/carljm/django-model-utils/badge.png?branch=master - :target: https://coveralls.io/r/carljm/django-model-utils -.. image:: https://pypip.in/v/django-model-utils/badge.png - :target: https://crate.io/packages/django-model-utils +.. image:: https://jazzband.co/static/img/badge.svg + :target: https://jazzband.co/ + :alt: Jazzband +.. image:: https://travis-ci.org/jazzband/django-model-utils.svg?branch=master + :target: https://travis-ci.org/jazzband/django-model-utils +.. image:: https://codecov.io/gh/jazzband/django-model-utils/branch/master/graph/badge.svg + :target: https://codecov.io/gh/jazzband/django-model-utils +.. image:: https://img.shields.io/pypi/v/django-model-utils.svg + :target: https://pypi.python.org/pypi/django-model-utils Django model mixins and utilities. -``django-model-utils`` supports `Django`_ 1.4.10 and later on Python 2.6, 2.7, -3.2, 3.3 and 3.4. +``django-model-utils`` supports `Django`_ 1.11 and 2.1+. .. _Django: http://www.djangoproject.com/ - -Getting Help -============ - -Documentation for django-model-utils is available at https://django-model-utils.readthedocs.org/ - This app is available on `PyPI`_. .. _PyPI: https://pypi.python.org/pypi/django-model-utils/ +Getting Help +============ + +Documentation for django-model-utils is available +https://django-model-utils.readthedocs.io/ + + +Run tests +--------- + +.. code-block + + pip install -e . + py.test Contributing ============ Please file bugs and send pull requests to the `GitHub repository`_ and `issue -tracker`_. - -.. _GitHub repository: https://github.com/carljm/django-model-utils/ -.. _issue tracker: https://github.com/carljm/django-model-utils/issues - -(Until January 2013 django-model-utils primary development was hosted at -`BitBucket`_; the issue tracker there will remain open until all issues and -pull requests tracked in it are closed, but all new issues should be filed at -GitHub.) - -.. _BitBucket: https://bitbucket.org/carljm/django-model-utils/overview +tracker`_. See `CONTRIBUTING.rst`_ for details. +.. _GitHub repository: https://github.com/jazzband/django-model-utils/ +.. _issue tracker: https://github.com/jazzband/django-model-utils/issues +.. _CONTRIBUTING.rst: https://github.com/jazzband/django-model-utils/blob/master/CONTRIBUTING.rst diff --git a/TODO.rst b/TODO.rst deleted file mode 100644 index 4218c78..0000000 --- a/TODO.rst +++ /dev/null @@ -1,4 +0,0 @@ -TODO -==== - -* Switch to proper test skips once Django 1.3 is minimum supported. diff --git a/docs/conf.py b/docs/conf.py index d79adff..9f0c4e7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -41,7 +41,7 @@ master_doc = 'index' # General information about the project. project = u'django-model-utils' -copyright = u'2013, Carl Meyer' +copyright = u'2015, Carl Meyer' parent_dir = os.path.dirname(os.path.dirname(__file__)) diff --git a/docs/fields.rst b/docs/fields.rst index 02ca6ef..87c298e 100644 --- a/docs/fields.rst +++ b/docs/fields.rst @@ -154,3 +154,29 @@ If no marker is found in the content, the first two paragraphs (where paragraphs are blocks of text separated by a blank line) are taken to be the excerpt. This number can be customized by setting the ``SPLIT_DEFAULT_PARAGRAPHS`` setting. + + +UUIDField +---------- + +A ``UUIDField`` subclass that provides an UUID field. You can +add this field to any model definition. + +With the param ``primary_key`` you can set if this field is the +primary key for the model, default is True. + +Param ``version`` is an integer that set default UUID version. +Versions 1,3,4 and 5 are supported, default is 4. + +If ``editable`` is set to false the field will not be displayed in the admin +or any other ModelForm, default is False. + + +.. code-block:: python + + from django.db import models + from model_utils.fields import UUIDField + + class MyAppModel(models.Model): + uuid = UUIDField(primary_key=True, version=4, editable=False) + diff --git a/docs/index.rst b/docs/index.rst index 9b6d2bb..411d2c4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,8 +24,8 @@ Contributing Please file bugs and send pull requests to the `GitHub repository`_ and `issue tracker`_. -.. _GitHub repository: https://github.com/carljm/django-model-utils/ -.. _issue tracker: https://github.com/carljm/django-model-utils/issues +.. _GitHub repository: https://github.com/jazzband/django-model-utils/ +.. _issue tracker: https://github.com/jazzband/django-model-utils/issues diff --git a/docs/managers.rst b/docs/managers.rst index 0fb3144..b90f3d4 100644 --- a/docs/managers.rst +++ b/docs/managers.rst @@ -84,15 +84,35 @@ If you don't explicitly call ``select_subclasses()`` or ``get_subclass()``, an ``InheritanceManager`` behaves identically to a normal ``Manager``; so it's safe to use as your default manager for the model. -.. note:: - - Due to `Django bug #16572`_, on Django versions prior to 1.6 - ``InheritanceManager`` only supports a single level of model inheritance; - it won't work for grandchild models. - .. _contributed by Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/ -.. _Django bug #16572: https://code.djangoproject.com/ticket/16572 +JoinManager +----------- + +The ``JoinManager`` will create a temporary table of your current queryset +and join that temporary table with the model of your current queryset. This can +be advantageous if you have to page through your entire DB and using django's +slice mechanism to do that. ``LIMIT .. OFFSET ..`` becomes slower the bigger +offset you use. + +.. code-block:: python + + sliced_qs = Place.objects.all()[2000:2010] + qs = sliced_qs.join() + # qs contains 10 objects, and there will be a much smaller performance hit + # for paging through all of first 2000 objects. + +Alternatively, you can give it a queryset and the manager will create a temporary +table and join that to your current queryset. This can work as a more performant +alternative to using django's ``__in`` as described in the following +(`StackExchange answer`_). + +.. code-block:: python + + big_qs = Restaurant.objects.filter(menu='vegetarian') + qs = Country.objects.filter(country_code='SE').join(big_qs) + +.. _StackExchange answer: https://dba.stackexchange.com/questions/91247/optimizing-a-postgres-query-with-a-large-in .. _QueryManager: @@ -124,98 +144,19 @@ set the ordering of the ``QuerySet`` returned by the ``QueryManager`` by chaining a call to ``.order_by()`` on the ``QueryManager`` (this is not required). +SoftDeletableManager +-------------------- -PassThroughManager ------------------- - -A common "gotcha" when defining methods on a custom manager class is that those -same methods are not automatically also available on the QuerySets returned by -that manager, so are not "chainable". This can be counterintuitive, as most of -the public QuerySet API is mirrored on managers. It is possible to create a -custom Manager that returns QuerySets that have the same additional methods, -but this requires boilerplate code. The ``PassThroughManager`` class -(`contributed by Paul McLanahan`_) removes this boilerplate. - -.. _contributed by Paul McLanahan: http://paulm.us/post/3717466639/passthroughmanager-for-django - -To use ``PassThroughManager``, rather than defining a custom manager with -additional methods, define a custom ``QuerySet`` subclass with the additional -methods you want, and pass that ``QuerySet`` subclass to the -``PassThroughManager.for_queryset_class()`` class method. The returned -``PassThroughManager`` subclass will always return instances of your custom -``QuerySet``, and you can also call methods of your custom ``QuerySet`` -directly on the manager: - -.. code-block:: python - - from datetime import datetime - from django.db import models - from django.db.models.query import QuerySet - from model_utils.managers import PassThroughManager - - class PostQuerySet(QuerySet): - def by_author(self, user): - return self.filter(user=user) - - def published(self): - return self.filter(published__lte=datetime.now()) - - def unpublished(self): - return self.filter(published__gte=datetime.now()) - - - class Post(models.Model): - user = models.ForeignKey(User) - published = models.DateTimeField() - - objects = PassThroughManager.for_queryset_class(PostQuerySet)() - - Post.objects.published() - Post.objects.by_author(user=request.user).unpublished() +Returns only model instances that have the ``is_removed`` field set +to False. Uses ``SoftDeletableQuerySet``, which ensures model instances +won't be removed in bulk, but they will be marked as removed instead. Mixins ------ Each of the above manager classes has a corresponding mixin that can be used to -add functionality to any manager. For example, to create a GeoDjango -``GeoManager`` that includes "pass through" functionality, you can write the -following code: +add functionality to any manager. -.. code-block:: python - - from django.contrib.gis.db import models - from django.contrib.gis.db.models.query import GeoQuerySet - - from model_utils.managers import PassThroughManagerMixin - - class PassThroughGeoManager(PassThroughManagerMixin, models.GeoManager): - pass - - class LocationQuerySet(GeoQuerySet): - def within_boundary(self, geom): - return self.filter(point__within=geom) - - def public(self): - return self.filter(public=True) - - class Location(models.Model): - point = models.PointField() - public = models.BooleanField(default=True) - objects = PassThroughGeoManager.for_queryset_class(LocationQuerySet)() - - Location.objects.public() - Location.objects.within_boundary(geom=geom) - Location.objects.within_boundary(geom=geom).public() - - -Now you have a "pass through manager" that can also take advantage of -GeoDjango's spatial lookups. You can similarly add additional functionality to -any manager by composing that manager with ``InheritanceManagerMixin`` or -``QueryManagerMixin``. - -(Note that any manager class using ``InheritanceManagerMixin`` must return a +Note that any manager class using ``InheritanceManagerMixin`` must return a ``QuerySet`` class using ``InheritanceQuerySetMixin`` from its ``get_queryset`` -method. This means that if composing ``InheritanceManagerMixin`` and -``PassThroughManagerMixin``, the ``QuerySet`` class passed to -``PassThroughManager.for_queryset_class`` must inherit -``InheritanceQuerySetMixin``.) +method. diff --git a/docs/models.rst b/docs/models.rst index 7a05c79..31707f4 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -5,10 +5,41 @@ TimeFramedModel --------------- An abstract base class for any model that expresses a time-range. Adds -``start`` and ``end`` nullable DateTimeFields, and a ``timeframed`` -manager that returns only objects for whom the current date-time lies -within their time range. +``start`` and ``end`` nullable DateTimeFields, and provides a new +``timeframed`` manager on the subclass whose queryset pre-filters results +to only include those which have a ``start`` which is not in the future, +and an ``end`` which is not in the past. If either ``start`` or ``end`` is +``null``, the manager will include it. +.. code-block:: python + + from model_utils.models import TimeFramedModel + from datetime import datetime, timedelta + class Post(TimeFramedModel): + pass + + p = Post() + p.start = datetime.utcnow() - timedelta(days=1) + p.end = datetime.utcnow() + timedelta(days=7) + p.save() + + # this query will return the above Post instance: + Post.timeframed.all() + + p.start = None + p.end = None + p.save() + + # this query will also return the above Post instance, because + # the `start` and/or `end` are NULL. + Post.timeframed.all() + + p.start = datetime.utcnow() + timedelta(days=7) + p.save() + + # this query will NOT return our Post instance, because + # the start date is in the future. + Post.timeframed.all() TimeStampedModel ---------------- @@ -47,3 +78,51 @@ returns objects with that status only: # this query will only return published articles: Article.published.all() + + +SoftDeletableModel +------------------ + +This abstract base class just provides field ``is_removed`` which is +set to True instead of removing the instance. Entities returned in +default manager are limited to not-deleted instances. + + +UUIDModel +------------------ + +This abstract base class provides ``id`` field on any model that inherits from it +which will be the primary key. + +If you dont want to set ``id`` as primary key or change the field name, you can be override it +with our `UUIDField`_ + +Also you can override the default uuid version. Versions 1,3,4 and 5 are now supported. + +.. code-block:: python + + from model_utils.models import UUIDModel + + class MyAppModel(UUIDModel): + pass + + +.. _`UUIDField`: https://github.com/jazzband/django-model-utils/blob/master/docs/fields.rst#uuidfield + + +SaveSignalHandlingModel +----------------------- + +An abstract base class model to pass a parameter ``signals_to_disable`` +to ``save`` method in order to disable signals + +.. code-block:: python + + from model_utils.models import SaveSignalHandlingModel + + class SaveSignalTestModel(SaveSignalHandlingModel): + name = models.CharField(max_length=20) + + obj = SaveSignalTestModel(name='Test') + # Note: If you use `Model.objects.create`, the signals can't be disabled + obj.save(signals_to_disable=['pre_save'] # disable `pre_save` signal diff --git a/docs/setup.rst b/docs/setup.rst index 5621649..1fca10c 100644 --- a/docs/setup.rst +++ b/docs/setup.rst @@ -17,7 +17,7 @@ modify your ``INSTALLED_APPS`` setting. Dependencies ============ -``django-model-utils`` supports `Django`_ 1.4.2 and later on Python 2.6, 2.7, -3.2, and 3.3. +``django-model-utils`` supports `Django`_ 1.8 through 2.1 (latest bugfix +release in each series only) on Python 2.7, 3.4, 3.5 and 3.6. .. _Django: http://www.djangoproject.com/ diff --git a/docs/utilities.rst b/docs/utilities.rst index e8d22d2..3cbbdc7 100644 --- a/docs/utilities.rst +++ b/docs/utilities.rst @@ -84,6 +84,27 @@ instances and other iterable objects that could be converted into Choices: STATUS = GENERIC_CHOICES + [(2, 'featured', _('featured'))] status = models.IntegerField(choices=STATUS, default=STATUS.draft) +Should you wish to provide a subset of choices for a field, for +instance, you have a form class to set some model instance to a failed +state, and only wish to show the user the failed outcomes from which to +select, you can use the ``subset`` method: + +.. code-block:: python + + from model_utils import Choices + + OUTCOMES = Choices( + (0, 'success', _('Successful')), + (1, 'user_cancelled', _('Cancelled by the user')), + (2, 'admin_cancelled', _('Cancelled by an admin')), + ) + FAILED_OUTCOMES = OUTCOMES.subset('user_cancelled', 'admin_cancelled') + +The ``choices`` attribute on the model field can then be set to +``FAILED_OUTCOMES``, thus allowing the subset to be defined in close +proximity to the definition of all the choices, and reused elsewhere as +required. + Field Tracker ============= @@ -150,10 +171,14 @@ Returns the value of the given field during the last save: Returns ``None`` when the model instance isn't saved yet. +If a field is `deferred`_, calling ``previous()`` will load the previous value from the database. + +.. _deferred: https://docs.djangoproject.com/en/2.0/ref/models/querysets/#defer + has_changed ~~~~~~~~~~~ -Returns ``True`` if the given field has changed since the last save: +Returns ``True`` if the given field has changed since the last save. The ``has_changed`` method expects a single field: .. code-block:: pycon @@ -167,6 +192,8 @@ Returns ``True`` if the given field has changed since the last save: The ``has_changed`` method relies on ``previous`` to determine whether a field's values has changed. +If a field is `deferred`_ and has been assigned locally, calling ``has_changed()`` +will load the previous value from the database to perform the comparison. changed ~~~~~~~ diff --git a/model_utils/__init__.py b/model_utils/__init__.py index d3ccdf2..9a87de5 100644 --- a/model_utils/__init__.py +++ b/model_utils/__init__.py @@ -1,4 +1,4 @@ -from .choices import Choices -from .tracker import FieldTracker, ModelTracker +from .choices import Choices # noqa:F401 +from .tracker import FieldTracker, ModelTracker # noqa:F401 -__version__ = '2.0.3.post1' +__version__ = '3.3.0' diff --git a/model_utils/choices.py b/model_utils/choices.py index d48ba90..31d5aa1 100644 --- a/model_utils/choices.py +++ b/model_utils/choices.py @@ -57,7 +57,6 @@ class Choices(object): self._process(choices) - def _store(self, triple, triple_collector, double_collector): self._identifier_map[triple[1]] = triple[0] self._display_map[triple[0]] = triple[2] @@ -65,7 +64,6 @@ class Choices(object): triple_collector.append(triple) double_collector.append((triple[0], triple[2])) - def _process(self, choices, triple_collector=None, double_collector=None): if triple_collector is None: triple_collector = self._triples @@ -94,18 +92,18 @@ class Choices(object): raise ValueError( "Choices can't take a list of length %s, only 2 or 3" % len(choice) - ) + ) else: store((choice, choice, choice)) - def __len__(self): return len(self._doubles) - def __iter__(self): return iter(self._doubles) + def __reversed__(self): + return reversed(self._doubles) def __getattr__(self, attname): try: @@ -113,11 +111,9 @@ class Choices(object): except KeyError: raise AttributeError(attname) - def __getitem__(self, key): return self._display_map[key] - def __add__(self, other): if isinstance(other, self.__class__): other = other._triples @@ -125,29 +121,38 @@ class Choices(object): other = list(other) return Choices(*(self._triples + other)) - def __radd__(self, other): # radd is never called for matching types, so we don't check here other = list(other) return Choices(*(other + self._triples)) - def __eq__(self, other): if isinstance(other, self.__class__): return self._triples == other._triples return False - def __repr__(self): return '%s(%s)' % ( self.__class__.__name__, ', '.join(("%s" % repr(i) for i in self._triples)) - ) - + ) def __contains__(self, item): return item in self._db_values - def __deepcopy__(self, memo): return self.__class__(*copy.deepcopy(self._triples, memo)) + + def subset(self, *new_identifiers): + identifiers = set(self._identifier_map.keys()) + + if not identifiers.issuperset(new_identifiers): + raise ValueError( + 'The following identifiers are not present: %s' % + identifiers.symmetric_difference(new_identifiers), + ) + + return self.__class__(*[ + choice for choice in self._triples + if choice[1] in new_identifiers + ]) diff --git a/model_utils/fields.py b/model_utils/fields.py index 8728d56..989c166 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -1,7 +1,10 @@ from __future__ import unicode_literals +import django +import uuid from django.db import models from django.conf import settings +from django.core.exceptions import ValidationError from django.utils.encoding import python_2_unicode_compatible from django.utils.timezone import now @@ -16,6 +19,7 @@ class AutoCreatedField(models.DateTimeField): By default, sets editable=False, default=datetime.now. """ + def __init__(self, *args, **kwargs): kwargs.setdefault('editable', False) kwargs.setdefault('default', now) @@ -29,8 +33,14 @@ class AutoLastModifiedField(AutoCreatedField): By default, sets editable=False and default=datetime.now. """ + def pre_save(self, model_instance, add): value = now() + if not model_instance.pk: + for field in model_instance._meta.get_fields(): + if isinstance(field, AutoCreatedField): + value = getattr(model_instance, field.name) + break setattr(model_instance, self.attname, value) return value @@ -47,6 +57,7 @@ class StatusField(models.CharField): Also features a ``no_check_for_status`` argument to make sure South can handle this field when it freezes a model. """ + def __init__(self, *args, **kwargs): kwargs.setdefault('max_length', 100) self.check_for_status = not kwargs.pop('no_check_for_status', False) @@ -59,6 +70,8 @@ class StatusField(models.CharField): "To use StatusField, the model '%s' must have a %s choices class attribute." \ % (sender.__name__, self.choices_name) self._choices = getattr(sender, self.choices_name) + if django.VERSION >= (1, 9, 0): + self.choices = self._choices if not self.has_default(): self.default = tuple(getattr(sender, self.choices_name))[0][0] # set first as default @@ -68,6 +81,8 @@ class StatusField(models.CharField): # the STATUS class attr being available), but we need to set some dummy # choices now so the super method will add the get_FOO_display method self._choices = [(0, 'dummy')] + if django.VERSION >= (1, 9, 0): + self.choices = self._choices super(StatusField, self).contribute_to_class(cls, name) def deconstruct(self): @@ -83,6 +98,7 @@ class MonitorField(models.DateTimeField): changes. """ + def __init__(self, *args, **kwargs): kwargs.setdefault('default', now) monitor = kwargs.pop('monitor', None) @@ -105,6 +121,9 @@ class MonitorField(models.DateTimeField): return getattr(instance, self.monitor) def _save_initial(self, sender, instance, **kwargs): + if django.VERSION >= (1, 10) and self.monitor in instance.get_deferred_fields(): + # Fix related to issue #241 to avoid recursive error on double monitor fields + return setattr(instance, self.monitor_attname, self.get_monitored_value(instance)) @@ -120,8 +139,7 @@ class MonitorField(models.DateTimeField): def deconstruct(self): name, path, args, kwargs = super(MonitorField, self).deconstruct() - if self.monitor is not None: - kwargs['monitor'] = self.monitor + kwargs['monitor'] = self.monitor if self.when is not None: kwargs['when'] = self.when return name, path, args, kwargs @@ -132,7 +150,10 @@ SPLIT_MARKER = getattr(settings, 'SPLIT_MARKER', '') # the number of paragraphs after which to split if no marker SPLIT_DEFAULT_PARAGRAPHS = getattr(settings, 'SPLIT_DEFAULT_PARAGRAPHS', 2) -_excerpt_field_name = lambda name: '_%s_excerpt' % name + +def _excerpt_field_name(name): + return '_%s_excerpt' % name + def get_excerpt(content): excerpt = [] @@ -149,6 +170,7 @@ def get_excerpt(content): return '\n'.join(default_excerpt) + @python_2_unicode_compatible class SplitText(object): def __init__(self, instance, field_name, excerpt_field_name): @@ -159,11 +181,13 @@ class SplitText(object): self.excerpt_field_name = excerpt_field_name # content is read/write - def _get_content(self): + @property + def content(self): return self.instance.__dict__[self.field_name] - def _set_content(self, val): + + @content.setter + def content(self, val): setattr(self.instance, self.field_name, val) - content = property(_get_content, _set_content) # excerpt is a read only property def _get_excerpt(self): @@ -178,6 +202,7 @@ class SplitText(object): def __str__(self): return self.content + class SplitDescriptor(object): def __init__(self, field): self.field = field @@ -198,6 +223,7 @@ class SplitDescriptor(object): else: obj.__dict__[self.field.name] = value + class SplitField(models.TextField): def __init__(self, *args, **kwargs): # for South FakeORM compatibility: the frozen version of a @@ -221,7 +247,7 @@ class SplitField(models.TextField): return value.content def value_to_string(self, obj): - value = self._get_val_from_obj(obj) + value = self.value_from_object(obj) return value.content def get_prep_value(self, value): @@ -230,30 +256,53 @@ class SplitField(models.TextField): except AttributeError: return value + def deconstruct(self): + name, path, args, kwargs = super(SplitField, self).deconstruct() + kwargs['no_excerpt_field'] = True + return name, path, args, kwargs -# allow South to handle these fields smoothly -try: - from south.modelsinspector import add_introspection_rules - # For a normal MarkupField, the add_excerpt_field attribute is - # always True, which means no_excerpt_field arg will always be - # True in a frozen MarkupField, which is what we want. - add_introspection_rules(rules=[ - ( - (SplitField,), - [], - {'no_excerpt_field': ('add_excerpt_field', {})} - ), - ( - (MonitorField,), - [], - {'monitor': ('monitor', {})} - ), - ( - (StatusField,), - [], - {'no_check_for_status': ('check_for_status', {})} - ), - ], patterns=['model_utils\.fields\.']) -except ImportError: - pass +class UUIDField(models.UUIDField): + """ + A field for storing universally unique identifiers. Use Python UUID class. + """ + + def __init__(self, primary_key=True, version=4, editable=False, *args, **kwargs): + """ + Parameters + ---------- + primary_key : bool + If True, this field is the primary key for the model. + version : int + An integer that set default UUID version. + editable : bool + If False, the field will not be displayed in the admin or any other ModelForm, + default is false. + + Raises + ------ + ValidationError + UUID version 2 is not supported. + """ + + if version == 2: + raise ValidationError( + 'UUID version 2 is not supported.') + + if version < 1 or version > 5: + raise ValidationError( + 'UUID version is not valid.') + + if version == 1: + default = uuid.uuid1 + elif version == 3: + default = uuid.uuid3 + elif version == 4: + default = uuid.uuid4 + elif version == 5: + default = uuid.uuid5 + + kwargs.setdefault('primary_key', primary_key) + kwargs.setdefault('editable', editable) + kwargs.setdefault('default', default) + super(UUIDField, self).__init__(*args, **kwargs) diff --git a/model_utils/locale/cs/LC_MESSAGES/django.mo b/model_utils/locale/cs/LC_MESSAGES/django.mo new file mode 100644 index 0000000..758c32d Binary files /dev/null and b/model_utils/locale/cs/LC_MESSAGES/django.mo differ diff --git a/model_utils/locale/cs/LC_MESSAGES/django.po b/model_utils/locale/cs/LC_MESSAGES/django.po new file mode 100644 index 0000000..eae5e9a --- /dev/null +++ b/model_utils/locale/cs/LC_MESSAGES/django.po @@ -0,0 +1,46 @@ +# Czech translations of django-model-utils +# +# This file is distributed under the same license as the django-model-utils package. +# +# Translators: +# ------------ +# Václav Dohnal , 2018. +# +msgid "" +msgstr "" +"Project-Id-Version: django-model-utils\n" +"Report-Msgid-Bugs-To: https://github.com/jazzband/django-model-utils/issues\n" +"POT-Creation-Date: 2018-05-04 13:40+0200\n" +"PO-Revision-Date: 2018-05-04 13:46+0200\n" +"Language: cs\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=3; plural=(n==1) ? 0 : (n>=2 && n<=4) ? 1 : 2;\n" +"Last-Translator: Václav Dohnal \n" +"Language-Team: N/A\n" +"X-Generator: Poedit 2.0.7\n" + +#: .\models.py:24 +msgid "created" +msgstr "vytvořeno" + +#: .\models.py:25 +msgid "modified" +msgstr "upraveno" + +#: .\models.py:37 +msgid "start" +msgstr "začátek" + +#: .\models.py:38 +msgid "end" +msgstr "konec" + +#: .\models.py:53 +msgid "status" +msgstr "stav" + +#: .\models.py:54 +msgid "status changed" +msgstr "změna stavu" diff --git a/model_utils/locale/de/LC_MESSAGES/django.mo b/model_utils/locale/de/LC_MESSAGES/django.mo new file mode 100644 index 0000000..7b80928 Binary files /dev/null and b/model_utils/locale/de/LC_MESSAGES/django.mo differ diff --git a/model_utils/locale/de/LC_MESSAGES/django.po b/model_utils/locale/de/LC_MESSAGES/django.po new file mode 100644 index 0000000..342b3cf --- /dev/null +++ b/model_utils/locale/de/LC_MESSAGES/django.po @@ -0,0 +1,53 @@ +# This file is distributed under the same license as the django-model-utils package. +# +# Translators: +# Philipp Steinhardt , 2015. +msgid "" +msgstr "" +"Project-Id-Version: django-model-utils\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2015-07-20 10:17-0600\n" +"PO-Revision-Date: 2015-07-01 10:12+0200\n" +"Last-Translator: Philipp Steinhardt \n" +"Language-Team: \n" +"Language: de\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" + +#: models.py:20 +msgid "created" +msgstr "erstellt" + +#: models.py:21 +msgid "modified" +msgstr "bearbeitet" + +#: models.py:33 +msgid "start" +msgstr "Beginn" + +#: models.py:34 +msgid "end" +msgstr "Ende" + +#: models.py:49 +msgid "status" +msgstr "Status" + +#: models.py:50 +msgid "status changed" +msgstr "Status geändert" + +#: tests/models.py:106 tests/models.py:115 tests/models.py:124 +msgid "active" +msgstr "aktiv" + +#: tests/models.py:107 tests/models.py:116 tests/models.py:125 +msgid "deleted" +msgstr "gelöscht" + +#: tests/models.py:108 tests/models.py:117 tests/models.py:126 +msgid "on hold" +msgstr "wartend" diff --git a/model_utils/locale/ru/LC_MESSAGES/django.mo b/model_utils/locale/ru/LC_MESSAGES/django.mo new file mode 100644 index 0000000..8edb2d8 Binary files /dev/null and b/model_utils/locale/ru/LC_MESSAGES/django.mo differ diff --git a/model_utils/locale/ru/LC_MESSAGES/django.po b/model_utils/locale/ru/LC_MESSAGES/django.po new file mode 100644 index 0000000..bd5d90c --- /dev/null +++ b/model_utils/locale/ru/LC_MESSAGES/django.po @@ -0,0 +1,43 @@ +# This file is distributed under the same license as the django-model-utils package. +# +# Translators: +# Arseny Sysolyatin , 2017. +msgid "" +msgstr "" +"Project-Id-Version: django-model-utils\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2017-05-22 19:46+0300\n" +"PO-Revision-Date: 2017-05-22 19:46+0300\n" +"Last-Translator: Arseny Sysolyatin \n" +"Language-Team: \n" +"Language: ru\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=4; plural=(n%10==1 && n%100!=11 ? 0 : n%10>=2 && n" +"%10<=4 && (n%100<12 || n%100>14) ? 1 : n%10==0 || (n%10>=5 && n%10<=9) || (n" +"%100>=11 && n%100<=14)? 2 : 3);\n" + +#: models.py:24 +msgid "created" +msgstr "создано" + +#: models.py:25 +msgid "modified" +msgstr "изменено" + +#: models.py:37 +msgid "start" +msgstr "начало" + +#: models.py:38 +msgid "end" +msgstr "конец" + +#: models.py:53 +msgid "status" +msgstr "статус" + +#: models.py:54 +msgid "status changed" +msgstr "статус изменен" diff --git a/model_utils/locale/zh_Hans/LC_MESSAGES/django.mo b/model_utils/locale/zh_Hans/LC_MESSAGES/django.mo new file mode 100644 index 0000000..6766cc5 Binary files /dev/null and b/model_utils/locale/zh_Hans/LC_MESSAGES/django.mo differ diff --git a/model_utils/locale/zh_Hans/LC_MESSAGES/django.po b/model_utils/locale/zh_Hans/LC_MESSAGES/django.po new file mode 100644 index 0000000..5b132f7 --- /dev/null +++ b/model_utils/locale/zh_Hans/LC_MESSAGES/django.po @@ -0,0 +1,41 @@ +# This file is distributed under the same license as the django-model-utils package. +# +# Translators: +# Zach Cheung , 2018. +msgid "" +msgstr "" +"Project-Id-Version: PACKAGE VERSION\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2018-10-23 15:12+0800\n" +"PO-Revision-Date: 2018-10-23 15:26+0800\n" +"Last-Translator: Zach Cheung \n" +"Language-Team: \n" +"Language: zh_CN\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=UTF-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Plural-Forms: nplurals=1; plural=0;\n" + +#: models.py:24 +msgid "created" +msgstr "创建时间" + +#: models.py:25 +msgid "modified" +msgstr "修改时间" + +#: models.py:37 +msgid "start" +msgstr "开始时间" + +#: models.py:38 +msgid "end" +msgstr "结束时间" + +#: models.py:53 +msgid "status" +msgstr "状态" + +#: models.py:54 +msgid "status changed" +msgstr "状态修改时间" diff --git a/model_utils/managers.py b/model_utils/managers.py index 48911b2..0b4a887 100644 --- a/model_utils/managers.py +++ b/model_utils/managers.py @@ -1,21 +1,56 @@ from __future__ import unicode_literals import django from django.db import models -from django.db.models.fields.related import OneToOneField +from django.db.models.fields.related import OneToOneField, OneToOneRel from django.db.models.query import QuerySet +from django.db.models.query import ModelIterable from django.core.exceptions import ObjectDoesNotExist -try: - from django.db.models.constants import LOOKUP_SEP - from django.utils.six import string_types -except ImportError: # Django < 1.5 - from django.db.models.sql.constants import LOOKUP_SEP - string_types = (basestring,) +from django.db.models.constants import LOOKUP_SEP +from django.utils.six import string_types + +from django.db import connection +from django.db.models.sql.datastructures import Join + + +class InheritanceIterable(ModelIterable): + def __iter__(self): + queryset = self.queryset + iter = ModelIterable(queryset) + if getattr(queryset, 'subclasses', False): + extras = tuple(queryset.query.extra.keys()) + # sort the subclass names longest first, + # so with 'a' and 'a__b' it goes as deep as possible + subclasses = sorted(queryset.subclasses, key=len, reverse=True) + for obj in iter: + sub_obj = None + for s in subclasses: + sub_obj = queryset._get_sub_obj_recurse(obj, s) + if sub_obj: + break + if not sub_obj: + sub_obj = obj + + if getattr(queryset, '_annotated', False): + for k in queryset._annotated: + setattr(sub_obj, k, getattr(obj, k)) + + for k in extras: + setattr(sub_obj, k, getattr(obj, k)) + + yield sub_obj + else: + for obj in iter: + yield obj class InheritanceQuerySetMixin(object): + def __init__(self, *args, **kwargs): + super(InheritanceQuerySetMixin, self).__init__(*args, **kwargs) + self._iterable_class = InheritanceIterable + def select_subclasses(self, *subclasses): - levels = self._get_maximum_depth() + levels = None calculated_subclasses = self._get_subclasses_recurse( self.model, levels=levels) # if none were passed in, we can just short circuit and select all @@ -40,7 +75,7 @@ class InheritanceQuerySetMixin(object): raise ValueError( '%r is not in the discovered subclasses, tried: %s' % ( subclass, ', '.join(calculated_subclasses)) - ) + ) subclasses = verified_subclasses # workaround https://code.djangoproject.com/ticket/16855 @@ -53,21 +88,34 @@ class InheritanceQuerySetMixin(object): new_qs.subclasses = subclasses return new_qs - - def _clone(self, klass=None, setup=False, **kwargs): + def _chain(self, **kwargs): for name in ['subclasses', '_annotated']: if hasattr(self, name): kwargs[name] = getattr(self, name) - return super(InheritanceQuerySetMixin, self)._clone(klass, setup, **kwargs) + return super(InheritanceQuerySetMixin, self)._chain(**kwargs) + + def _clone(self, klass=None, setup=False, **kwargs): + if django.VERSION >= (2, 0): + qs = super(InheritanceQuerySetMixin, self)._clone() + for name in ['subclasses', '_annotated']: + if hasattr(self, name): + setattr(qs, name, getattr(self, name)) + return qs + + for name in ['subclasses', '_annotated']: + if hasattr(self, name): + kwargs[name] = getattr(self, name) + + return super(InheritanceQuerySetMixin, self)._clone(**kwargs) def annotate(self, *args, **kwargs): qset = super(InheritanceQuerySetMixin, self).annotate(*args, **kwargs) qset._annotated = [a.default_alias for a in args] + list(kwargs.keys()) return qset - def iterator(self): + # Maintained for Django 1.8 compatability iter = super(InheritanceQuerySetMixin, self).iterator() if getattr(self, 'subclasses', False): extras = tuple(self.query.extra.keys()) @@ -95,19 +143,24 @@ class InheritanceQuerySetMixin(object): for obj in iter: yield obj - def _get_subclasses_recurse(self, model, levels=None): """ Given a Model class, find all related objects, exploring children recursively, returning a `list` of strings representing the relations for select_related """ + related_objects = [ + f for f in model._meta.get_fields() + if isinstance(f, OneToOneRel)] + rels = [ - rel for rel in model._meta.get_all_related_objects() + rel for rel in related_objects if isinstance(rel.field, OneToOneField) and issubclass(rel.field.model, model) and model is not rel.field.model - ] + and rel.parent_link + ] + subclasses = [] if levels: levels -= 1 @@ -115,11 +168,11 @@ class InheritanceQuerySetMixin(object): if levels or levels is None: for subclass in self._get_subclasses_recurse( rel.field.model, levels=levels): - subclasses.append(rel.get_accessor_name() + LOOKUP_SEP + subclass) + subclasses.append( + rel.get_accessor_name() + LOOKUP_SEP + subclass) subclasses.append(rel.get_accessor_name()) return subclasses - def _get_ancestors_path(self, model, levels=None): """ Serves as an opposite to _get_subclasses_recurse, instead walking from @@ -127,25 +180,33 @@ class InheritanceQuerySetMixin(object): select_related string backwards. """ if not issubclass(model, self.model): - raise ValueError("%r is not a subclass of %r" % (model, self.model)) + raise ValueError( + "%r is not a subclass of %r" % (model, self.model)) ancestry = [] # should be a OneToOneField or None - parent = model._meta.get_ancestor_link(self.model) + parent_link = model._meta.get_ancestor_link(self.model) if levels: levels -= 1 - while parent is not None: - ancestry.insert(0, parent.related.get_accessor_name()) + while parent_link is not None: + related = parent_link.remote_field + ancestry.insert(0, related.get_accessor_name()) if levels or levels is None: - parent = parent.related.parent_model._meta.get_ancestor_link( + parent_model = related.model + parent_link = parent_model._meta.get_ancestor_link( self.model) else: - parent = None + parent_link = None return LOOKUP_SEP.join(ancestry) - def _get_sub_obj_recurse(self, obj, s): rel, _, s = s.partition(LOOKUP_SEP) + + # Django 1.9: If a primitive type gets passed to this recursive function, + # return None as non-models are not part of inheritance. + if not isinstance(obj, models.Model): + return None + try: node = getattr(obj, rel) except ObjectDoesNotExist: @@ -159,16 +220,10 @@ class InheritanceQuerySetMixin(object): def get_subclass(self, *args, **kwargs): return self.select_subclasses().get(*args, **kwargs) - def _get_maximum_depth(self): - """ - Under Django versions < 1.6, to avoid triggering - https://code.djangoproject.com/ticket/16572 we can only look - as far as children. - """ - levels = None - if django.VERSION < (1, 6, 0): - levels = 1 - return levels + +class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): + pass + def instance_of(self, *models): """ @@ -196,12 +251,10 @@ class InheritanceQuerySetMixin(object): return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)]) class InheritanceManagerMixin(object): - use_for_related_fields = True + _queryset_class = InheritanceQuerySet def get_queryset(self): - return InheritanceQuerySet(self.model) - - get_query_set = get_queryset + return self._queryset_class(self.model) def select_subclasses(self, *subclasses): return self.get_queryset().select_subclasses(*subclasses) @@ -212,15 +265,11 @@ class InheritanceManagerMixin(object): def instance_of(self, *models): return self.get_queryset().instance_of(*models) -class InheritanceQuerySet(InheritanceQuerySetMixin, QuerySet): - pass - class InheritanceManager(InheritanceManagerMixin, models.Manager): pass class QueryManagerMixin(object): - use_for_related_fields = True def __init__(self, *args, **kwargs): if args: @@ -235,103 +284,159 @@ class QueryManagerMixin(object): return self def get_queryset(self): - try: - qs = super(QueryManagerMixin, self).get_queryset().filter(self._q) - except AttributeError: - qs = super(QueryManagerMixin, self).get_query_set().filter(self._q) + qs = super(QueryManagerMixin, self).get_queryset().filter(self._q) if self._order_by is not None: return qs.order_by(*self._order_by) return qs - get_query_set = get_queryset - class QueryManager(QueryManagerMixin, models.Manager): pass -class PassThroughManagerMixin(object): +class SoftDeletableQuerySetMixin(object): """ - A mixin that enables you to call custom QuerySet methods from your manager. + QuerySet for SoftDeletableModel. Instead of removing instance sets + its ``is_removed`` field to True. """ - # pickling causes recursion errors - _deny_methods = ['__getstate__', '__setstate__', '__getinitargs__', - '__getnewargs__', '__copy__', '__deepcopy__', '_db', - '__slots__'] - - def __init__(self, queryset_cls=None): - self._queryset_cls = queryset_cls - super(PassThroughManagerMixin, self).__init__() - - def __getattr__(self, name): - if name in self._deny_methods: - raise AttributeError(name) - if django.VERSION < (1, 6, 0): - return getattr(self.get_query_set(), name) - return getattr(self.get_queryset(), name) - - def __dir__(self): + def delete(self): """ - Allow introspection via dir() and ipythonesque tab-discovery. - - We do dir(type(self)) because to do dir(self) would be a recursion - error. - We call dir(self.get_query_set()) because it is possible that the - queryset returned by get_query_set() is interesting, even if - self._queryset_cls is None. + Soft delete objects from queryset (set their ``is_removed`` + field to True) """ - my_values = frozenset(dir(type(self))) - my_values |= frozenset(dir(self.get_query_set())) - return list(my_values) - - def get_queryset(self): - try: - qs = super(PassThroughManagerMixin, self).get_queryset() - except AttributeError: - qs = super(PassThroughManagerMixin, self).get_query_set() - if self._queryset_cls is not None: - qs = qs._clone(klass=self._queryset_cls) - return qs - - get_query_set = get_queryset - - @classmethod - def for_queryset_class(cls, queryset_cls): - return create_pass_through_manager_for_queryset_class(cls, queryset_cls) + self.update(is_removed=True) -class PassThroughManager(PassThroughManagerMixin, models.Manager): - """ - Inherit from this Manager to enable you to call any methods from your - custom QuerySet class from your manager. Simply define your QuerySet - class, and return an instance of it from your manager's `get_queryset` - method. - - Alternately, if you don't need any extra methods on your manager that - aren't on your QuerySet, then just pass your QuerySet class to the - ``for_queryset_class`` class method. - - class PostQuerySet(QuerySet): - def enabled(self): - return self.filter(disabled=False) - - class Post(models.Model): - objects = PassThroughManager.for_queryset_class(PostQuerySet)() - - """ +class SoftDeletableQuerySet(SoftDeletableQuerySetMixin, QuerySet): pass -def create_pass_through_manager_for_queryset_class(base, queryset_cls): - class _PassThroughManager(base): - def __init__(self, *args, **kwargs): - return super(_PassThroughManager, self).__init__(*args, **kwargs) +class SoftDeletableManagerMixin(object): + """ + Manager that limits the queryset by default to show only not removed + instances of model. + """ + _queryset_class = SoftDeletableQuerySet - def get_queryset(self): - qs = super(_PassThroughManager, self).get_queryset() - return qs._clone(klass=queryset_cls) + def get_queryset(self): + """ + Return queryset limited to not removed entries. + """ + kwargs = {'model': self.model, 'using': self._db} + if hasattr(self, '_hints'): + kwargs['hints'] = self._hints - get_query_set = get_queryset + return self._queryset_class(**kwargs).filter(is_removed=False) - return _PassThroughManager + +class SoftDeletableManager(SoftDeletableManagerMixin, models.Manager): + pass + + +class JoinQueryset(models.QuerySet): + + def get_quoted_query(self, query): + query, params = query.sql_with_params() + + # Put additional quotes around string. + params = [ + '\'{}\''.format(p) + if isinstance(p, str) else p + for p in params + ] + + # Cast list of parameters to tuple because I got + # "not enough format characters" otherwise. + params = tuple(params) + return query % params + + def join(self, qs=None): + ''' + Join one queryset together with another using a temporary table. If + no queryset is used, it will use the current queryset and join that + to itself. + + `Join` either uses the current queryset and effectively does a self-join to + create a new limited queryset OR it uses a querset given by the user. + + The model of a given queryset needs to contain a valid foreign key to + the current queryset to perform a join. A new queryset is then created. + ''' + to_field = 'id' + + if qs: + fk = [ + fk for fk in qs.model._meta.fields + if getattr(fk, 'related_model', None) == self.model + ] + fk = fk[0] if fk else None + model_set = '{}_set'.format(self.model.__name__.lower()) + key = fk or getattr(qs.model, model_set, None) + + if not key: + raise ValueError('QuerySet is not related to current model') + + try: + fk_column = key.column + except AttributeError: + fk_column = 'id' + to_field = key.field.column + + qs = qs.only(fk_column) + # if we give a qs we need to keep the model qs to not lose anything + new_qs = self + else: + fk_column = 'id' + qs = self.only(fk_column) + new_qs = self.model.objects.all() + + TABLE_NAME = 'temp_stuff' + query = self.get_quoted_query(qs.query) + sql = ''' + DROP TABLE IF EXISTS {table_name}; + DROP INDEX IF EXISTS {table_name}_id; + CREATE TEMPORARY TABLE {table_name} AS {query}; + CREATE INDEX {table_name}_{fk_column} ON {table_name} ({fk_column}); + '''.format(table_name=TABLE_NAME, fk_column=fk_column, query=str(query)) + + with connection.cursor() as cursor: + cursor.execute(sql) + + class TempModel(models.Model): + temp_key = models.ForeignKey( + self.model, + on_delete=models.DO_NOTHING, + db_column=fk_column, + to_field=to_field + ) + + class Meta: + managed = False + db_table = TABLE_NAME + + conn = Join( + table_name=TempModel._meta.db_table, + parent_alias=new_qs.query.get_initial_alias(), + table_alias=None, + join_type='INNER JOIN', + join_field=self.model.tempmodel_set.rel, + nullable=False + ) + new_qs.query.join(conn, reuse=None) + return new_qs + + +class JoinManagerMixin(object): + """ + Manager that adds a method join. This method allows you to join two + querysets together. + """ + _queryset_class = JoinQueryset + + def get_queryset(self): + return self._queryset_class(model=self.model, using=self._db) + + +class JoinManager(JoinManagerMixin, models.Manager): + pass diff --git a/model_utils/models.py b/model_utils/models.py index 8030bea..6083015 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -1,14 +1,28 @@ from __future__ import unicode_literals -from django.db import models -from django.utils.translation import ugettext_lazy as _ -from django.db.models.fields import FieldDoesNotExist +import django from django.core.exceptions import ImproperlyConfigured -from django.utils.timezone import now +from django.db import models, transaction, router +from django.db.models.signals import post_save, pre_save +from django.utils.translation import ugettext_lazy as _ -from model_utils.managers import QueryManager -from model_utils.fields import AutoCreatedField, AutoLastModifiedField, \ - StatusField, MonitorField +from model_utils.fields import ( + AutoCreatedField, + AutoLastModifiedField, + StatusField, + MonitorField, + UUIDField, +) +from model_utils.managers import ( + QueryManager, + SoftDeletableManager, +) + +if django.VERSION >= (1, 9, 0): + from django.db.models.functions import Now + now = Now() +else: + from django.utils.timezone import now class TimeStampedModel(models.Model): @@ -36,6 +50,7 @@ class TimeFramedModel(models.Model): class Meta: abstract = True + class StatusModel(models.Model): """ An abstract base class model with a ``status`` field that @@ -51,6 +66,7 @@ class StatusModel(models.Model): class Meta: abstract = True + def add_status_query_managers(sender, **kwargs): """ Add a Querymanager for each status item dynamically. @@ -58,17 +74,25 @@ def add_status_query_managers(sender, **kwargs): """ if not issubclass(sender, StatusModel): return + + if django.VERSION >= (1, 10): + # First, get current manager name... + default_manager = sender._meta.default_manager + for value, display in getattr(sender, 'STATUS', ()): - try: - sender._meta.get_field(value) - raise ImproperlyConfigured("StatusModel: Model '%s' has a field " - "named '%s' which conflicts with a " - "status of the same name." - % (sender.__name__, value)) - except FieldDoesNotExist: - pass + if _field_exists(sender, value): + raise ImproperlyConfigured( + "StatusModel: Model '%s' has a field named '%s' which " + "conflicts with a status of the same name." + % (sender.__name__, value) + ) sender.add_to_class(value, QueryManager(status=value)) + if django.VERSION >= (1, 10): + # ...then, put it back, as add_to_class is modifying the default manager! + sender._meta.default_manager_name = default_manager.name + + def add_timeframed_query_manager(sender, **kwargs): """ Add a QueryManager for a specific timeframe. @@ -76,19 +100,120 @@ def add_timeframed_query_manager(sender, **kwargs): """ if not issubclass(sender, TimeFramedModel): return - try: - sender._meta.get_field('timeframed') - raise ImproperlyConfigured("Model '%s' has a field named " - "'timeframed' which conflicts with " - "the TimeFramedModel manager." - % sender.__name__) - except FieldDoesNotExist: - pass + if _field_exists(sender, 'timeframed'): + raise ImproperlyConfigured( + "Model '%s' has a field named 'timeframed' " + "which conflicts with the TimeFramedModel manager." + % sender.__name__ + ) sender.add_to_class('timeframed', QueryManager( - (models.Q(start__lte=now) | models.Q(start__isnull=True)) & - (models.Q(end__gte=now) | models.Q(end__isnull=True)) + (models.Q(start__lte=now) | models.Q(start__isnull=True)) + & (models.Q(end__gte=now) | models.Q(end__isnull=True)) )) models.signals.class_prepared.connect(add_status_query_managers) models.signals.class_prepared.connect(add_timeframed_query_manager) + + +def _field_exists(model_class, field_name): + return field_name in [f.attname for f in model_class._meta.local_fields] + + +class SoftDeletableModel(models.Model): + """ + An abstract base class model with a ``is_removed`` field that + marks entries that are not going to be used anymore, but are + kept in db for any reason. + Default manager returns only not-removed entries. + """ + is_removed = models.BooleanField(default=False) + + class Meta: + abstract = True + + objects = SoftDeletableManager() + all_objects = models.Manager() + + def delete(self, using=None, soft=True, *args, **kwargs): + """ + Soft delete object (set its ``is_removed`` field to True). + Actually delete object if setting ``soft`` to False. + """ + if soft: + self.is_removed = True + self.save(using=using) + else: + return super(SoftDeletableModel, self).delete(using=using, *args, **kwargs) + + +class UUIDModel(models.Model): + """ + This abstract base class provides id field on any model that inherits from it + which will be the primary key. + """ + id = UUIDField( + primary_key=True, + version=4, + editable=False, + ) + + class Meta: + abstract = True + + +class SaveSignalHandlingModel(models.Model): + """ + An abstract base class model to pass a parameter ``signals_to_disable`` + to ``save`` method in order to disable signals + """ + class Meta: + abstract = True + + def save(self, signals_to_disable=None, *args, **kwargs): + """ + Add an extra parameters to hold which signals to disable + If empty, nothing will change + """ + + self.signals_to_disable = signals_to_disable or [] + + super(SaveSignalHandlingModel, self).save(*args, **kwargs) + + def save_base(self, raw=False, force_insert=False, + force_update=False, using=None, update_fields=None): + """ + Copied from base class for a minor change. + This is an ugly overwriting but since Django's ``save_base`` method + does not differ between versions 1.8 and 1.10, + that way of implementing wouldn't harm the flow + """ + using = using or router.db_for_write(self.__class__, instance=self) + assert not (force_insert and (force_update or update_fields)) + assert update_fields is None or len(update_fields) > 0 + cls = origin = self.__class__ + + if cls._meta.proxy: + cls = cls._meta.concrete_model + meta = cls._meta + if not meta.auto_created and 'pre_save' not in self.signals_to_disable: + pre_save.send( + sender=origin, instance=self, raw=raw, using=using, + update_fields=update_fields, + ) + with transaction.atomic(using=using, savepoint=False): + if not raw: + self._save_parents(cls, using, update_fields) + updated = self._save_table(raw, cls, force_insert, force_update, using, update_fields) + + self._state.db = using + self._state.adding = False + + if not meta.auto_created and 'post_save' not in self.signals_to_disable: + post_save.send( + sender=origin, instance=self, created=(not updated), + update_fields=update_fields, raw=raw, using=using, + ) + + # Empty the signals in case it might be used somewhere else in future + self.signals_to_disable = [] diff --git a/model_utils/tests/fields.py b/model_utils/tests/fields.py deleted file mode 100644 index 3f1503a..0000000 --- a/model_utils/tests/fields.py +++ /dev/null @@ -1,26 +0,0 @@ -from django.db import models -from django.utils.six import with_metaclass, string_types - - -class MutableField(with_metaclass(models.SubfieldBase, models.TextField)): - - def to_python(self, value): - if value == '': - return None - - try: - if isinstance(value, string_types): - return [int(i) for i in value.split(',')] - except ValueError: - pass - - return value - - def get_db_prep_save(self, value, connection): - if value is None: - return '' - - if isinstance(value, list): - value = ','.join((str(i) for i in value)) - - return super(MutableField, self).get_db_prep_save(value, connection) diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py deleted file mode 100644 index f887b57..0000000 --- a/model_utils/tests/tests.py +++ /dev/null @@ -1,1971 +0,0 @@ -from __future__ import unicode_literals - -from datetime import datetime, timedelta -import pickle -try: - from unittest import skipUnless -except ImportError: # Python 2.6 - from django.utils.unittest import skipUnless - -import django -from django.db import models -from django.db.models.fields import FieldDoesNotExist -from django.utils.six import text_type -from django.core.exceptions import ImproperlyConfigured, FieldError -from django.core.management import call_command -from django.test import TestCase - -from model_utils import Choices, FieldTracker -from model_utils.fields import get_excerpt, MonitorField, StatusField -from model_utils.managers import QueryManager -from model_utils.models import StatusModel, TimeFramedModel -from model_utils.tests.models import ( - InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1, - InheritanceManagerTestGrandChild1_2, - InheritanceManagerTestParent, InheritanceManagerTestChild1, - InheritanceManagerTestChild2, TimeStamp, Post, Article, Status, - StatusPlainTuple, TimeFrame, Monitored, MonitorWhen, MonitorWhenEmpty, StatusManagerAdded, - TimeFrameManagerAdded, Dude, SplitFieldAbstractParent, Car, Spot, - ModelTracked, ModelTrackedFK, ModelTrackedNotDefault, ModelTrackedMultiple, InheritedModelTracked, - Tracked, TrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple, - InheritedTracked, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled, - InheritanceManagerTestChild3, StatusFieldChoicesName) - - -class MigrationsTests(TestCase): - @skipUnless(django.VERSION >= (1, 7, 0), "test only applies to Django 1.7+") - def test_makemigrations(self): - call_command('makemigrations', dry_run=True) - - -class GetExcerptTests(TestCase): - def test_split(self): - e = get_excerpt("some content\n\n\n\nsome more") - self.assertEqual(e, 'some content\n') - - - def test_auto_split(self): - e = get_excerpt("para one\n\npara two\n\npara three") - self.assertEqual(e, 'para one\n\npara two') - - - def test_middle_of_para(self): - e = get_excerpt("some text\n\nmore text") - self.assertEqual(e, 'some text') - - - def test_middle_of_line(self): - e = get_excerpt("some text more text") - self.assertEqual(e, "some text more text") - - - -class SplitFieldTests(TestCase): - full_text = 'summary\n\n\n\nmore' - excerpt = 'summary\n' - - - def setUp(self): - self.post = Article.objects.create( - title='example post', body=self.full_text) - - - def test_unicode_content(self): - self.assertEqual(text_type(self.post.body), self.full_text) - - - def test_excerpt(self): - self.assertEqual(self.post.body.excerpt, self.excerpt) - - - def test_content(self): - self.assertEqual(self.post.body.content, self.full_text) - - - def test_has_more(self): - self.assertTrue(self.post.body.has_more) - - - def test_not_has_more(self): - post = Article.objects.create(title='example 2', - body='some text\n\nsome more\n') - self.assertFalse(post.body.has_more) - - - def test_load_back(self): - post = Article.objects.get(pk=self.post.pk) - self.assertEqual(post.body.content, self.post.body.content) - self.assertEqual(post.body.excerpt, self.post.body.excerpt) - - - def test_assign_to_body(self): - new_text = 'different\n\n\n\nother' - self.post.body = new_text - self.post.save() - self.assertEqual(text_type(self.post.body), new_text) - - - def test_assign_to_content(self): - new_text = 'different\n\n\n\nother' - self.post.body.content = new_text - self.post.save() - self.assertEqual(text_type(self.post.body), new_text) - - - def test_assign_to_excerpt(self): - with self.assertRaises(AttributeError): - self.post.body.excerpt = 'this should fail' - - - def test_access_via_class(self): - with self.assertRaises(AttributeError): - Article.body - - - def test_none(self): - a = Article(title='Some Title', body=None) - self.assertEqual(a.body, None) - - - def test_assign_splittext(self): - a = Article(title='Some Title') - a.body = self.post.body - self.assertEqual(a.body.excerpt, 'summary\n') - - - def test_value_to_string(self): - f = self.post._meta.get_field('body') - self.assertEqual(f.value_to_string(self.post), self.full_text) - - - def test_abstract_inheritance(self): - class Child(SplitFieldAbstractParent): - pass - - self.assertEqual( - [f.name for f in Child._meta.fields], - ["id", "content", "_content_excerpt"]) - - - -class MonitorFieldTests(TestCase): - def setUp(self): - self.instance = Monitored(name='Charlie') - self.created = self.instance.name_changed - - - def test_save_no_change(self): - self.instance.save() - self.assertEqual(self.instance.name_changed, self.created) - - - def test_save_changed(self): - self.instance.name = 'Maria' - self.instance.save() - self.assertTrue(self.instance.name_changed > self.created) - - - def test_double_save(self): - self.instance.name = 'Jose' - self.instance.save() - changed = self.instance.name_changed - self.instance.save() - self.assertEqual(self.instance.name_changed, changed) - - - def test_no_monitor_arg(self): - with self.assertRaises(TypeError): - MonitorField() - - - -class MonitorWhenFieldTests(TestCase): - """ - Will record changes only when name is 'Jose' or 'Maria' - """ - def setUp(self): - self.instance = MonitorWhen(name='Charlie') - self.created = self.instance.name_changed - - - def test_save_no_change(self): - self.instance.save() - self.assertEqual(self.instance.name_changed, self.created) - - - def test_save_changed_to_Jose(self): - self.instance.name = 'Jose' - self.instance.save() - self.assertTrue(self.instance.name_changed > self.created) - - - def test_save_changed_to_Maria(self): - self.instance.name = 'Maria' - self.instance.save() - self.assertTrue(self.instance.name_changed > self.created) - - - def test_save_changed_to_Pedro(self): - self.instance.name = 'Pedro' - self.instance.save() - self.assertEqual(self.instance.name_changed, self.created) - - - def test_double_save(self): - self.instance.name = 'Jose' - self.instance.save() - changed = self.instance.name_changed - self.instance.save() - self.assertEqual(self.instance.name_changed, changed) - - - -class MonitorWhenEmptyFieldTests(TestCase): - """ - Monitor should never be updated id when is an empty list. - """ - def setUp(self): - self.instance = MonitorWhenEmpty(name='Charlie') - self.created = self.instance.name_changed - - - def test_save_no_change(self): - self.instance.save() - self.assertEqual(self.instance.name_changed, self.created) - - - def test_save_changed_to_Jose(self): - self.instance.name = 'Jose' - self.instance.save() - self.assertEqual(self.instance.name_changed, self.created) - - - def test_save_changed_to_Maria(self): - self.instance.name = 'Maria' - self.instance.save() - self.assertEqual(self.instance.name_changed, self.created) - - - -class StatusFieldTests(TestCase): - - def test_status_with_default_filled(self): - instance = StatusFieldDefaultFilled() - self.assertEqual(instance.status, instance.STATUS.yes) - - def test_status_with_default_not_filled(self): - instance = StatusFieldDefaultNotFilled() - self.assertEqual(instance.status, instance.STATUS.no) - - def test_no_check_for_status(self): - field = StatusField(no_check_for_status=True) - # this model has no STATUS attribute, so checking for it would error - field.prepare_class(Article) - - def test_get_status_display(self): - instance = StatusFieldDefaultFilled() - self.assertEqual(instance.get_status_display(), "Yes") - - def test_choices_name(self): - StatusFieldChoicesName() - - -class ChoicesTests(TestCase): - def setUp(self): - self.STATUS = Choices('DRAFT', 'PUBLISHED') - - - def test_getattr(self): - self.assertEqual(self.STATUS.DRAFT, 'DRAFT') - - - def test_indexing(self): - self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') - - - def test_iteration(self): - self.assertEqual(tuple(self.STATUS), (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) - - - def test_len(self): - self.assertEqual(len(self.STATUS), 2) - - - def test_repr(self): - self.assertEqual(repr(self.STATUS), "Choices" + repr(( - ('DRAFT', 'DRAFT', 'DRAFT'), - ('PUBLISHED', 'PUBLISHED', 'PUBLISHED'), - ))) - - - def test_wrong_length_tuple(self): - with self.assertRaises(ValueError): - Choices(('a',)) - - - def test_contains_value(self): - self.assertTrue('PUBLISHED' in self.STATUS) - self.assertTrue('DRAFT' in self.STATUS) - - - def test_doesnt_contain_value(self): - self.assertFalse('UNPUBLISHED' in self.STATUS) - - def test_deepcopy(self): - import copy - self.assertEqual(list(self.STATUS), - list(copy.deepcopy(self.STATUS))) - - - def test_equality(self): - self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED')) - - - def test_inequality(self): - self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED']) - self.assertNotEqual(self.STATUS, Choices('DRAFT')) - - - def test_composability(self): - self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS) - self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS) - self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS) - - - def test_option_groups(self): - c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) - self.assertEqual( - list(c), - [ - ('group a', [('one', 'one'), ('two', 'two')]), - ('group b', [('three', 'three')]), - ], - ) - - -class LabelChoicesTests(ChoicesTests): - def setUp(self): - self.STATUS = Choices( - ('DRAFT', 'is draft'), - ('PUBLISHED', 'is published'), - 'DELETED', - ) - - - def test_iteration(self): - self.assertEqual(tuple(self.STATUS), ( - ('DRAFT', 'is draft'), - ('PUBLISHED', 'is published'), - ('DELETED', 'DELETED')) - ) - - - def test_indexing(self): - self.assertEqual(self.STATUS['PUBLISHED'], 'is published') - - - def test_default(self): - self.assertEqual(self.STATUS.DELETED, 'DELETED') - - - def test_provided(self): - self.assertEqual(self.STATUS.DRAFT, 'DRAFT') - - - def test_len(self): - self.assertEqual(len(self.STATUS), 3) - - - def test_equality(self): - self.assertEqual(self.STATUS, Choices( - ('DRAFT', 'is draft'), - ('PUBLISHED', 'is published'), - 'DELETED', - )) - - - def test_inequality(self): - self.assertNotEqual(self.STATUS, [ - ('DRAFT', 'is draft'), - ('PUBLISHED', 'is published'), - 'DELETED' - ]) - self.assertNotEqual(self.STATUS, Choices('DRAFT')) - - - def test_repr(self): - self.assertEqual(repr(self.STATUS), "Choices" + repr(( - ('DRAFT', 'DRAFT', 'is draft'), - ('PUBLISHED', 'PUBLISHED', 'is published'), - ('DELETED', 'DELETED', 'DELETED'), - ))) - - - def test_contains_value(self): - self.assertTrue('PUBLISHED' in self.STATUS) - self.assertTrue('DRAFT' in self.STATUS) - # This should be True, because both the display value - # and the internal representation are both DELETED. - self.assertTrue('DELETED' in self.STATUS) - - - def test_doesnt_contain_value(self): - self.assertFalse('UNPUBLISHED' in self.STATUS) - - - def test_doesnt_contain_display_value(self): - self.assertFalse('is draft' in self.STATUS) - - - def test_composability(self): - self.assertEqual( - Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'), - self.STATUS - ) - - self.assertEqual( - (('DRAFT', 'is draft',),) + Choices(('PUBLISHED', 'is published'), 'DELETED'), - self.STATUS - ) - - self.assertEqual( - Choices(('DRAFT', 'is draft',)) + (('PUBLISHED', 'is published'), 'DELETED'), - self.STATUS - ) - - - def test_option_groups(self): - c = Choices( - ('group a', [(1, 'one'), (2, 'two')]), - ['group b', ((3, 'three'),)] - ) - self.assertEqual( - list(c), - [ - ('group a', [(1, 'one'), (2, 'two')]), - ('group b', [(3, 'three')]), - ], - ) - - - -class IdentifierChoicesTests(ChoicesTests): - def setUp(self): - self.STATUS = Choices( - (0, 'DRAFT', 'is draft'), - (1, 'PUBLISHED', 'is published'), - (2, 'DELETED', 'is deleted')) - - - def test_iteration(self): - self.assertEqual(tuple(self.STATUS), ( - (0, 'is draft'), - (1, 'is published'), - (2, 'is deleted'))) - - - def test_indexing(self): - self.assertEqual(self.STATUS[1], 'is published') - - - def test_getattr(self): - self.assertEqual(self.STATUS.DRAFT, 0) - - - def test_len(self): - self.assertEqual(len(self.STATUS), 3) - - - def test_repr(self): - self.assertEqual(repr(self.STATUS), "Choices" + repr(( - (0, 'DRAFT', 'is draft'), - (1, 'PUBLISHED', 'is published'), - (2, 'DELETED', 'is deleted'), - ))) - - - def test_contains_value(self): - self.assertTrue(0 in self.STATUS) - self.assertTrue(1 in self.STATUS) - self.assertTrue(2 in self.STATUS) - - - def test_doesnt_contain_value(self): - self.assertFalse(3 in self.STATUS) - - - def test_doesnt_contain_display_value(self): - self.assertFalse('is draft' in self.STATUS) - - - def test_doesnt_contain_python_attr(self): - self.assertFalse('PUBLISHED' in self.STATUS) - - - def test_equality(self): - self.assertEqual(self.STATUS, Choices( - (0, 'DRAFT', 'is draft'), - (1, 'PUBLISHED', 'is published'), - (2, 'DELETED', 'is deleted') - )) - - - def test_inequality(self): - self.assertNotEqual(self.STATUS, [ - (0, 'DRAFT', 'is draft'), - (1, 'PUBLISHED', 'is published'), - (2, 'DELETED', 'is deleted') - ]) - self.assertNotEqual(self.STATUS, Choices('DRAFT')) - - - def test_composability(self): - self.assertEqual( - Choices( - (0, 'DRAFT', 'is draft'), - (1, 'PUBLISHED', 'is published') - ) + Choices( - (2, 'DELETED', 'is deleted'), - ), - self.STATUS - ) - - self.assertEqual( - Choices( - (0, 'DRAFT', 'is draft'), - (1, 'PUBLISHED', 'is published') - ) + ( - (2, 'DELETED', 'is deleted'), - ), - self.STATUS - ) - - self.assertEqual( - ( - (0, 'DRAFT', 'is draft'), - (1, 'PUBLISHED', 'is published') - ) + Choices( - (2, 'DELETED', 'is deleted'), - ), - self.STATUS - ) - - - def test_option_groups(self): - c = Choices( - ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), - ['group b', ((3, 'THREE', 'three'),)] - ) - self.assertEqual( - list(c), - [ - ('group a', [(1, 'one'), (2, 'two')]), - ('group b', [(3, 'three')]), - ], - ) - - -class InheritanceManagerTests(TestCase): - def setUp(self): - self.child1 = InheritanceManagerTestChild1.objects.create() - self.child2 = InheritanceManagerTestChild2.objects.create() - self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() - self.grandchild1_2 = \ - InheritanceManagerTestGrandChild1_2.objects.create() - - - def get_manager(self): - return InheritanceManagerTestParent.objects - - - def test_normal(self): - children = set([ - InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestParent(pk=self.child2.pk), - InheritanceManagerTestParent(pk=self.grandchild1.pk), - InheritanceManagerTestParent(pk=self.grandchild1_2.pk), - ]) - self.assertEqual(set(self.get_manager().all()), children) - - - def test_select_all_subclasses(self): - children = set([self.child1, self.child2]) - if django.VERSION >= (1, 6, 0): - children.add(self.grandchild1) - children.add(self.grandchild1_2) - else: - children.add(InheritanceManagerTestChild1(pk=self.grandchild1.pk)) - children.add(InheritanceManagerTestChild1(pk=self.grandchild1_2.pk)) - self.assertEqual( - set(self.get_manager().select_subclasses()), children) - - - def test_select_subclasses_invalid_relation(self): - """ - If an invalid relation string is provided, we can provide the user - with a list which is valid, rather than just have the select_related() - raise an AttributeError further in. - """ - regex = '^.+? is not in the discovered subclasses, tried:.+$' - with self.assertRaisesRegexp(ValueError, regex): - self.get_manager().select_subclasses('user') - - - def test_select_specific_subclasses(self): - children = set([ - self.child1, - InheritanceManagerTestParent(pk=self.child2.pk), - InheritanceManagerTestChild1(pk=self.grandchild1.pk), - InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), - ]) - self.assertEqual( - set( - self.get_manager().select_subclasses( - "inheritancemanagertestchild1") - ), - children, - ) - - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_select_specific_grandchildren(self): - children = set([ - InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestParent(pk=self.child2.pk), - self.grandchild1, - InheritanceManagerTestParent(pk=self.grandchild1_2.pk), - ]) - self.assertEqual( - set( - self.get_manager().select_subclasses( - "inheritancemanagertestchild1__inheritancemanagertestgrandchild1" - ) - ), - children, - ) - - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_children_and_grandchildren(self): - children = set([ - self.child1, - InheritanceManagerTestParent(pk=self.child2.pk), - self.grandchild1, - InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), - ]) - self.assertEqual( - set( - self.get_manager().select_subclasses( - "inheritancemanagertestchild1", - "inheritancemanagertestchild1__inheritancemanagertestgrandchild1" - ) - ), - children, - ) - - - def test_get_subclass(self): - self.assertEqual( - self.get_manager().get_subclass(pk=self.child1.pk), - self.child1) - - - def test_get_subclass_on_queryset(self): - self.assertEqual( - self.get_manager().all().get_subclass(pk=self.child1.pk), - self.child1) - - - def test_prior_select_related(self): - with self.assertNumQueries(1): - obj = self.get_manager().select_related( - "inheritancemanagertestchild1").select_subclasses( - "inheritancemanagertestchild2").get(pk=self.child1.pk) - obj.inheritancemanagertestchild1 - - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_version_determining_any_depth(self): - self.assertIsNone(self.get_manager().all()._get_maximum_depth()) - - - @skipUnless(django.VERSION < (1, 6, 0), "test only applies to Django < 1.6") - def test_version_determining_only_child_depth(self): - self.assertEqual(1, self.get_manager().all()._get_maximum_depth()) - - - @skipUnless(django.VERSION < (1, 6, 0), "test only applies to Django < 1.6") - def test_manually_specifying_parent_fk_only_children(self): - """ - given a Model which inherits from another Model, but also declares - the OneToOne link manually using `related_name` and `parent_link`, - ensure that the relation names and subclasses are obtained correctly. - """ - child3 = InheritanceManagerTestChild3.objects.create() - results = InheritanceManagerTestParent.objects.all().select_subclasses() - - expected_objs = [self.child1, self.child2, - InheritanceManagerTestChild1(pk=self.grandchild1.pk), - InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), - child3] - self.assertEqual(list(results), expected_objs) - - expected_related_names = [ - 'inheritancemanagertestchild1', - 'inheritancemanagertestchild2', - 'manual_onetoone', # this was set via parent_link & related_name - ] - self.assertEqual(set(results.subclasses), - set(expected_related_names)) - - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_manually_specifying_parent_fk_including_grandchildren(self): - """ - given a Model which inherits from another Model, but also declares - the OneToOne link manually using `related_name` and `parent_link`, - ensure that the relation names and subclasses are obtained correctly. - """ - child3 = InheritanceManagerTestChild3.objects.create() - results = InheritanceManagerTestParent.objects.all().select_subclasses() - - expected_objs = [self.child1, self.child2, self.grandchild1, - self.grandchild1_2, child3] - self.assertEqual(list(results), expected_objs) - - expected_related_names = [ - 'inheritancemanagertestchild1__inheritancemanagertestgrandchild1', - 'inheritancemanagertestchild1__inheritancemanagertestgrandchild1_2', - 'inheritancemanagertestchild1', - 'inheritancemanagertestchild2', - 'manual_onetoone', # this was set via parent_link & related_name - ] - self.assertEqual(set(results.subclasses), - set(expected_related_names)) - - - def test_manually_specifying_parent_fk_single_subclass(self): - """ - Using a string related_name when the relation is manually defined - instead of implicit should still work in the same way. - """ - related_name = 'manual_onetoone' - child3 = InheritanceManagerTestChild3.objects.create() - results = InheritanceManagerTestParent.objects.all().select_subclasses(related_name) - - expected_objs = [InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestParent(pk=self.child2.pk), - InheritanceManagerTestParent(pk=self.grandchild1.pk), - InheritanceManagerTestParent(pk=self.grandchild1_2.pk), - child3] - self.assertEqual(list(results), expected_objs) - expected_related_names = [related_name] - self.assertEqual(set(results.subclasses), - set(expected_related_names)) - - -class InheritanceManagerUsingModelsTests(TestCase): - - def setUp(self): - self.parent1 = InheritanceManagerTestParent.objects.create() - self.child1 = InheritanceManagerTestChild1.objects.create() - self.child2 = InheritanceManagerTestChild2.objects.create() - self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() - self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create() - - - def test_select_subclass_by_child_model(self): - """ - Confirm that passing a child model works the same as passing the - select_related manually - """ - objs = InheritanceManagerTestParent.objects.select_subclasses( - "inheritancemanagertestchild1").order_by('pk') - objsmodels = InheritanceManagerTestParent.objects.select_subclasses( - InheritanceManagerTestChild1).order_by('pk') - self.assertEqual(objs.subclasses, objsmodels.subclasses) - self.assertEqual(list(objs), list(objsmodels)) - - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_select_subclass_by_grandchild_model(self): - """ - Confirm that passing a grandchild model works the same as passing the - select_related manually - """ - objs = InheritanceManagerTestParent.objects.select_subclasses( - "inheritancemanagertestchild1__inheritancemanagertestgrandchild1")\ - .order_by('pk') - objsmodels = InheritanceManagerTestParent.objects.select_subclasses( - InheritanceManagerTestGrandChild1).order_by('pk') - self.assertEqual(objs.subclasses, objsmodels.subclasses) - self.assertEqual(list(objs), list(objsmodels)) - - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_selecting_all_subclasses_specifically_grandchildren(self): - """ - A bare select_subclasses() should achieve the same results as doing - select_subclasses and specifying all possible subclasses. - This test checks grandchildren, so only works on 1.6>= - """ - objs = InheritanceManagerTestParent.objects.select_subclasses().order_by('pk') - objsmodels = InheritanceManagerTestParent.objects.select_subclasses( - InheritanceManagerTestChild1, InheritanceManagerTestChild2, - InheritanceManagerTestChild3, - InheritanceManagerTestGrandChild1, - InheritanceManagerTestGrandChild1_2).order_by('pk') - self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) - self.assertEqual(list(objs), list(objsmodels)) - - - def test_selecting_all_subclasses_specifically_children(self): - """ - A bare select_subclasses() should achieve the same results as doing - select_subclasses and specifying all possible subclasses. - - Note: This is sort of the same test as - `test_selecting_all_subclasses_specifically_grandchildren` but it - specifically switches what models are used because that happens - behind the scenes in a bare select_subclasses(), so we need to - emulate it. - """ - objs = InheritanceManagerTestParent.objects.select_subclasses().order_by('pk') - - if django.VERSION >= (1, 6, 0): - models = (InheritanceManagerTestChild1, - InheritanceManagerTestChild2, - InheritanceManagerTestChild3, - InheritanceManagerTestGrandChild1, - InheritanceManagerTestGrandChild1_2) - else: - models = (InheritanceManagerTestChild1, - InheritanceManagerTestChild2, - InheritanceManagerTestChild3) - - objsmodels = InheritanceManagerTestParent.objects.select_subclasses( - *models).order_by('pk') - # order shouldn't matter, I don't think, as long as the resulting - # queryset (when cast to a list) is the same. - self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) - self.assertEqual(list(objs), list(objsmodels)) - - - def test_select_subclass_just_self(self): - """ - Passing in the same model as the manager/queryset is bound against - (ie: the root parent) should have no effect on the result set. - """ - objsmodels = InheritanceManagerTestParent.objects.select_subclasses( - InheritanceManagerTestParent).order_by('pk') - self.assertEqual([], objsmodels.subclasses) - self.assertEqual(list(objsmodels), [ - InheritanceManagerTestParent(pk=self.parent1.pk), - InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestParent(pk=self.child2.pk), - InheritanceManagerTestParent(pk=self.grandchild1.pk), - InheritanceManagerTestParent(pk=self.grandchild1_2.pk), - ]) - - - def test_select_subclass_invalid_related_model(self): - """ - Confirming that giving a stupid model doesn't work. - """ - from django.contrib.auth.models import User - regex = '^.+? is not a subclass of .+$' - with self.assertRaisesRegexp(ValueError, regex): - InheritanceManagerTestParent.objects.select_subclasses( - User).order_by('pk') - - - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_mixing_strings_and_classes_with_grandchildren(self): - """ - Given arguments consisting of both strings and model classes, - ensure the right resolutions take place, accounting for the extra - depth (grandchildren etc) 1.6> allows. - """ - objs = InheritanceManagerTestParent.objects.select_subclasses( - "inheritancemanagertestchild2", - InheritanceManagerTestGrandChild1_2).order_by('pk') - expecting = ['inheritancemanagertestchild1__inheritancemanagertestgrandchild1_2', - 'inheritancemanagertestchild2'] - self.assertEqual(set(objs.subclasses), set(expecting)) - expecting2 = [ - InheritanceManagerTestParent(pk=self.parent1.pk), - InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestChild2(pk=self.child2.pk), - InheritanceManagerTestParent(pk=self.grandchild1.pk), - InheritanceManagerTestGrandChild1_2(pk=self.grandchild1_2.pk), - ] - self.assertEqual(list(objs), expecting2) - - - def test_mixing_strings_and_classes_with_children(self): - """ - Given arguments consisting of both strings and model classes, - ensure the right resolutions take place, walking down as far as - children. - """ - objs = InheritanceManagerTestParent.objects.select_subclasses( - "inheritancemanagertestchild2", - InheritanceManagerTestChild1).order_by('pk') - expecting = ['inheritancemanagertestchild1', - 'inheritancemanagertestchild2'] - - self.assertEqual(set(objs.subclasses), set(expecting)) - expecting2 = [ - InheritanceManagerTestParent(pk=self.parent1.pk), - InheritanceManagerTestChild1(pk=self.child1.pk), - InheritanceManagerTestChild2(pk=self.child2.pk), - InheritanceManagerTestChild1(pk=self.grandchild1.pk), - InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), - ] - self.assertEqual(list(objs), expecting2) - - - def test_duplications(self): - """ - Check that even if the same thing is provided as a string and a model - that the right results are retrieved. - """ - # mixing strings and models which evaluate to the same thing is fine. - objs = InheritanceManagerTestParent.objects.select_subclasses( - "inheritancemanagertestchild2", - InheritanceManagerTestChild2).order_by('pk') - self.assertEqual(list(objs), [ - InheritanceManagerTestParent(pk=self.parent1.pk), - InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestChild2(pk=self.child2.pk), - InheritanceManagerTestParent(pk=self.grandchild1.pk), - InheritanceManagerTestParent(pk=self.grandchild1_2.pk), - ]) - - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_child_doesnt_accidentally_get_parent(self): - """ - Given a Child model which also has an InheritanceManager, - none of the returned objects should be Parent objects. - """ - objs = InheritanceManagerTestChild1.objects.select_subclasses( - InheritanceManagerTestGrandChild1).order_by('pk') - self.assertEqual([ - InheritanceManagerTestChild1(pk=self.child1.pk), - InheritanceManagerTestGrandChild1(pk=self.grandchild1.pk), - InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), - ], list(objs)) - - - def test_manually_specifying_parent_fk_only_specific_child(self): - """ - given a Model which inherits from another Model, but also declares - the OneToOne link manually using `related_name` and `parent_link`, - ensure that the relation names and subclasses are obtained correctly. - """ - child3 = InheritanceManagerTestChild3.objects.create() - results = InheritanceManagerTestParent.objects.all().select_subclasses( - InheritanceManagerTestChild3) - - expected_objs = [InheritanceManagerTestParent(pk=self.parent1.pk), - InheritanceManagerTestParent(pk=self.child1.pk), - InheritanceManagerTestParent(pk=self.child2.pk), - InheritanceManagerTestParent(pk=self.grandchild1.pk), - InheritanceManagerTestParent(pk=self.grandchild1_2.pk), - child3] - self.assertEqual(list(results), expected_objs) - - expected_related_names = ['manual_onetoone'] - self.assertEqual(set(results.subclasses), - set(expected_related_names)) - - def test_extras_descend(self): - """ - Ensure that extra(select=) values are copied onto sub-classes. - """ - results = InheritanceManagerTestParent.objects.select_subclasses().extra( - select={'foo': 'id + 1'} - ) - self.assertTrue(all(result.foo == (result.id + 1) for result in results)) - - def test_limit_to_specific_subclass(self): - child3 = InheritanceManagerTestChild3.objects.create() - results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3) - - self.assertEqual([child3], list(results)) - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_limit_to_specific_grandchild_class(self): - grandchild1 = InheritanceManagerTestGrandChild1.objects.get() - results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1) - - self.assertEqual([grandchild1], list(results)) - - def test_limit_to_child_fetches_grandchildren_as_child_class(self): - # Not sure if this is the desired behaviour...? - children = InheritanceManagerTestChild1.objects.all() - - results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild1) - - self.assertEqual(set(children), set(results)) - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_can_fetch_limited_class_grandchildren(self): - # Not sure if this is the desired behaviour...? - children = InheritanceManagerTestChild1.objects.select_subclasses() - - results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild1).select_subclasses() - - self.assertEqual(set(children), set(results)) - - def test_selecting_multiple_instance_classes(self): - child3 = InheritanceManagerTestChild3.objects.create() - children1 = InheritanceManagerTestChild1.objects.all() - - results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3, InheritanceManagerTestChild1) - - self.assertEqual(set([child3] + list(children1)), set(results)) - - @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") - def test_selecting_multiple_instance_classes_including_grandchildren(self): - child3 = InheritanceManagerTestChild3.objects.create() - grandchild1 = InheritanceManagerTestGrandChild1.objects.get() - - results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3, InheritanceManagerTestGrandChild1).select_subclasses() - - self.assertEqual(set([child3, grandchild1]), set(results)) - - def test_select_subclasses_interaction_with_instance_of(self): - child3 = InheritanceManagerTestChild3.objects.create() - - results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3) - - self.assertEqual(set([child3]), set(results)) - - -class InheritanceManagerRelatedTests(InheritanceManagerTests): - def setUp(self): - self.related = InheritanceManagerTestRelated.objects.create() - self.child1 = InheritanceManagerTestChild1.objects.create( - related=self.related) - self.child2 = InheritanceManagerTestChild2.objects.create( - related=self.related) - self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related) - self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create(related=self.related) - - - def get_manager(self): - return self.related.imtests - - - def test_get_method_with_select_subclasses(self): - self.assertEqual( - InheritanceManagerTestParent.objects.select_subclasses().get( - id=self.child1.id), - self.child1) - - - def test_annotate_with_select_subclasses(self): - qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( - models.Count('id')) - self.assertEqual(qs.get(id=self.child1.id).id__count, 1) - - - def test_annotate_with_named_arguments_with_select_subclasses(self): - qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( - test_count=models.Count('id')) - self.assertEqual(qs.get(id=self.child1.id).test_count, 1) - - - def test_annotate_before_select_subclasses(self): - qs = InheritanceManagerTestParent.objects.annotate( - models.Count('id')).select_subclasses() - self.assertEqual(qs.get(id=self.child1.id).id__count, 1) - - - def test_annotate_with_named_arguments_before_select_subclasses(self): - qs = InheritanceManagerTestParent.objects.annotate( - test_count=models.Count('id')).select_subclasses() - self.assertEqual(qs.get(id=self.child1.id).test_count, 1) - - - -class TimeStampedModelTests(TestCase): - def test_created(self): - t1 = TimeStamp.objects.create() - t2 = TimeStamp.objects.create() - self.assertTrue(t2.created > t1.created) - - - def test_modified(self): - t1 = TimeStamp.objects.create() - t2 = TimeStamp.objects.create() - t1.save() - self.assertTrue(t2.modified < t1.modified) - - - -class TimeFramedModelTests(TestCase): - def setUp(self): - self.now = datetime.now() - - - def test_not_yet_begun(self): - TimeFrame.objects.create(start=self.now+timedelta(days=2)) - self.assertEqual(TimeFrame.timeframed.count(), 0) - - - def test_finished(self): - TimeFrame.objects.create(end=self.now-timedelta(days=1)) - self.assertEqual(TimeFrame.timeframed.count(), 0) - - - def test_no_end(self): - TimeFrame.objects.create(start=self.now-timedelta(days=10)) - self.assertEqual(TimeFrame.timeframed.count(), 1) - - - def test_no_start(self): - TimeFrame.objects.create(end=self.now+timedelta(days=2)) - self.assertEqual(TimeFrame.timeframed.count(), 1) - - - def test_within_range(self): - TimeFrame.objects.create(start=self.now-timedelta(days=1), - end=self.now+timedelta(days=1)) - self.assertEqual(TimeFrame.timeframed.count(), 1) - - - -class TimeFrameManagerAddedTests(TestCase): - def test_manager_available(self): - self.assertTrue(isinstance(TimeFrameManagerAdded.timeframed, QueryManager)) - - - def test_conflict_error(self): - with self.assertRaises(ImproperlyConfigured): - class ErrorModel(TimeFramedModel): - timeframed = models.BooleanField() - - - -class StatusModelTests(TestCase): - def setUp(self): - self.model = Status - self.on_hold = Status.STATUS.on_hold - self.active = Status.STATUS.active - - - def test_created(self): - c1 = self.model.objects.create() - c2 = self.model.objects.create() - self.assertTrue(c2.status_changed > c1.status_changed) - self.assertEqual(self.model.active.count(), 2) - self.assertEqual(self.model.deleted.count(), 0) - - - def test_modification(self): - t1 = self.model.objects.create() - date_created = t1.status_changed - t1.status = self.on_hold - t1.save() - self.assertEqual(self.model.active.count(), 0) - self.assertEqual(self.model.on_hold.count(), 1) - self.assertTrue(t1.status_changed > date_created) - date_changed = t1.status_changed - t1.save() - self.assertEqual(t1.status_changed, date_changed) - date_active_again = t1.status_changed - t1.status = self.active - t1.save() - self.assertTrue(t1.status_changed > date_active_again) - - - -class StatusModelPlainTupleTests(StatusModelTests): - def setUp(self): - self.model = StatusPlainTuple - self.on_hold = StatusPlainTuple.STATUS[2][0] - self.active = StatusPlainTuple.STATUS[0][0] - - - -class StatusManagerAddedTests(TestCase): - def test_manager_available(self): - self.assertTrue(isinstance(StatusManagerAdded.active, QueryManager)) - - - def test_conflict_error(self): - with self.assertRaises(ImproperlyConfigured): - class ErrorModel(StatusModel): - STATUS = ( - ('active', 'Is Active'), - ('deleted', 'Is Deleted'), - ) - active = models.BooleanField() - - - -class QueryManagerTests(TestCase): - def setUp(self): - data = ((True, True, 0), - (True, False, 4), - (False, False, 2), - (False, True, 3), - (True, True, 1), - (True, False, 5)) - for p, c, o in data: - Post.objects.create(published=p, confirmed=c, order=o) - - - def test_passing_kwargs(self): - qs = Post.public.all() - self.assertEqual([p.order for p in qs], [0, 1, 4, 5]) - - - def test_passing_Q(self): - qs = Post.public_confirmed.all() - self.assertEqual([p.order for p in qs], [0, 1]) - - - def test_ordering(self): - qs = Post.public_reversed.all() - self.assertEqual([p.order for p in qs], [5, 4, 1, 0]) - - - -try: - from south.modelsinspector import introspector -except ImportError: - introspector = None - -@skipUnless(introspector, 'South is not installed') -class SouthFreezingTests(TestCase): - def test_introspector_adds_no_excerpt_field(self): - mf = Article._meta.get_field('body') - args, kwargs = introspector(mf) - self.assertEqual(kwargs['no_excerpt_field'], 'True') - - - def test_no_excerpt_field_works(self): - from .models import NoRendered - with self.assertRaises(FieldDoesNotExist): - NoRendered._meta.get_field('_body_excerpt') - - def test_status_field_no_check_for_status(self): - sf = StatusFieldDefaultFilled._meta.get_field('status') - args, kwargs = introspector(sf) - self.assertEqual(kwargs['no_check_for_status'], 'True') - - - -class PassThroughManagerTests(TestCase): - def setUp(self): - Dude.objects.create(name='The Dude', abides=True, has_rug=False) - Dude.objects.create(name='His Dudeness', abides=False, has_rug=True) - Dude.objects.create(name='Duder', abides=False, has_rug=False) - Dude.objects.create(name='El Duderino', abides=True, has_rug=True) - - - def test_chaining(self): - self.assertEqual(Dude.objects.by_name('Duder').count(), 1) - self.assertEqual(Dude.objects.all().by_name('Duder').count(), 1) - self.assertEqual(Dude.abiders.rug_positive().count(), 1) - self.assertEqual(Dude.abiders.all().rug_positive().count(), 1) - - - def test_manager_only_methods(self): - stats = Dude.abiders.get_stats() - self.assertEqual(stats['rug_count'], 1) - with self.assertRaises(AttributeError): - Dude.abiders.all().get_stats() - - - def test_queryset_pickling(self): - qs = Dude.objects.all() - saltyqs = pickle.dumps(qs) - unqs = pickle.loads(saltyqs) - self.assertEqual(unqs.by_name('The Dude').count(), 1) - - - def test_queryset_not_available_on_related_manager(self): - dude = Dude.objects.by_name('Duder').get() - Car.objects.create(name='Ford', owner=dude) - self.assertFalse(hasattr(dude.cars_owned, 'by_name')) - - - def test_using_dir(self): - # make sure introspecing via dir() doesn't actually cause queries, - # just as a sanity check. - with self.assertNumQueries(0): - querysets_to_dir = ( - Dude.objects, - Dude.objects.by_name('Duder'), - Dude.objects.all().by_name('Duder'), - Dude.abiders, - Dude.abiders.rug_positive(), - Dude.abiders.all().rug_positive() - ) - for qs in querysets_to_dir: - self.assertTrue('by_name' in dir(qs)) - self.assertTrue('abiding' in dir(qs)) - self.assertTrue('rug_positive' in dir(qs)) - self.assertTrue('rug_negative' in dir(qs)) - # some standard qs methods - self.assertTrue('count' in dir(qs)) - self.assertTrue('order_by' in dir(qs)) - self.assertTrue('select_related' in dir(qs)) - # make sure it's been de-duplicated - self.assertEqual(1, dir(qs).count('distinct')) - - # manager only method. - self.assertTrue('get_stats' in dir(Dude.abiders)) - # manager only method shouldn't appear on the non AbidingManager - self.assertFalse('get_stats' in dir(Dude.objects)) - # standard manager methods - self.assertTrue('get_query_set' in dir(Dude.abiders)) - self.assertTrue('contribute_to_class' in dir(Dude.abiders)) - - - -class CreatePassThroughManagerTests(TestCase): - def setUp(self): - self.dude = Dude.objects.create(name='El Duderino') - self.other_dude = Dude.objects.create(name='Das Dude') - - def test_reverse_manager(self): - Spot.objects.create( - name='The Crib', owner=self.dude, closed=True, secure=True, - secret=False) - self.assertEqual(self.dude.spots_owned.closed().count(), 1) - Spot.objects.create( - name='The Crux', owner=self.other_dude, closed=True, secure=True, - secret=False - ) - self.assertEqual(self.dude.spots_owned.closed().all().count(), 1) - self.assertEqual(self.dude.spots_owned.closed().count(), 1) - - def test_related_queryset_pickling(self): - Spot.objects.create( - name='The Crib', owner=self.dude, closed=True, secure=True, - secret=False) - qs = self.dude.spots_owned.closed() - pickled_qs = pickle.dumps(qs) - unpickled_qs = pickle.loads(pickled_qs) - self.assertEqual(unpickled_qs.secured().count(), 1) - - def test_related_queryset_superclass_method(self): - Spot.objects.create( - name='The Crib', owner=self.dude, closed=True, secure=True, - secret=False) - Spot.objects.create( - name='The Secret Crib', owner=self.dude, closed=False, secure=True, - secret=True) - self.assertEqual(self.dude.spots_owned.count(), 1) - - def test_related_manager_create(self): - self.dude.spots_owned.create(name='The Crib', closed=True, secure=True) - - -class FieldTrackerTestCase(TestCase): - - tracker = None - - def assertHasChanged(self, **kwargs): - tracker = kwargs.pop('tracker', self.tracker) - for field, value in kwargs.items(): - if value is None: - with self.assertRaises(FieldError): - tracker.has_changed(field) - else: - self.assertEqual(tracker.has_changed(field), value) - - def assertPrevious(self, **kwargs): - tracker = kwargs.pop('tracker', self.tracker) - for field, value in kwargs.items(): - self.assertEqual(tracker.previous(field), value) - - def assertChanged(self, **kwargs): - tracker = kwargs.pop('tracker', self.tracker) - self.assertEqual(tracker.changed(), kwargs) - - def assertCurrent(self, **kwargs): - tracker = kwargs.pop('tracker', self.tracker) - self.assertEqual(tracker.current(), kwargs) - - def update_instance(self, **kwargs): - for field, value in kwargs.items(): - setattr(self.instance, field, value) - self.instance.save() - - -class FieldTrackerCommonTests(object): - - def test_pre_save_previous(self): - self.assertPrevious(name=None, number=None) - self.instance.name = 'new age' - self.instance.number = 8 - self.assertPrevious(name=None, number=None) - - -class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): - - tracked_class = Tracked - - def setUp(self): - self.instance = self.tracked_class() - self.tracker = self.instance.tracker - - def test_descriptor(self): - self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker)) - - def test_pre_save_changed(self): - self.assertChanged(name=None) - self.instance.name = 'new age' - self.assertChanged(name=None) - self.instance.number = 8 - self.assertChanged(name=None, number=None) - self.instance.name = '' - self.assertChanged(name=None, number=None) - self.instance.mutable = [1,2,3] - self.assertChanged(name=None, number=None, mutable=None) - - def test_pre_save_has_changed(self): - self.assertHasChanged(name=True, number=False, mutable=False) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=False, mutable=False) - self.instance.number = 7 - self.assertHasChanged(name=True, number=True) - self.instance.mutable = [1,2,3] - self.assertHasChanged(name=True, number=True, mutable=True) - - def test_first_save(self): - self.assertHasChanged(name=True, number=False, mutable=False) - self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='', number=None, id=None, mutable=None) - self.assertChanged(name=None) - self.instance.name = 'retro' - self.instance.number = 4 - self.instance.mutable = [1,2,3] - self.assertHasChanged(name=True, number=True, mutable=True) - self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) - self.assertChanged(name=None, number=None, mutable=None) - # Django 1.4 doesn't have update_fields - if django.VERSION >= (1, 5, 0): - self.instance.save(update_fields=[]) - self.assertHasChanged(name=True, number=True, mutable=True) - self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) - self.assertChanged(name=None, number=None, mutable=None) - with self.assertRaises(ValueError): - self.instance.save(update_fields=['number']) - - def test_post_save_has_changed(self): - self.update_instance(name='retro', number=4, mutable=[1,2,3]) - self.assertHasChanged(name=False, number=False, mutable=False) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=False) - self.instance.number = 8 - self.assertHasChanged(name=True, number=True) - self.instance.mutable[1] = 4 - self.assertHasChanged(name=True, number=True, mutable=True) - self.instance.name = 'retro' - self.assertHasChanged(name=False, number=True, mutable=True) - - def test_post_save_previous(self): - self.update_instance(name='retro', number=4, mutable=[1,2,3]) - self.instance.name = 'new age' - self.assertPrevious(name='retro', number=4, mutable=[1,2,3]) - self.instance.mutable[1] = 4 - self.assertPrevious(name='retro', number=4, mutable=[1,2,3]) - - def test_post_save_changed(self): - self.update_instance(name='retro', number=4, mutable=[1,2,3]) - self.assertChanged() - self.instance.name = 'new age' - self.assertChanged(name='retro') - self.instance.number = 8 - self.assertChanged(name='retro', number=4) - self.instance.name = 'retro' - self.assertChanged(number=4) - self.instance.mutable[1] = 4 - self.assertChanged(number=4, mutable=[1,2,3]) - self.instance.mutable = [1,2,3] - self.assertChanged(number=4) - - def test_current(self): - self.assertCurrent(id=None, name='', number=None, mutable=None) - self.instance.name = 'new age' - self.assertCurrent(id=None, name='new age', number=None, mutable=None) - self.instance.number = 8 - self.assertCurrent(id=None, name='new age', number=8, mutable=None) - self.instance.mutable = [1,2,3] - self.assertCurrent(id=None, name='new age', number=8, mutable=[1,2,3]) - self.instance.mutable[1] = 4 - self.assertCurrent(id=None, name='new age', number=8, mutable=[1,4,3]) - self.instance.save() - self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1,4,3]) - - @skipUnless( - django.VERSION >= (1, 5, 0), "Django 1.4 doesn't have update_fields") - def test_update_fields(self): - self.update_instance(name='retro', number=4, mutable=[1,2,3]) - self.assertChanged() - self.instance.name = 'new age' - self.instance.number = 8 - self.instance.mutable = [4,5,6] - self.assertChanged(name='retro', number=4, mutable=[1,2,3]) - self.instance.save(update_fields=[]) - self.assertChanged(name='retro', number=4, mutable=[1,2,3]) - self.instance.save(update_fields=['name']) - in_db = self.tracked_class.objects.get(id=self.instance.id) - self.assertEqual(in_db.name, self.instance.name) - self.assertNotEqual(in_db.number, self.instance.number) - self.assertChanged(number=4, mutable=[1,2,3]) - self.instance.save(update_fields=['number']) - self.assertChanged(mutable=[1,2,3]) - self.instance.save(update_fields=['mutable']) - self.assertChanged() - in_db = self.tracked_class.objects.get(id=self.instance.id) - self.assertEqual(in_db.name, self.instance.name) - self.assertEqual(in_db.number, self.instance.number) - self.assertEqual(in_db.mutable, self.instance.mutable) - - def test_with_deferred(self): - self.instance.name = 'new age' - self.instance.number = 1 - self.instance.save() - item = list(self.tracked_class.objects.only('name').all())[0] - self.assertTrue(item.tracker.deferred_fields) - - self.assertEqual(item.tracker.previous('number'), None) - self.assertTrue('number' in item.tracker.deferred_fields) - - self.assertEqual(item.number, 1) - self.assertTrue('number' not in item.tracker.deferred_fields) - self.assertEqual(item.tracker.previous('number'), 1) - self.assertFalse(item.tracker.has_changed('number')) - - item.number = 2 - self.assertTrue(item.tracker.has_changed('number')) - - def test_can_pickle_objects(self): - pickle.dumps(self.instance) - - -class FieldTrackedModelCustomTests(FieldTrackerTestCase, - FieldTrackerCommonTests): - - tracked_class = TrackedNotDefault - - def setUp(self): - self.instance = self.tracked_class() - self.tracker = self.instance.name_tracker - - def test_pre_save_changed(self): - self.assertChanged(name=None) - self.instance.name = 'new age' - self.assertChanged(name=None) - self.instance.number = 8 - self.assertChanged(name=None) - self.instance.name = '' - self.assertChanged(name=None) - - def test_first_save(self): - self.assertHasChanged(name=True, number=None) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='') - self.assertChanged(name=None) - self.instance.name = 'retro' - self.instance.number = 4 - self.assertHasChanged(name=True, number=None) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='retro') - self.assertChanged(name=None) - - def test_pre_save_has_changed(self): - self.assertHasChanged(name=True, number=None) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=None) - self.instance.number = 7 - self.assertHasChanged(name=True, number=None) - - def test_post_save_has_changed(self): - self.update_instance(name='retro', number=4) - self.assertHasChanged(name=False, number=None) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=None) - self.instance.number = 8 - self.assertHasChanged(name=True, number=None) - self.instance.name = 'retro' - self.assertHasChanged(name=False, number=None) - - def test_post_save_previous(self): - self.update_instance(name='retro', number=4) - self.instance.name = 'new age' - self.assertPrevious(name='retro', number=None) - - def test_post_save_changed(self): - self.update_instance(name='retro', number=4) - self.assertChanged() - self.instance.name = 'new age' - self.assertChanged(name='retro') - self.instance.number = 8 - self.assertChanged(name='retro') - self.instance.name = 'retro' - self.assertChanged() - - def test_current(self): - self.assertCurrent(name='') - self.instance.name = 'new age' - self.assertCurrent(name='new age') - self.instance.number = 8 - self.assertCurrent(name='new age') - self.instance.save() - self.assertCurrent(name='new age') - - @skipUnless( - django.VERSION >= (1, 5, 0), "Django 1.4 doesn't have update_fields") - def test_update_fields(self): - self.update_instance(name='retro', number=4) - self.assertChanged() - self.instance.name = 'new age' - self.instance.number = 8 - self.instance.save(update_fields=['name', 'number']) - self.assertChanged() - - -class FieldTrackedModelAttributeTests(FieldTrackerTestCase): - - tracked_class = TrackedNonFieldAttr - - def setUp(self): - self.instance = self.tracked_class() - self.tracker = self.instance.tracker - - def test_previous(self): - self.assertPrevious(rounded=None) - self.instance.number = 7.5 - self.assertPrevious(rounded=None) - self.instance.save() - self.assertPrevious(rounded=8) - self.instance.number = 7.2 - self.assertPrevious(rounded=8) - self.instance.save() - self.assertPrevious(rounded=7) - - def test_has_changed(self): - self.assertHasChanged(rounded=False) - self.instance.number = 7.5 - self.assertHasChanged(rounded=True) - self.instance.save() - self.assertHasChanged(rounded=False) - self.instance.number = 7.2 - self.assertHasChanged(rounded=True) - self.instance.number = 7.8 - self.assertHasChanged(rounded=False) - - def test_changed(self): - self.assertChanged() - self.instance.number = 7.5 - self.assertPrevious(rounded=None) - self.instance.save() - self.assertPrevious() - self.instance.number = 7.8 - self.assertPrevious() - self.instance.number = 7.2 - self.assertPrevious(rounded=8) - self.instance.save() - self.assertPrevious() - - def test_current(self): - self.assertCurrent(rounded=None) - self.instance.number = 7.5 - self.assertCurrent(rounded=8) - self.instance.save() - self.assertCurrent(rounded=8) - - -class FieldTrackedModelMultiTests(FieldTrackerTestCase, - FieldTrackerCommonTests): - - tracked_class = TrackedMultiple - - def setUp(self): - self.instance = self.tracked_class() - self.trackers = [self.instance.name_tracker, - self.instance.number_tracker] - - def test_pre_save_changed(self): - self.tracker = self.instance.name_tracker - self.assertChanged(name=None) - self.instance.name = 'new age' - self.assertChanged(name=None) - self.instance.number = 8 - self.assertChanged(name=None) - self.instance.name = '' - self.assertChanged(name=None) - self.tracker = self.instance.number_tracker - self.assertChanged(number=None) - self.instance.name = 'new age' - self.assertChanged(number=None) - self.instance.number = 8 - self.assertChanged(number=None) - - def test_pre_save_has_changed(self): - self.tracker = self.instance.name_tracker - self.assertHasChanged(name=True, number=None) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=None) - self.tracker = self.instance.number_tracker - self.assertHasChanged(name=None, number=False) - self.instance.name = 'new age' - self.assertHasChanged(name=None, number=False) - - def test_pre_save_previous(self): - for tracker in self.trackers: - self.tracker = tracker - super(FieldTrackedModelMultiTests, self).test_pre_save_previous() - - def test_post_save_has_changed(self): - self.update_instance(name='retro', number=4) - self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) - self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) - self.instance.name = 'new age' - self.assertHasChanged(tracker=self.trackers[0], name=True, number=None) - self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) - self.instance.number = 8 - self.assertHasChanged(tracker=self.trackers[0], name=True, number=None) - self.assertHasChanged(tracker=self.trackers[1], name=None, number=True) - self.instance.name = 'retro' - self.instance.number = 4 - self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) - self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) - - def test_post_save_previous(self): - self.update_instance(name='retro', number=4) - self.instance.name = 'new age' - self.instance.number = 8 - self.assertPrevious(tracker=self.trackers[0], name='retro', number=None) - self.assertPrevious(tracker=self.trackers[1], name=None, number=4) - - def test_post_save_changed(self): - self.update_instance(name='retro', number=4) - self.assertChanged(tracker=self.trackers[0]) - self.assertChanged(tracker=self.trackers[1]) - self.instance.name = 'new age' - self.assertChanged(tracker=self.trackers[0], name='retro') - self.assertChanged(tracker=self.trackers[1]) - self.instance.number = 8 - self.assertChanged(tracker=self.trackers[0], name='retro') - self.assertChanged(tracker=self.trackers[1], number=4) - self.instance.name = 'retro' - self.instance.number = 4 - self.assertChanged(tracker=self.trackers[0]) - self.assertChanged(tracker=self.trackers[1]) - - def test_current(self): - self.assertCurrent(tracker=self.trackers[0], name='') - self.assertCurrent(tracker=self.trackers[1], number=None) - self.instance.name = 'new age' - self.assertCurrent(tracker=self.trackers[0], name='new age') - self.assertCurrent(tracker=self.trackers[1], number=None) - self.instance.number = 8 - self.assertCurrent(tracker=self.trackers[0], name='new age') - self.assertCurrent(tracker=self.trackers[1], number=8) - self.instance.save() - self.assertCurrent(tracker=self.trackers[0], name='new age') - self.assertCurrent(tracker=self.trackers[1], number=8) - - -class FieldTrackerForeignKeyTests(FieldTrackerTestCase): - - fk_class = Tracked - tracked_class = TrackedFK - - def setUp(self): - self.old_fk = self.fk_class.objects.create(number=8) - self.instance = self.tracked_class.objects.create(fk=self.old_fk) - - def test_default(self): - self.tracker = self.instance.tracker - self.assertChanged() - self.assertPrevious() - self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) - self.assertChanged(fk_id=self.old_fk.id) - self.assertPrevious(fk_id=self.old_fk.id) - self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id) - - def test_custom(self): - self.tracker = self.instance.custom_tracker - self.assertChanged() - self.assertPrevious() - self.assertCurrent(fk_id=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) - self.assertChanged(fk_id=self.old_fk.id) - self.assertPrevious(fk_id=self.old_fk.id) - self.assertCurrent(fk_id=self.instance.fk_id) - - def test_custom_without_id(self): - with self.assertNumQueries(1): - self.tracked_class.objects.get() - self.tracker = self.instance.custom_tracker_without_id - self.assertChanged() - self.assertPrevious() - self.assertCurrent(fk=self.old_fk.id) - self.instance.fk = self.fk_class.objects.create(number=8) - self.assertChanged(fk=self.old_fk.id) - self.assertPrevious(fk=self.old_fk.id) - self.assertCurrent(fk=self.instance.fk_id) - - -class InheritedFieldTrackerTests(FieldTrackerTests): - - tracked_class = InheritedTracked - - def test_child_fields_not_tracked(self): - self.name2 = 'test' - self.assertEqual(self.tracker.previous('name2'), None) - self.assertRaises(FieldError, self.tracker.has_changed, 'name2') - - -class ModelTrackerTests(FieldTrackerTests): - - tracked_class = ModelTracked - - def test_pre_save_changed(self): - self.assertChanged() - self.instance.name = 'new age' - self.assertChanged() - self.instance.number = 8 - self.assertChanged() - self.instance.name = '' - self.assertChanged() - self.instance.mutable = [1,2,3] - self.assertChanged() - - def test_first_save(self): - self.assertHasChanged(name=True, number=True, mutable=True) - self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='', number=None, id=None, mutable=None) - self.assertChanged() - self.instance.name = 'retro' - self.instance.number = 4 - self.instance.mutable = [1,2,3] - self.assertHasChanged(name=True, number=True, mutable=True) - self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) - self.assertChanged() - # Django 1.4 doesn't have update_fields - if django.VERSION >= (1, 5, 0): - self.instance.save(update_fields=[]) - self.assertHasChanged(name=True, number=True, mutable=True) - self.assertPrevious(name=None, number=None, mutable=None) - self.assertCurrent(name='retro', number=4, id=None, mutable=[1,2,3]) - self.assertChanged() - with self.assertRaises(ValueError): - self.instance.save(update_fields=['number']) - - def test_pre_save_has_changed(self): - self.assertHasChanged(name=True, number=True) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=True) - self.instance.number = 7 - self.assertHasChanged(name=True, number=True) - - -class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests): - - tracked_class = ModelTrackedNotDefault - - def test_first_save(self): - self.assertHasChanged(name=True, number=True) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='') - self.assertChanged() - self.instance.name = 'retro' - self.instance.number = 4 - self.assertHasChanged(name=True, number=True) - self.assertPrevious(name=None, number=None) - self.assertCurrent(name='retro') - self.assertChanged() - - def test_pre_save_has_changed(self): - self.assertHasChanged(name=True, number=True) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=True) - self.instance.number = 7 - self.assertHasChanged(name=True, number=True) - - def test_pre_save_changed(self): - self.assertChanged() - self.instance.name = 'new age' - self.assertChanged() - self.instance.number = 8 - self.assertChanged() - self.instance.name = '' - self.assertChanged() - - -class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests): - - tracked_class = ModelTrackedMultiple - - def test_pre_save_has_changed(self): - self.tracker = self.instance.name_tracker - self.assertHasChanged(name=True, number=True) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=True) - self.tracker = self.instance.number_tracker - self.assertHasChanged(name=True, number=True) - self.instance.name = 'new age' - self.assertHasChanged(name=True, number=True) - - def test_pre_save_changed(self): - self.tracker = self.instance.name_tracker - self.assertChanged() - self.instance.name = 'new age' - self.assertChanged() - self.instance.number = 8 - self.assertChanged() - self.instance.name = '' - self.assertChanged() - self.tracker = self.instance.number_tracker - self.assertChanged() - self.instance.name = 'new age' - self.assertChanged() - self.instance.number = 8 - self.assertChanged() - - -class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests): - - fk_class = ModelTracked - tracked_class = ModelTrackedFK - - def test_custom_without_id(self): - with self.assertNumQueries(2): - self.tracked_class.objects.get() - self.tracker = self.instance.custom_tracker_without_id - self.assertChanged() - self.assertPrevious() - self.assertCurrent(fk=self.old_fk) - self.instance.fk = self.fk_class.objects.create(number=8) - self.assertNotEqual(self.instance.fk, self.old_fk) - self.assertChanged(fk=self.old_fk) - self.assertPrevious(fk=self.old_fk) - self.assertCurrent(fk=self.instance.fk) - - -class InheritedModelTrackerTests(ModelTrackerTests): - - tracked_class = InheritedModelTracked - - def test_child_fields_not_tracked(self): - self.name2 = 'test' - self.assertEqual(self.tracker.previous('name2'), None) - self.assertTrue(self.tracker.has_changed('name2')) diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 021b12f..5c7c7d8 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -2,18 +2,103 @@ from __future__ import unicode_literals from copy import deepcopy +import django from django.core.exceptions import FieldError from django.db import models +from django.db.models.fields.files import FileDescriptor from django.db.models.query_utils import DeferredAttribute -from django.db.models.signals import post_save -from django.dispatch import receiver + + +class DescriptorMixin(object): + tracker_instance = None + + def __get__(self, instance, owner): + if instance is None: + return self + was_deferred = False + field_name = self._get_field_name() + if field_name in instance._deferred_fields: + instance._deferred_fields.remove(field_name) + was_deferred = True + value = super(DescriptorMixin, self).__get__(instance, owner) + if was_deferred: + self.tracker_instance.saved_data[field_name] = deepcopy(value) + return value + + def _get_field_name(self): + return self.field_name + + +class DescriptorWrapper(object): + + def __init__(self, field_name, descriptor, tracker_attname): + self.field_name = field_name + self.descriptor = descriptor + self.tracker_attname = tracker_attname + + def __get__(self, instance, owner): + if instance is None: + return self + was_deferred = self.field_name in instance.get_deferred_fields() + try: + value = self.descriptor.__get__(instance, owner) + except AttributeError: + value = self.descriptor + if was_deferred: + tracker_instance = getattr(instance, self.tracker_attname) + tracker_instance.saved_data[self.field_name] = deepcopy(value) + return value + + def __set__(self, instance, value): + initialized = hasattr(instance, '_instance_intialized') + was_deferred = self.field_name in instance.get_deferred_fields() + + # Sentinel attribute to detect whether we are already trying to + # set the attribute higher up the stack. This prevents infinite + # recursion when retrieving deferred values from the database. + recursion_sentinel_attname = '_setting_' + self.field_name + already_setting = hasattr(instance, recursion_sentinel_attname) + + if initialized and was_deferred and not already_setting: + setattr(instance, recursion_sentinel_attname, True) + try: + # Retrieve the value to set the saved_data value. + # This will undefer the field + getattr(instance, self.field_name) + finally: + instance.__dict__.pop(recursion_sentinel_attname, None) + if hasattr(self.descriptor, '__set__'): + self.descriptor.__set__(instance, value) + else: + instance.__dict__[self.field_name] = value + + @staticmethod + def cls_for_descriptor(descriptor): + if hasattr(descriptor, '__delete__'): + return FullDescriptorWrapper + else: + return DescriptorWrapper + + +class FullDescriptorWrapper(DescriptorWrapper): + """ + Wrapper for descriptors with all three descriptor methods. + """ + def __delete__(self, obj): + self.descriptor.__delete__(obj) + class FieldInstanceTracker(object): def __init__(self, instance, fields, field_map): self.instance = instance self.fields = fields self.field_map = field_map - self.init_deferred_fields() + if django.VERSION < (1, 10): + self.init_deferred_fields() + + @property + def deferred_fields(self): + return self.instance._deferred_fields if django.VERSION < (1, 10) else self.instance.get_deferred_fields() def get_field_value(self, field): return getattr(self.instance, self.field_map[field]) @@ -33,10 +118,11 @@ class FieldInstanceTracker(object): def current(self, fields=None): """Returns dict of current values for all tracked fields""" if fields is None: - if self.deferred_fields: + deferred_fields = self.deferred_fields + if deferred_fields: fields = [ field for field in self.fields - if field not in self.deferred_fields + if field not in deferred_fields ] else: fields = self.fields @@ -46,12 +132,31 @@ class FieldInstanceTracker(object): def has_changed(self, field): """Returns ``True`` if field has changed from currently saved value""" if field in self.fields: + # deferred fields haven't changed + if field in self.deferred_fields and field not in self.instance.__dict__: + return False return self.previous(field) != self.get_field_value(field) else: raise FieldError('field "%s" not tracked' % field) def previous(self, field): """Returns currently saved value of given field""" + + # handle deferred fields that have not yet been loaded from the database + if self.instance.pk and field in self.deferred_fields and field not in self.saved_data: + + # if the field has not been assigned locally, simply fetch and un-defer the value + if field not in self.instance.__dict__: + self.get_field_value(field) + + # if the field has been assigned locally, store the local value, fetch the database value, + # store database value to saved_data, and restore the local value + else: + current_value = self.get_field_value(field) + self.instance.refresh_from_db(fields=[field]) + self.saved_data[field] = deepcopy(self.get_field_value(field)) + setattr(self.instance, self.field_map[field], current_value) + return self.saved_data.get(field) def changed(self): @@ -63,32 +168,27 @@ class FieldInstanceTracker(object): ) def init_deferred_fields(self): - self.deferred_fields = [] - if not self.instance._deferred: + self.instance._deferred_fields = set() + if hasattr(self.instance, '_deferred') and not self.instance._deferred: return - class DeferredAttributeTracker(DeferredAttribute): - def __get__(field, instance, owner): - data = instance.__dict__ - if data.get(field.field_name, field) is field: - self.deferred_fields.remove(field.field_name) - value = super(DeferredAttributeTracker, field).__get__( - instance, owner) - self.saved_data[field.field_name] = deepcopy(value) - return data[field.field_name] + class DeferredAttributeTracker(DescriptorMixin, DeferredAttribute): + tracker_instance = self - for field in self.fields: + class FileDescriptorTracker(DescriptorMixin, FileDescriptor): + tracker_instance = self + + def _get_field_name(self): + return self.field.name + + self.instance._deferred_fields = self.instance.get_deferred_fields() + for field in self.instance._deferred_fields: field_obj = self.instance.__class__.__dict__.get(field) - if isinstance(field_obj, DeferredAttribute): - self.deferred_fields.append(field) - - # Django 1.4 - model = None - if hasattr(field_obj, 'model_ref'): - model = field_obj.model_ref() - - field_tracker = DeferredAttributeTracker( - field_obj.field_name, model) + if isinstance(field_obj, FileDescriptor): + field_tracker = FileDescriptorTracker(field_obj.field) + setattr(self.instance.__class__, field, field_tracker) + else: + field_tracker = DeferredAttributeTracker(field, type(self.instance)) setattr(self.instance.__class__, field, field_tracker) @@ -102,7 +202,7 @@ class FieldTracker(object): def get_field_map(self, cls): """Returns dict mapping fields names to model attribute names""" field_map = dict((field, field) for field in self.fields) - all_fields = dict((f.name, f.attname) for f in cls._meta.local_fields) + all_fields = dict((f.name, f.attname) for f in cls._meta.fields) field_map.update(**dict((k, v) for (k, v) in all_fields.items() if k in field_map)) return field_map @@ -114,22 +214,34 @@ class FieldTracker(object): def finalize_class(self, sender, **kwargs): if self.fields is None: - self.fields = (field.attname for field in sender._meta.local_fields) + self.fields = (field.attname for field in sender._meta.fields) self.fields = set(self.fields) + if django.VERSION >= (1, 10): + for field_name in self.fields: + descriptor = getattr(sender, field_name) + wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) + wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) + setattr(sender, field_name, wrapped_descriptor) self.field_map = self.get_field_map(sender) models.signals.post_init.connect(self.initialize_tracker) self.model_class = sender setattr(sender, self.name, self) - - # Rather than patch the save method on the instance, - # we can observe the post_save signal on the class. - @receiver(post_save, sender=None, weak=False) - def handler(sender, instance, **kwargs): - if not isinstance(instance, self.model_class): - return - + self.patch_save(sender) + + def initialize_tracker(self, sender, instance, **kwargs): + if not isinstance(instance, self.model_class): + return # Only init instances of given model (including children) + tracker = self.tracker_class(instance, self.fields, self.field_map) + setattr(instance, self.attname, tracker) + tracker.set_saved_fields() + instance._instance_intialized = True + + def patch_save(self, model): + original_save = model.save + + def save(instance, *args, **kwargs): + ret = original_save(instance, *args, **kwargs) update_fields = kwargs.get('update_fields') - if not update_fields and update_fields is not None: # () or [] fields = update_fields elif update_fields is None: @@ -139,19 +251,13 @@ class FieldTracker(object): field for field in update_fields if field in self.fields ) - getattr(instance, self.attname).set_saved_fields( fields=fields ) - + return ret + + model.save = save - def initialize_tracker(self, sender, instance, **kwargs): - if not isinstance(instance, self.model_class): - return # Only init instances of given model (including children) - tracker = self.tracker_class(instance, self.fields, self.field_map) - setattr(instance, self.attname, tracker) - tracker.set_saved_fields() - def __get__(self, instance, owner): if instance is None: return self diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..f190999 --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,4 @@ +pytest==4.5.0 +pytest-django==3.4.7 +psycopg2==2.7.6.1 +pytest-cov==2.7.1 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f7c07a6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +# Dependencies for development of django-model-utils + +tox +sphinx +twine +freezegun diff --git a/runtests.py b/runtests.py deleted file mode 100755 index bb86dc7..0000000 --- a/runtests.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python - -import os, sys - -from django.conf import settings -import django - - -DEFAULT_SETTINGS = dict( - INSTALLED_APPS=( - 'model_utils', - 'model_utils.tests', - ), - DATABASES={ - "default": { - "ENGINE": "django.db.backends.sqlite3" - } - }, - ) - - -def runtests(): - if not settings.configured: - settings.configure(**DEFAULT_SETTINGS) - - # Compatibility with Django 1.7's stricter initialization - if hasattr(django, 'setup'): - django.setup() - - parent = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, parent) - - try: - from django.test.runner import DiscoverRunner - runner_class = DiscoverRunner - test_args = ['model_utils.tests'] - except ImportError: - from django.test.simple import DjangoTestSuiteRunner - runner_class = DjangoTestSuiteRunner - test_args = ['tests'] - - failures = runner_class( - verbosity=1, interactive=True, failfast=False).run_tests(test_args) - sys.exit(failures) - - -if __name__ == '__main__': - runtests() diff --git a/setup.cfg b/setup.cfg index 7d5a6f7..6058f77 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,5 +3,9 @@ source-dir = docs/ build-dir = docs/_build all_files = 1 +[tool:pytest] +django_find_project = false +DJANGO_SETTINGS_MODULE = tests.settings + [wheel] universal = 1 diff --git a/setup.py b/setup.py index 25fca37..5b61beb 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,22 @@ -from os.path import join +import os from setuptools import setup, find_packages -long_description = (open('README.rst').read() + - open('CHANGES.rst').read() + - open('TODO.rst').read()) +def long_desc(root_path): + FILES = ['README.rst', 'CHANGES.rst'] + for filename in FILES: + filepath = os.path.realpath(os.path.join(root_path, filename)) + if os.path.isfile(filepath): + with open(filepath, mode='r') as f: + yield f.read() -def get_version(): - with open(join('model_utils', '__init__.py')) as f: +HERE = os.path.abspath(os.path.dirname(__file__)) +long_description = "\n\n".join(long_desc(HERE)) + + +def get_version(root_path): + with open(os.path.join(root_path, 'model_utils', '__init__.py')) as f: for line in f: if line.startswith('__version__ ='): return line.split('=')[1].strip().strip('"\'') @@ -16,14 +24,16 @@ def get_version(): setup( name='django-model-utils', - version=get_version(), + version=get_version(HERE), + license="BSD", description='Django model mixins and utilities', long_description=long_description, author='Carl Meyer', author_email='carl@oddbird.net', - url='https://github.com/carljm/django-model-utils/', - packages=find_packages(), - install_requires=['Django>=1.4.2'], + maintainer='JazzBand', + url='https://github.com/jazzband/django-model-utils/', + packages=find_packages(exclude=['tests*']), + install_requires=['Django>=1.11'], classifiers=[ 'Development Status :: 5 - Production/Stable', 'Environment :: Web Environment', @@ -31,14 +41,20 @@ setup( 'License :: OSI Approved :: BSD License', 'Operating System :: OS Independent', 'Programming Language :: Python', - 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.2', - 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.6', 'Framework :: Django', + 'Framework :: Django :: 2.1', + 'Framework :: Django :: 1.11', + 'Framework :: Django :: 2.2', ], zip_safe=False, - tests_require=["Django>=1.4.2"], - test_suite='runtests.runtests' + tests_require=['Django>=1.1.11'], + package_data={ + 'model_utils': [ + 'locale/*/LC_MESSAGES/django.po', 'locale/*/LC_MESSAGES/django.mo' + ], + }, ) diff --git a/model_utils/tests/__init__.py b/tests/__init__.py similarity index 100% rename from model_utils/tests/__init__.py rename to tests/__init__.py diff --git a/tests/fields.py b/tests/fields.py new file mode 100644 index 0000000..7c29aa4 --- /dev/null +++ b/tests/fields.py @@ -0,0 +1,43 @@ +import django +from django.db import models +from django.utils.six import with_metaclass, string_types + + +def mutable_from_db(value): + if value == '': + return None + try: + if isinstance(value, string_types): + return [int(i) for i in value.split(',')] + except ValueError: + pass + return value + + +def mutable_to_db(value): + if value is None: + return '' + if isinstance(value, list): + value = ','.join((str(i) for i in value)) + return str(value) + + +if django.VERSION >= (1, 9, 0): + class MutableField(models.TextField): + def to_python(self, value): + return mutable_from_db(value) + + def from_db_value(self, value, expression, connection, context): + return mutable_from_db(value) + + def get_db_prep_save(self, value, connection): + value = super(MutableField, self).get_db_prep_save(value, connection) + return mutable_to_db(value) +else: + class MutableField(with_metaclass(models.SubfieldBase, models.TextField)): + def to_python(self, value): + return mutable_from_db(value) + + def get_db_prep_save(self, value, connection): + value = mutable_to_db(value) + return super(MutableField, self).get_db_prep_save(value, connection) diff --git a/tests/managers.py b/tests/managers.py new file mode 100644 index 0000000..4a055f2 --- /dev/null +++ b/tests/managers.py @@ -0,0 +1,15 @@ +from __future__ import unicode_literals, absolute_import + +from model_utils.managers import SoftDeletableQuerySet, SoftDeletableManager + + +class CustomSoftDeleteQuerySet(SoftDeletableQuerySet): + def only_read(self): + return self.filter(is_read=True) + + +class CustomSoftDeleteManager(SoftDeletableManager): + _queryset_class = CustomSoftDeleteQuerySet + + def only_read(self): + return self.get_queryset().only_read() diff --git a/model_utils/tests/models.py b/tests/models.py similarity index 55% rename from model_utils/tests/models.py rename to tests/models.py index ea46d0f..9c5e374 100644 --- a/model_utils/tests/models.py +++ b/tests/models.py @@ -1,42 +1,59 @@ -from __future__ import unicode_literals +from __future__ import unicode_literals, absolute_import +import django from django.db import models +from django.db.models.query_utils import DeferredAttribute +from django.db.models import Manager from django.utils.encoding import python_2_unicode_compatible from django.utils.translation import ugettext_lazy as _ -from model_utils.models import TimeStampedModel, StatusModel, TimeFramedModel -from model_utils.tracker import FieldTracker, ModelTracker -from model_utils.managers import QueryManager, InheritanceManager, PassThroughManager -from model_utils.fields import SplitField, MonitorField, StatusField -from model_utils.tests.fields import MutableField from model_utils import Choices - +from model_utils.fields import ( + SplitField, + MonitorField, + StatusField, + UUIDField, +) +from model_utils.managers import ( + QueryManager, + InheritanceManager, + JoinManagerMixin +) +from model_utils.models import ( + SoftDeletableModel, + StatusModel, + TimeFramedModel, + TimeStampedModel, + UUIDModel, + SaveSignalHandlingModel, +) +from tests.fields import MutableField +from tests.managers import CustomSoftDeleteManager +from model_utils.tracker import FieldTracker, ModelTracker class InheritanceManagerTestRelated(models.Model): pass - @python_2_unicode_compatible class InheritanceManagerTestParent(models.Model): # FileField is just a handy descriptor-using field. Refs #6. non_related_field_using_descriptor = models.FileField(upload_to="test") related = models.ForeignKey( - InheritanceManagerTestRelated, related_name="imtests", null=True) + InheritanceManagerTestRelated, related_name="imtests", null=True, + on_delete=models.CASCADE) normal_field = models.TextField() - related_self = models.OneToOneField("self", related_name="imtests_self", null=True) + related_self = models.OneToOneField( + "self", related_name="imtests_self", null=True, + on_delete=models.CASCADE) objects = InheritanceManager() - def __unicode__(self): - return unicode(self.pk) - def __str__(self): return "%s(%s)" % ( self.__class__.__name__[len('InheritanceManagerTest'):], self.pk, - ) - + ) class InheritanceManagerTestChild1(InheritanceManagerTestParent): @@ -45,61 +62,69 @@ class InheritanceManagerTestChild1(InheritanceManagerTestParent): objects = InheritanceManager() - class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1): text_field = models.TextField() - class InheritanceManagerTestGrandChild1_2(InheritanceManagerTestChild1): text_field = models.TextField() - class InheritanceManagerTestChild2(InheritanceManagerTestParent): non_related_field_using_descriptor_2 = models.FileField(upload_to="test") normal_field_2 = models.TextField() - class InheritanceManagerTestChild3(InheritanceManagerTestParent): parent_ptr = models.OneToOneField( InheritanceManagerTestParent, related_name='manual_onetoone', - parent_link=True) + parent_link=True, on_delete=models.CASCADE) + + +class InheritanceManagerTestChild4(InheritanceManagerTestParent): + other_onetoone = models.OneToOneField( + InheritanceManagerTestParent, related_name='non_inheritance_relation', + parent_link=False, on_delete=models.CASCADE) + # The following is needed because of that Django bug: + # https://code.djangoproject.com/ticket/29998 + parent_ptr = models.OneToOneField( + InheritanceManagerTestParent, related_name='child4_onetoone', + parent_link=True, on_delete=models.CASCADE) class TimeStamp(TimeStampedModel): pass - class TimeFrame(TimeFramedModel): pass - class TimeFrameManagerAdded(TimeFramedModel): pass - class Monitored(models.Model): name = models.CharField(max_length=25) name_changed = MonitorField(monitor="name") - class MonitorWhen(models.Model): name = models.CharField(max_length=25) name_changed = MonitorField(monitor="name", when=["Jose", "Maria"]) - class MonitorWhenEmpty(models.Model): name = models.CharField(max_length=25) name_changed = MonitorField(monitor="name", when=[]) +class DoubleMonitored(models.Model): + name = models.CharField(max_length=25) + name_changed = MonitorField(monitor="name") + name2 = models.CharField(max_length=25) + name_changed2 = MonitorField(monitor="name2") + class Status(StatusModel): STATUS = Choices( @@ -109,7 +134,6 @@ class Status(StatusModel): ) - class StatusPlainTuple(StatusModel): STATUS = ( ("active", _("active")), @@ -118,7 +142,6 @@ class StatusPlainTuple(StatusModel): ) - class StatusManagerAdded(StatusModel): STATUS = ( ("active", _("active")), @@ -127,6 +150,25 @@ class StatusManagerAdded(StatusModel): ) +class StatusCustomManager(Manager): + pass + + +class AbstractStatusCustomManager(StatusModel): + STATUS = Choices( + ("first_choice", _("First choice")), + ("second_choice", _("Second choice")), + ) + + objects = StatusCustomManager() + + class Meta: + abstract = True + + +class StatusCustomManager(AbstractStatusCustomManager): + title = models.CharField(max_length=50) + class Post(models.Model): published = models.BooleanField(default=False) @@ -135,30 +177,26 @@ class Post(models.Model): objects = models.Manager() public = QueryManager(published=True) - public_confirmed = QueryManager(models.Q(published=True) & - models.Q(confirmed=True)) + public_confirmed = QueryManager( + models.Q(published=True) & models.Q(confirmed=True)) public_reversed = QueryManager(published=True).order_by("-order") class Meta: ordering = ("order",) - class Article(models.Model): title = models.CharField(max_length=50) body = SplitField() - class SplitFieldAbstractParent(models.Model): content = SplitField() - class Meta: abstract = True - class NoRendered(models.Model): """ Test that the no_excerpt_field keyword arg works. This arg should @@ -168,29 +206,24 @@ class NoRendered(models.Model): body = SplitField(no_excerpt_field=True) - class AuthorMixin(object): def by_author(self, name): return self.filter(author=name) - class PublishedMixin(object): def published(self): return self.filter(published=True) - def unpublished(self): return self.filter(published=False) - class ByAuthorQuerySet(models.query.QuerySet, AuthorMixin): pass - class FeaturedManager(models.Manager): def get_queryset(self): kwargs = {} @@ -198,95 +231,42 @@ class FeaturedManager(models.Manager): kwargs["using"] = self._db return ByAuthorQuerySet(self.model, **kwargs).filter(feature=True) - get_query_set = get_queryset +class AbstractTracked(models.Model): + number = 1 -class DudeQuerySet(models.query.QuerySet): - def abiding(self): - return self.filter(abides=True) - - def rug_positive(self): - return self.filter(has_rug=True) - - def rug_negative(self): - return self.filter(has_rug=False) - - def by_name(self, name): - return self.filter(name__iexact=name) - - - -class AbidingManager(PassThroughManager): - def get_queryset(self): - return DudeQuerySet(self.model).abiding() - - get_query_set = get_queryset - - def get_stats(self): - return { - "abiding_count": self.count(), - "rug_count": self.rug_positive().count(), - } - - - -class Dude(models.Model): - abides = models.BooleanField(default=True) - name = models.CharField(max_length=20) - has_rug = models.BooleanField(default=False) - - objects = PassThroughManager(DudeQuerySet) - abiders = AbidingManager() - - -class Car(models.Model): - name = models.CharField(max_length=20) - owner = models.ForeignKey(Dude, related_name='cars_owned') - - objects = PassThroughManager(DudeQuerySet) - - -class SpotManager(PassThroughManager): - def get_queryset(self): - return super(SpotManager, self).get_queryset().filter(secret=False) - - get_query_set = get_queryset - - -class SpotQuerySet(models.query.QuerySet): - def closed(self): - return self.filter(closed=True) - - def secured(self): - return self.filter(secure=True) - - -class Spot(models.Model): - name = models.CharField(max_length=20) - secure = models.BooleanField(default=True) - closed = models.BooleanField(default=False) - secret = models.BooleanField(default=False) - owner = models.ForeignKey(Dude, related_name='spots_owned') - - objects = SpotManager.for_queryset_class(SpotQuerySet)() + class Meta: + abstract = True class Tracked(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() - mutable = MutableField() + mutable = MutableField(default=None) tracker = FieldTracker() + def save(self, *args, **kwargs): + """ No-op save() to ensure that FieldTracker.patch_save() works. """ + super(Tracked, self).save(*args, **kwargs) + class TrackedFK(models.Model): - fk = models.ForeignKey('Tracked') + fk = models.ForeignKey('Tracked', on_delete=models.CASCADE) tracker = FieldTracker() custom_tracker = FieldTracker(fields=['fk_id']) custom_tracker_without_id = FieldTracker(fields=['fk']) +class TrackedAbstract(AbstractTracked): + name = models.CharField(max_length=20) + number = models.IntegerField() + mutable = MutableField(default=None) + + tracker = FieldTracker() + + class TrackedNotDefault(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() @@ -312,20 +292,31 @@ class TrackedMultiple(models.Model): number_tracker = FieldTracker(fields=['number']) +class TrackedFileField(models.Model): + some_file = models.FileField(upload_to='test_location') + + tracker = FieldTracker() + + class InheritedTracked(Tracked): name2 = models.CharField(max_length=20) +class InheritedTrackedFK(TrackedFK): + custom_tracker = FieldTracker(fields=['fk_id']) + custom_tracker_without_id = FieldTracker(fields=['fk']) + + class ModelTracked(models.Model): name = models.CharField(max_length=20) number = models.IntegerField() - mutable = MutableField() + mutable = MutableField(default=None) tracker = ModelTracker() class ModelTrackedFK(models.Model): - fk = models.ForeignKey('ModelTracked') + fk = models.ForeignKey('ModelTracked', on_delete=models.CASCADE) tracker = ModelTracker() custom_tracker = ModelTracker(fields=['fk_id']) @@ -346,6 +337,7 @@ class ModelTrackedMultiple(models.Model): name_tracker = ModelTracker(fields=['name']) number_tracker = ModelTracker(fields=['number']) + class InheritedModelTracked(ModelTracked): name2 = models.CharField(max_length=20) @@ -363,3 +355,91 @@ class StatusFieldDefaultNotFilled(models.Model): class StatusFieldChoicesName(models.Model): NAMED_STATUS = Choices((0, "no", "No"), (1, "yes", "Yes")) status = StatusField(choices_name='NAMED_STATUS') + + +class SoftDeletable(SoftDeletableModel): + """ + Test model with additional manager for full access to model + instances. + """ + name = models.CharField(max_length=20) + + all_objects = models.Manager() + + +class CustomSoftDelete(SoftDeletableModel): + is_read = models.BooleanField(default=False) + + objects = CustomSoftDeleteManager() + + +class StringyDescriptor(object): + """ + Descriptor that returns a string version of the underlying integer value. + """ + + def __init__(self, name): + self.name = name + + def __get__(self, obj, cls=None): + if obj is None: + return self + if self.name in obj.get_deferred_fields(): + # This queries the database, and sets the value on the instance. + if django.VERSION < (2, 1): + DeferredAttribute(field_name=self.name, model=cls).__get__(obj, cls) + else: + DeferredAttribute(field_name=self.name).__get__(obj, cls) + return str(obj.__dict__[self.name]) + + def __set__(self, obj, value): + obj.__dict__[self.name] = int(value) + + def __delete__(self, obj): + del obj.__dict__[self.name] + + +class CustomDescriptorField(models.IntegerField): + def contribute_to_class(self, cls, name, **kwargs): + super(CustomDescriptorField, self).contribute_to_class(cls, name, **kwargs) + setattr(cls, name, StringyDescriptor(name)) + + +class ModelWithCustomDescriptor(models.Model): + custom_field = CustomDescriptorField() + tracked_custom_field = CustomDescriptorField() + regular_field = models.IntegerField() + tracked_regular_field = models.IntegerField() + + tracker = FieldTracker(fields=['tracked_custom_field', 'tracked_regular_field']) + + +class JoinManager(JoinManagerMixin, models.Manager): + pass + + +class BoxJoinModel(models.Model): + name = models.CharField(max_length=32) + objects = JoinManager() + + +class JoinItemForeignKey(models.Model): + weight = models.IntegerField() + belonging = models.ForeignKey( + BoxJoinModel, + null=True, + on_delete=models.CASCADE + ) + objects = JoinManager() + + +class CustomUUIDModel(UUIDModel): + pass + + +class CustomNotPrimaryUUIDModel(models.Model): + uuid = UUIDField(primary_key=False) + + +class SaveSignalHandlingTestModel(SaveSignalHandlingModel): + name = models.CharField(max_length=20) diff --git a/tests/settings.py b/tests/settings.py new file mode 100644 index 0000000..e34c891 --- /dev/null +++ b/tests/settings.py @@ -0,0 +1,22 @@ +import os + +INSTALLED_APPS = ( + 'model_utils', + 'tests', +) +DATABASES = { + "default": { + "ENGINE": "django.db.backends.postgresql_psycopg2", + "NAME": os.environ.get("DJANGO_DATABASE_NAME_POSTGRES", "modelutils"), + "USER": os.environ.get("DJANGO_DATABASE_USER_POSTGRES", 'postgres'), + "PASSWORD": os.environ.get("DJANGO_DATABASE_PASSWORD_POSTGRES", ""), + "HOST": os.environ.get("DJANGO_DATABASE_HOST_POSTGRES", ""), + }, +} +SECRET_KEY = 'dummy' + +CACHES = { + 'default': { + 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', + } +} diff --git a/tests/signals.py b/tests/signals.py new file mode 100644 index 0000000..6b86ca2 --- /dev/null +++ b/tests/signals.py @@ -0,0 +1,6 @@ +def pre_save_test(instance, *args, **kwargs): + instance.pre_save_runned = True + + +def post_save_test(instance, created, *args, **kwargs): + instance.post_save_runned = True diff --git a/tests/test_choices.py b/tests/test_choices.py new file mode 100644 index 0000000..cb5bec9 --- /dev/null +++ b/tests/test_choices.py @@ -0,0 +1,308 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from model_utils import Choices + + +class ChoicesTests(TestCase): + def setUp(self): + self.STATUS = Choices('DRAFT', 'PUBLISHED') + + def test_getattr(self): + self.assertEqual(self.STATUS.DRAFT, 'DRAFT') + + def test_indexing(self): + self.assertEqual(self.STATUS['PUBLISHED'], 'PUBLISHED') + + def test_iteration(self): + self.assertEqual(tuple(self.STATUS), + (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) + + def test_reversed(self): + self.assertEqual(tuple(reversed(self.STATUS)), + (('PUBLISHED', 'PUBLISHED'), ('DRAFT', 'DRAFT'))) + + def test_len(self): + self.assertEqual(len(self.STATUS), 2) + + def test_repr(self): + self.assertEqual(repr(self.STATUS), "Choices" + repr(( + ('DRAFT', 'DRAFT', 'DRAFT'), + ('PUBLISHED', 'PUBLISHED', 'PUBLISHED'), + ))) + + def test_wrong_length_tuple(self): + with self.assertRaises(ValueError): + Choices(('a',)) + + def test_contains_value(self): + self.assertTrue('PUBLISHED' in self.STATUS) + self.assertTrue('DRAFT' in self.STATUS) + + def test_doesnt_contain_value(self): + self.assertFalse('UNPUBLISHED' in self.STATUS) + + def test_deepcopy(self): + import copy + self.assertEqual(list(self.STATUS), + list(copy.deepcopy(self.STATUS))) + + def test_equality(self): + self.assertEqual(self.STATUS, Choices('DRAFT', 'PUBLISHED')) + + def test_inequality(self): + self.assertNotEqual(self.STATUS, ['DRAFT', 'PUBLISHED']) + self.assertNotEqual(self.STATUS, Choices('DRAFT')) + + def test_composability(self): + self.assertEqual(Choices('DRAFT') + Choices('PUBLISHED'), self.STATUS) + self.assertEqual(Choices('DRAFT') + ('PUBLISHED',), self.STATUS) + self.assertEqual(('DRAFT',) + Choices('PUBLISHED'), self.STATUS) + + def test_option_groups(self): + c = Choices(('group a', ['one', 'two']), ['group b', ('three',)]) + self.assertEqual( + list(c), + [ + ('group a', [('one', 'one'), ('two', 'two')]), + ('group b', [('three', 'three')]), + ], + ) + + +class LabelChoicesTests(ChoicesTests): + def setUp(self): + self.STATUS = Choices( + ('DRAFT', 'is draft'), + ('PUBLISHED', 'is published'), + 'DELETED', + ) + + def test_iteration(self): + self.assertEqual(tuple(self.STATUS), ( + ('DRAFT', 'is draft'), + ('PUBLISHED', 'is published'), + ('DELETED', 'DELETED'), + )) + + def test_reversed(self): + self.assertEqual(tuple(reversed(self.STATUS)), ( + ('DELETED', 'DELETED'), + ('PUBLISHED', 'is published'), + ('DRAFT', 'is draft'), + )) + + def test_indexing(self): + self.assertEqual(self.STATUS['PUBLISHED'], 'is published') + + def test_default(self): + self.assertEqual(self.STATUS.DELETED, 'DELETED') + + def test_provided(self): + self.assertEqual(self.STATUS.DRAFT, 'DRAFT') + + def test_len(self): + self.assertEqual(len(self.STATUS), 3) + + def test_equality(self): + self.assertEqual(self.STATUS, Choices( + ('DRAFT', 'is draft'), + ('PUBLISHED', 'is published'), + 'DELETED', + )) + + def test_inequality(self): + self.assertNotEqual(self.STATUS, [ + ('DRAFT', 'is draft'), + ('PUBLISHED', 'is published'), + 'DELETED' + ]) + self.assertNotEqual(self.STATUS, Choices('DRAFT')) + + def test_repr(self): + self.assertEqual(repr(self.STATUS), "Choices" + repr(( + ('DRAFT', 'DRAFT', 'is draft'), + ('PUBLISHED', 'PUBLISHED', 'is published'), + ('DELETED', 'DELETED', 'DELETED'), + ))) + + def test_contains_value(self): + self.assertTrue('PUBLISHED' in self.STATUS) + self.assertTrue('DRAFT' in self.STATUS) + # This should be True, because both the display value + # and the internal representation are both DELETED. + self.assertTrue('DELETED' in self.STATUS) + + def test_doesnt_contain_value(self): + self.assertFalse('UNPUBLISHED' in self.STATUS) + + def test_doesnt_contain_display_value(self): + self.assertFalse('is draft' in self.STATUS) + + def test_composability(self): + self.assertEqual( + Choices(('DRAFT', 'is draft',)) + Choices(('PUBLISHED', 'is published'), 'DELETED'), + self.STATUS + ) + + self.assertEqual( + (('DRAFT', 'is draft',),) + Choices(('PUBLISHED', 'is published'), 'DELETED'), + self.STATUS + ) + + self.assertEqual( + Choices(('DRAFT', 'is draft',)) + (('PUBLISHED', 'is published'), 'DELETED'), + self.STATUS + ) + + def test_option_groups(self): + c = Choices( + ('group a', [(1, 'one'), (2, 'two')]), + ['group b', ((3, 'three'),)] + ) + self.assertEqual( + list(c), + [ + ('group a', [(1, 'one'), (2, 'two')]), + ('group b', [(3, 'three')]), + ], + ) + + +class IdentifierChoicesTests(ChoicesTests): + def setUp(self): + self.STATUS = Choices( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published'), + (2, 'DELETED', 'is deleted')) + + def test_iteration(self): + self.assertEqual(tuple(self.STATUS), ( + (0, 'is draft'), + (1, 'is published'), + (2, 'is deleted'), + )) + + def test_reversed(self): + self.assertEqual(tuple(reversed(self.STATUS)), ( + (2, 'is deleted'), + (1, 'is published'), + (0, 'is draft'), + )) + + def test_indexing(self): + self.assertEqual(self.STATUS[1], 'is published') + + def test_getattr(self): + self.assertEqual(self.STATUS.DRAFT, 0) + + def test_len(self): + self.assertEqual(len(self.STATUS), 3) + + def test_repr(self): + self.assertEqual(repr(self.STATUS), "Choices" + repr(( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published'), + (2, 'DELETED', 'is deleted'), + ))) + + def test_contains_value(self): + self.assertTrue(0 in self.STATUS) + self.assertTrue(1 in self.STATUS) + self.assertTrue(2 in self.STATUS) + + def test_doesnt_contain_value(self): + self.assertFalse(3 in self.STATUS) + + def test_doesnt_contain_display_value(self): + self.assertFalse('is draft' in self.STATUS) + + def test_doesnt_contain_python_attr(self): + self.assertFalse('PUBLISHED' in self.STATUS) + + def test_equality(self): + self.assertEqual(self.STATUS, Choices( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published'), + (2, 'DELETED', 'is deleted') + )) + + def test_inequality(self): + self.assertNotEqual(self.STATUS, [ + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published'), + (2, 'DELETED', 'is deleted') + ]) + self.assertNotEqual(self.STATUS, Choices('DRAFT')) + + def test_composability(self): + self.assertEqual( + Choices( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published') + ) + Choices( + (2, 'DELETED', 'is deleted'), + ), + self.STATUS + ) + + self.assertEqual( + Choices( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published') + ) + ( + (2, 'DELETED', 'is deleted'), + ), + self.STATUS + ) + + self.assertEqual( + ( + (0, 'DRAFT', 'is draft'), + (1, 'PUBLISHED', 'is published') + ) + Choices( + (2, 'DELETED', 'is deleted'), + ), + self.STATUS + ) + + def test_option_groups(self): + c = Choices( + ('group a', [(1, 'ONE', 'one'), (2, 'TWO', 'two')]), + ['group b', ((3, 'THREE', 'three'),)] + ) + self.assertEqual( + list(c), + [ + ('group a', [(1, 'one'), (2, 'two')]), + ('group b', [(3, 'three')]), + ], + ) + + +class SubsetChoicesTest(TestCase): + + def setUp(self): + self.choices = Choices( + (0, 'a', 'A'), + (1, 'b', 'B'), + ) + + def test_nonexistent_identifiers_raise(self): + with self.assertRaises(ValueError): + self.choices.subset('a', 'c') + + def test_solo_nonexistent_identifiers_raise(self): + with self.assertRaises(ValueError): + self.choices.subset('c') + + def test_empty_subset_passes(self): + subset = self.choices.subset() + + self.assertEqual(subset, Choices()) + + def test_subset_returns_correct_subset(self): + subset = self.choices.subset('a') + + self.assertEqual(subset, Choices((0, 'a', 'A'))) diff --git a/tests/test_fields/__init__.py b/tests/test_fields/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py new file mode 100644 index 0000000..83cde07 --- /dev/null +++ b/tests/test_fields/test_field_tracker.py @@ -0,0 +1,792 @@ +from __future__ import unicode_literals + +import django +from django.core.exceptions import FieldError +from django.test import TestCase +from django.core.cache import cache +from model_utils import FieldTracker +from model_utils.tracker import DescriptorWrapper +from tests.models import ( + Tracked, TrackedFK, InheritedTrackedFK, TrackedNotDefault, TrackedNonFieldAttr, TrackedMultiple, + InheritedTracked, TrackedFileField, TrackedAbstract, + ModelTracked, ModelTrackedFK, ModelTrackedNotDefault, ModelTrackedMultiple, InheritedModelTracked, +) + + +class FieldTrackerTestCase(TestCase): + + tracker = None + + def assertHasChanged(self, **kwargs): + tracker = kwargs.pop('tracker', self.tracker) + for field, value in kwargs.items(): + if value is None: + with self.assertRaises(FieldError): + tracker.has_changed(field) + else: + self.assertEqual(tracker.has_changed(field), value) + + def assertPrevious(self, **kwargs): + tracker = kwargs.pop('tracker', self.tracker) + for field, value in kwargs.items(): + self.assertEqual(tracker.previous(field), value) + + def assertChanged(self, **kwargs): + tracker = kwargs.pop('tracker', self.tracker) + self.assertEqual(tracker.changed(), kwargs) + + def assertCurrent(self, **kwargs): + tracker = kwargs.pop('tracker', self.tracker) + self.assertEqual(tracker.current(), kwargs) + + def update_instance(self, **kwargs): + for field, value in kwargs.items(): + setattr(self.instance, field, value) + self.instance.save() + + +class FieldTrackerCommonTests(object): + + def test_pre_save_previous(self): + self.assertPrevious(name=None, number=None) + self.instance.name = 'new age' + self.instance.number = 8 + self.assertPrevious(name=None, number=None) + + +class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): + + tracked_class = Tracked + + def setUp(self): + self.instance = self.tracked_class() + self.tracker = self.instance.tracker + + def test_descriptor(self): + self.assertTrue(isinstance(self.tracked_class.tracker, FieldTracker)) + + def test_pre_save_changed(self): + self.assertChanged(name=None) + self.instance.name = 'new age' + self.assertChanged(name=None) + self.instance.number = 8 + self.assertChanged(name=None, number=None) + self.instance.name = '' + self.assertChanged(name=None, number=None) + self.instance.mutable = [1, 2, 3] + self.assertChanged(name=None, number=None, mutable=None) + + def test_pre_save_has_changed(self): + self.assertHasChanged(name=True, number=False, mutable=False) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=False, mutable=False) + self.instance.number = 7 + self.assertHasChanged(name=True, number=True) + self.instance.mutable = [1, 2, 3] + self.assertHasChanged(name=True, number=True, mutable=True) + + def test_save_with_args(self): + self.instance.number = 1 + self.instance.save(False, False, None, None) + self.assertChanged() + + def test_first_save(self): + self.assertHasChanged(name=True, number=False, mutable=False) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='', number=None, id=None, mutable=None) + self.assertChanged(name=None) + self.instance.name = 'retro' + self.instance.number = 4 + self.instance.mutable = [1, 2, 3] + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3]) + self.assertChanged(name=None, number=None, mutable=None) + + self.instance.save(update_fields=[]) + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3]) + self.assertChanged(name=None, number=None, mutable=None) + with self.assertRaises(ValueError): + self.instance.save(update_fields=['number']) + + def test_post_save_has_changed(self): + self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) + self.assertHasChanged(name=False, number=False, mutable=False) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=False) + self.instance.number = 8 + self.assertHasChanged(name=True, number=True) + self.instance.mutable[1] = 4 + self.assertHasChanged(name=True, number=True, mutable=True) + self.instance.name = 'retro' + self.assertHasChanged(name=False, number=True, mutable=True) + + def test_post_save_previous(self): + self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) + self.instance.name = 'new age' + self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) + self.instance.mutable[1] = 4 + self.assertPrevious(name='retro', number=4, mutable=[1, 2, 3]) + + def test_post_save_changed(self): + self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) + self.assertChanged() + self.instance.name = 'new age' + self.assertChanged(name='retro') + self.instance.number = 8 + self.assertChanged(name='retro', number=4) + self.instance.name = 'retro' + self.assertChanged(number=4) + self.instance.mutable[1] = 4 + self.assertChanged(number=4, mutable=[1, 2, 3]) + self.instance.mutable = [1, 2, 3] + self.assertChanged(number=4) + + def test_current(self): + self.assertCurrent(id=None, name='', number=None, mutable=None) + self.instance.name = 'new age' + self.assertCurrent(id=None, name='new age', number=None, mutable=None) + self.instance.number = 8 + self.assertCurrent(id=None, name='new age', number=8, mutable=None) + self.instance.mutable = [1, 2, 3] + self.assertCurrent(id=None, name='new age', number=8, mutable=[1, 2, 3]) + self.instance.mutable[1] = 4 + self.assertCurrent(id=None, name='new age', number=8, mutable=[1, 4, 3]) + self.instance.save() + self.assertCurrent(id=self.instance.id, name='new age', number=8, mutable=[1, 4, 3]) + + def test_update_fields(self): + self.update_instance(name='retro', number=4, mutable=[1, 2, 3]) + self.assertChanged() + self.instance.name = 'new age' + self.instance.number = 8 + self.instance.mutable = [4, 5, 6] + self.assertChanged(name='retro', number=4, mutable=[1, 2, 3]) + self.instance.save(update_fields=[]) + self.assertChanged(name='retro', number=4, mutable=[1, 2, 3]) + self.instance.save(update_fields=['name']) + in_db = self.tracked_class.objects.get(id=self.instance.id) + self.assertEqual(in_db.name, self.instance.name) + self.assertNotEqual(in_db.number, self.instance.number) + self.assertChanged(number=4, mutable=[1, 2, 3]) + self.instance.save(update_fields=['number']) + self.assertChanged(mutable=[1, 2, 3]) + self.instance.save(update_fields=['mutable']) + self.assertChanged() + in_db = self.tracked_class.objects.get(id=self.instance.id) + self.assertEqual(in_db.name, self.instance.name) + self.assertEqual(in_db.number, self.instance.number) + self.assertEqual(in_db.mutable, self.instance.mutable) + + def test_with_deferred(self): + self.instance.name = 'new age' + self.instance.number = 1 + self.instance.save() + item = self.tracked_class.objects.only('name').first() + if django.VERSION >= (1, 10): + self.assertTrue(item.get_deferred_fields()) + else: + self.assertTrue(item._deferred_fields) + + # has_changed() returns False for deferred fields, without un-deferring them. + # Use an if because ModelTracked doesn't support has_changed() in this case. + if self.tracked_class == Tracked: + self.assertFalse(item.tracker.has_changed('number')) + if django.VERSION >= (1, 10): + self.assertIsInstance(item.__class__.number, DescriptorWrapper) + self.assertTrue('number' in item.get_deferred_fields()) + else: + self.assertTrue('number' in item._deferred_fields) + + # previous() un-defers field and returns value + self.assertEqual(item.tracker.previous('number'), 1) + if django.VERSION >= (1, 10): + self.assertNotIn('number', item.get_deferred_fields()) + else: + self.assertNotIn('number', item._deferred_fields) + + # examining a deferred field un-defers it + item = self.tracked_class.objects.only('name').first() + self.assertEqual(item.number, 1) + if django.VERSION >= (1, 10): + self.assertTrue('number' not in item.get_deferred_fields()) + else: + self.assertTrue('number' not in item._deferred_fields) + self.assertEqual(item.tracker.previous('number'), 1) + self.assertFalse(item.tracker.has_changed('number')) + + # has_changed() returns correct values after deferred field is examined + self.assertFalse(item.tracker.has_changed('number')) + item.number = 2 + self.assertTrue(item.tracker.has_changed('number')) + + # previous() returns correct value after deferred field is examined + self.assertEqual(item.tracker.previous('number'), 1) + + # assigning to a deferred field un-defers it + # Use an if because ModelTracked doesn't handle this case. + if self.tracked_class == Tracked: + + item = self.tracked_class.objects.only('name').first() + item.number = 2 + + # previous() fetches correct value from database after deferred field is assigned + self.assertEqual(item.tracker.previous('number'), 1) + + # database fetch of previous() value doesn't affect current value + self.assertEqual(item.number, 2) + + # has_changed() returns correct values after deferred field is assigned + self.assertTrue(item.tracker.has_changed('number')) + item.number = 1 + self.assertFalse(item.tracker.has_changed('number')) + + +class FieldTrackerMultipleInstancesTests(TestCase): + + def test_with_deferred_fields_access_multiple(self): + Tracked.objects.create(pk=1, name='foo', number=1) + Tracked.objects.create(pk=2, name='bar', number=2) + + queryset = Tracked.objects.only('id') + + for instance in queryset: + instance.name + + +class FieldTrackedModelCustomTests(FieldTrackerTestCase, + FieldTrackerCommonTests): + + tracked_class = TrackedNotDefault + + def setUp(self): + self.instance = self.tracked_class() + self.tracker = self.instance.name_tracker + + def test_pre_save_changed(self): + self.assertChanged(name=None) + self.instance.name = 'new age' + self.assertChanged(name=None) + self.instance.number = 8 + self.assertChanged(name=None) + self.instance.name = '' + self.assertChanged(name=None) + + def test_first_save(self): + self.assertHasChanged(name=True, number=None) + self.assertPrevious(name=None, number=None) + self.assertCurrent(name='') + self.assertChanged(name=None) + self.instance.name = 'retro' + self.instance.number = 4 + self.assertHasChanged(name=True, number=None) + self.assertPrevious(name=None, number=None) + self.assertCurrent(name='retro') + self.assertChanged(name=None) + + def test_pre_save_has_changed(self): + self.assertHasChanged(name=True, number=None) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=None) + self.instance.number = 7 + self.assertHasChanged(name=True, number=None) + + def test_post_save_has_changed(self): + self.update_instance(name='retro', number=4) + self.assertHasChanged(name=False, number=None) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=None) + self.instance.number = 8 + self.assertHasChanged(name=True, number=None) + self.instance.name = 'retro' + self.assertHasChanged(name=False, number=None) + + def test_post_save_previous(self): + self.update_instance(name='retro', number=4) + self.instance.name = 'new age' + self.assertPrevious(name='retro', number=None) + + def test_post_save_changed(self): + self.update_instance(name='retro', number=4) + self.assertChanged() + self.instance.name = 'new age' + self.assertChanged(name='retro') + self.instance.number = 8 + self.assertChanged(name='retro') + self.instance.name = 'retro' + self.assertChanged() + + def test_current(self): + self.assertCurrent(name='') + self.instance.name = 'new age' + self.assertCurrent(name='new age') + self.instance.number = 8 + self.assertCurrent(name='new age') + self.instance.save() + self.assertCurrent(name='new age') + + def test_update_fields(self): + self.update_instance(name='retro', number=4) + self.assertChanged() + self.instance.name = 'new age' + self.instance.number = 8 + self.instance.save(update_fields=['name', 'number']) + self.assertChanged() + + +class FieldTrackedModelAttributeTests(FieldTrackerTestCase): + + tracked_class = TrackedNonFieldAttr + + def setUp(self): + self.instance = self.tracked_class() + self.tracker = self.instance.tracker + + def test_previous(self): + self.assertPrevious(rounded=None) + self.instance.number = 7.5 + self.assertPrevious(rounded=None) + self.instance.save() + self.assertPrevious(rounded=8) + self.instance.number = 7.2 + self.assertPrevious(rounded=8) + self.instance.save() + self.assertPrevious(rounded=7) + + def test_has_changed(self): + self.assertHasChanged(rounded=False) + self.instance.number = 7.5 + self.assertHasChanged(rounded=True) + self.instance.save() + self.assertHasChanged(rounded=False) + self.instance.number = 7.2 + self.assertHasChanged(rounded=True) + self.instance.number = 7.8 + self.assertHasChanged(rounded=False) + + def test_changed(self): + self.assertChanged() + self.instance.number = 7.5 + self.assertPrevious(rounded=None) + self.instance.save() + self.assertPrevious() + self.instance.number = 7.8 + self.assertPrevious() + self.instance.number = 7.2 + self.assertPrevious(rounded=8) + self.instance.save() + self.assertPrevious() + + def test_current(self): + self.assertCurrent(rounded=None) + self.instance.number = 7.5 + self.assertCurrent(rounded=8) + self.instance.save() + self.assertCurrent(rounded=8) + + +class FieldTrackedModelMultiTests(FieldTrackerTestCase, + FieldTrackerCommonTests): + + tracked_class = TrackedMultiple + + def setUp(self): + self.instance = self.tracked_class() + self.trackers = [self.instance.name_tracker, + self.instance.number_tracker] + + def test_pre_save_changed(self): + self.tracker = self.instance.name_tracker + self.assertChanged(name=None) + self.instance.name = 'new age' + self.assertChanged(name=None) + self.instance.number = 8 + self.assertChanged(name=None) + self.instance.name = '' + self.assertChanged(name=None) + self.tracker = self.instance.number_tracker + self.assertChanged(number=None) + self.instance.name = 'new age' + self.assertChanged(number=None) + self.instance.number = 8 + self.assertChanged(number=None) + + def test_pre_save_has_changed(self): + self.tracker = self.instance.name_tracker + self.assertHasChanged(name=True, number=None) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=None) + self.tracker = self.instance.number_tracker + self.assertHasChanged(name=None, number=False) + self.instance.name = 'new age' + self.assertHasChanged(name=None, number=False) + + def test_pre_save_previous(self): + for tracker in self.trackers: + self.tracker = tracker + super(FieldTrackedModelMultiTests, self).test_pre_save_previous() + + def test_post_save_has_changed(self): + self.update_instance(name='retro', number=4) + self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) + self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) + self.instance.name = 'new age' + self.assertHasChanged(tracker=self.trackers[0], name=True, number=None) + self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) + self.instance.number = 8 + self.assertHasChanged(tracker=self.trackers[0], name=True, number=None) + self.assertHasChanged(tracker=self.trackers[1], name=None, number=True) + self.instance.name = 'retro' + self.instance.number = 4 + self.assertHasChanged(tracker=self.trackers[0], name=False, number=None) + self.assertHasChanged(tracker=self.trackers[1], name=None, number=False) + + def test_post_save_previous(self): + self.update_instance(name='retro', number=4) + self.instance.name = 'new age' + self.instance.number = 8 + self.assertPrevious(tracker=self.trackers[0], name='retro', number=None) + self.assertPrevious(tracker=self.trackers[1], name=None, number=4) + + def test_post_save_changed(self): + self.update_instance(name='retro', number=4) + self.assertChanged(tracker=self.trackers[0]) + self.assertChanged(tracker=self.trackers[1]) + self.instance.name = 'new age' + self.assertChanged(tracker=self.trackers[0], name='retro') + self.assertChanged(tracker=self.trackers[1]) + self.instance.number = 8 + self.assertChanged(tracker=self.trackers[0], name='retro') + self.assertChanged(tracker=self.trackers[1], number=4) + self.instance.name = 'retro' + self.instance.number = 4 + self.assertChanged(tracker=self.trackers[0]) + self.assertChanged(tracker=self.trackers[1]) + + def test_current(self): + self.assertCurrent(tracker=self.trackers[0], name='') + self.assertCurrent(tracker=self.trackers[1], number=None) + self.instance.name = 'new age' + self.assertCurrent(tracker=self.trackers[0], name='new age') + self.assertCurrent(tracker=self.trackers[1], number=None) + self.instance.number = 8 + self.assertCurrent(tracker=self.trackers[0], name='new age') + self.assertCurrent(tracker=self.trackers[1], number=8) + self.instance.save() + self.assertCurrent(tracker=self.trackers[0], name='new age') + self.assertCurrent(tracker=self.trackers[1], number=8) + + +class FieldTrackerForeignKeyTests(FieldTrackerTestCase): + + fk_class = Tracked + tracked_class = TrackedFK + + def setUp(self): + self.old_fk = self.fk_class.objects.create(number=8) + self.instance = self.tracked_class.objects.create(fk=self.old_fk) + + def test_default(self): + self.tracker = self.instance.tracker + self.assertChanged() + self.assertPrevious() + self.assertCurrent(id=self.instance.id, fk_id=self.old_fk.id) + self.instance.fk = self.fk_class.objects.create(number=8) + self.assertChanged(fk_id=self.old_fk.id) + self.assertPrevious(fk_id=self.old_fk.id) + self.assertCurrent(id=self.instance.id, fk_id=self.instance.fk_id) + + def test_custom(self): + self.tracker = self.instance.custom_tracker + self.assertChanged() + self.assertPrevious() + self.assertCurrent(fk_id=self.old_fk.id) + self.instance.fk = self.fk_class.objects.create(number=8) + self.assertChanged(fk_id=self.old_fk.id) + self.assertPrevious(fk_id=self.old_fk.id) + self.assertCurrent(fk_id=self.instance.fk_id) + + def test_custom_without_id(self): + with self.assertNumQueries(1): + self.tracked_class.objects.get() + self.tracker = self.instance.custom_tracker_without_id + self.assertChanged() + self.assertPrevious() + self.assertCurrent(fk=self.old_fk.id) + self.instance.fk = self.fk_class.objects.create(number=8) + self.assertChanged(fk=self.old_fk.id) + self.assertPrevious(fk=self.old_fk.id) + self.assertCurrent(fk=self.instance.fk_id) + + +class InheritedFieldTrackerTests(FieldTrackerTests): + + tracked_class = InheritedTracked + + def test_child_fields_not_tracked(self): + self.name2 = 'test' + self.assertEqual(self.tracker.previous('name2'), None) + self.assertRaises(FieldError, self.tracker.has_changed, 'name2') + + +class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests): + + tracked_class = InheritedTrackedFK + + +class FieldTrackerFileFieldTests(FieldTrackerTestCase): + + tracked_class = TrackedFileField + + def setUp(self): + self.instance = self.tracked_class() + self.tracker = self.instance.tracker + self.some_file = 'something.txt' + self.another_file = 'another.txt' + + def test_pre_save_changed(self): + self.assertChanged(some_file=None) + self.instance.some_file = self.some_file + self.assertChanged(some_file=None) + + def test_pre_save_has_changed(self): + self.assertHasChanged(some_file=True) + self.instance.some_file = self.some_file + self.assertHasChanged(some_file=True) + + def test_pre_save_previous(self): + self.assertPrevious(some_file=None) + self.instance.some_file = self.some_file + self.assertPrevious(some_file=None) + + def test_post_save_changed(self): + self.update_instance(some_file=self.some_file) + self.assertChanged() + previous_file = self.instance.some_file + self.instance.some_file = self.another_file + self.assertChanged(some_file=previous_file) + # test deferred file field + deferred_instance = self.tracked_class.objects.defer('some_file')[0] + deferred_instance.some_file # access field to fetch from database + self.assertChanged(tracker=deferred_instance.tracker) + + previous_file = deferred_instance.some_file + deferred_instance.some_file = self.another_file + self.assertChanged( + tracker=deferred_instance.tracker, + some_file=previous_file, + ) + + def test_post_save_has_changed(self): + self.update_instance(some_file=self.some_file) + self.assertHasChanged(some_file=False) + self.instance.some_file = self.another_file + self.assertHasChanged(some_file=True) + + # test deferred file field + deferred_instance = self.tracked_class.objects.defer('some_file')[0] + deferred_instance.some_file # access field to fetch from database + self.assertHasChanged( + tracker=deferred_instance.tracker, + some_file=False, + ) + + deferred_instance.some_file = self.another_file + self.assertHasChanged( + tracker=deferred_instance.tracker, + some_file=True, + ) + + def test_post_save_previous(self): + self.update_instance(some_file=self.some_file) + previous_file = self.instance.some_file + self.instance.some_file = self.another_file + self.assertPrevious(some_file=previous_file) + + # test deferred file field + deferred_instance = self.tracked_class.objects.defer('some_file')[0] + deferred_instance.some_file # access field to fetch from database + self.assertPrevious( + tracker=deferred_instance.tracker, + some_file=previous_file, + ) + + deferred_instance.some_file = self.another_file + self.assertPrevious( + tracker=deferred_instance.tracker, + some_file=previous_file, + ) + + def test_current(self): + self.assertCurrent(some_file=self.instance.some_file, id=None) + self.instance.some_file = self.some_file + self.assertCurrent(some_file=self.instance.some_file, id=None) + + # test deferred file field + self.instance.save() + deferred_instance = self.tracked_class.objects.defer('some_file')[0] + deferred_instance.some_file # access field to fetch from database + self.assertCurrent( + some_file=self.instance.some_file, + id=self.instance.id, + ) + + self.instance.some_file = self.another_file + self.assertCurrent( + some_file=self.instance.some_file, + id=self.instance.id, + ) + + +class ModelTrackerTests(FieldTrackerTests): + + tracked_class = ModelTracked + + def test_cache_compatible(self): + cache.set('key', self.instance) + instance = cache.get('key') + instance.number = 1 + instance.name = 'cached' + instance.save() + self.assertChanged() + instance.number = 2 + self.assertHasChanged(number=True) + + def test_pre_save_changed(self): + self.assertChanged() + self.instance.name = 'new age' + self.assertChanged() + self.instance.number = 8 + self.assertChanged() + self.instance.name = '' + self.assertChanged() + self.instance.mutable = [1, 2, 3] + self.assertChanged() + + def test_first_save(self): + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='', number=None, id=None, mutable=None) + self.assertChanged() + self.instance.name = 'retro' + self.instance.number = 4 + self.instance.mutable = [1, 2, 3] + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3]) + self.assertChanged() + + self.instance.save(update_fields=[]) + self.assertHasChanged(name=True, number=True, mutable=True) + self.assertPrevious(name=None, number=None, mutable=None) + self.assertCurrent(name='retro', number=4, id=None, mutable=[1, 2, 3]) + self.assertChanged() + with self.assertRaises(ValueError): + self.instance.save(update_fields=['number']) + + def test_pre_save_has_changed(self): + self.assertHasChanged(name=True, number=True) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=True) + self.instance.number = 7 + self.assertHasChanged(name=True, number=True) + + +class ModelTrackedModelCustomTests(FieldTrackedModelCustomTests): + + tracked_class = ModelTrackedNotDefault + + def test_first_save(self): + self.assertHasChanged(name=True, number=True) + self.assertPrevious(name=None, number=None) + self.assertCurrent(name='') + self.assertChanged() + self.instance.name = 'retro' + self.instance.number = 4 + self.assertHasChanged(name=True, number=True) + self.assertPrevious(name=None, number=None) + self.assertCurrent(name='retro') + self.assertChanged() + + def test_pre_save_has_changed(self): + self.assertHasChanged(name=True, number=True) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=True) + self.instance.number = 7 + self.assertHasChanged(name=True, number=True) + + def test_pre_save_changed(self): + self.assertChanged() + self.instance.name = 'new age' + self.assertChanged() + self.instance.number = 8 + self.assertChanged() + self.instance.name = '' + self.assertChanged() + + +class ModelTrackedModelMultiTests(FieldTrackedModelMultiTests): + + tracked_class = ModelTrackedMultiple + + def test_pre_save_has_changed(self): + self.tracker = self.instance.name_tracker + self.assertHasChanged(name=True, number=True) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=True) + self.tracker = self.instance.number_tracker + self.assertHasChanged(name=True, number=True) + self.instance.name = 'new age' + self.assertHasChanged(name=True, number=True) + + def test_pre_save_changed(self): + self.tracker = self.instance.name_tracker + self.assertChanged() + self.instance.name = 'new age' + self.assertChanged() + self.instance.number = 8 + self.assertChanged() + self.instance.name = '' + self.assertChanged() + self.tracker = self.instance.number_tracker + self.assertChanged() + self.instance.name = 'new age' + self.assertChanged() + self.instance.number = 8 + self.assertChanged() + + +class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests): + + fk_class = ModelTracked + tracked_class = ModelTrackedFK + + def test_custom_without_id(self): + with self.assertNumQueries(2): + self.tracked_class.objects.get() + self.tracker = self.instance.custom_tracker_without_id + self.assertChanged() + self.assertPrevious() + self.assertCurrent(fk=self.old_fk) + self.instance.fk = self.fk_class.objects.create(number=8) + self.assertNotEqual(self.instance.fk, self.old_fk) + self.assertChanged(fk=self.old_fk) + self.assertPrevious(fk=self.old_fk) + self.assertCurrent(fk=self.instance.fk) + + +class InheritedModelTrackerTests(ModelTrackerTests): + + tracked_class = InheritedModelTracked + + def test_child_fields_not_tracked(self): + self.name2 = 'test' + self.assertEqual(self.tracker.previous('name2'), None) + self.assertTrue(self.tracker.has_changed('name2')) + + +class AbstractModelTrackerTests(FieldTrackerTestCase): + + tracked_class = TrackedAbstract diff --git a/tests/test_fields/test_monitor_field.py b/tests/test_fields/test_monitor_field.py new file mode 100644 index 0000000..6c5792e --- /dev/null +++ b/tests/test_fields/test_monitor_field.py @@ -0,0 +1,120 @@ +from __future__ import unicode_literals + +from datetime import datetime + +from freezegun import freeze_time + +from django.test import TestCase + +from model_utils.fields import MonitorField +from tests.models import Monitored, MonitorWhen, MonitorWhenEmpty, DoubleMonitored + + +class MonitorFieldTests(TestCase): + def setUp(self): + with freeze_time(datetime(2016, 1, 1, 10, 0, 0)): + self.instance = Monitored(name='Charlie') + self.created = self.instance.name_changed + + def test_save_no_change(self): + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + def test_save_changed(self): + with freeze_time(datetime(2016, 1, 1, 12, 0, 0)): + self.instance.name = 'Maria' + self.instance.save() + self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0)) + + def test_double_save(self): + self.instance.name = 'Jose' + self.instance.save() + changed = self.instance.name_changed + self.instance.save() + self.assertEqual(self.instance.name_changed, changed) + + def test_no_monitor_arg(self): + with self.assertRaises(TypeError): + MonitorField() + + +class MonitorWhenFieldTests(TestCase): + """ + Will record changes only when name is 'Jose' or 'Maria' + """ + def setUp(self): + with freeze_time(datetime(2016, 1, 1, 10, 0, 0)): + self.instance = MonitorWhen(name='Charlie') + self.created = self.instance.name_changed + + def test_save_no_change(self): + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + def test_save_changed_to_Jose(self): + with freeze_time(datetime(2016, 1, 1, 12, 0, 0)): + self.instance.name = 'Jose' + self.instance.save() + self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0)) + + def test_save_changed_to_Maria(self): + with freeze_time(datetime(2016, 1, 1, 12, 0, 0)): + self.instance.name = 'Maria' + self.instance.save() + self.assertEqual(self.instance.name_changed, datetime(2016, 1, 1, 12, 0, 0)) + + def test_save_changed_to_Pedro(self): + self.instance.name = 'Pedro' + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + def test_double_save(self): + self.instance.name = 'Jose' + self.instance.save() + changed = self.instance.name_changed + self.instance.save() + self.assertEqual(self.instance.name_changed, changed) + + +class MonitorWhenEmptyFieldTests(TestCase): + """ + Monitor should never be updated id when is an empty list. + """ + def setUp(self): + self.instance = MonitorWhenEmpty(name='Charlie') + self.created = self.instance.name_changed + + def test_save_no_change(self): + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + def test_save_changed_to_Jose(self): + self.instance.name = 'Jose' + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + def test_save_changed_to_Maria(self): + self.instance.name = 'Maria' + self.instance.save() + self.assertEqual(self.instance.name_changed, self.created) + + +class MonitorDoubleFieldTests(TestCase): + + def setUp(self): + DoubleMonitored.objects.create(name='Charlie', name2='Charlie2') + + def test_recursion_error_with_only(self): + # Any field passed to only() is generating a recursion error + list(DoubleMonitored.objects.only('id')) + + def test_recursion_error_with_defer(self): + # Only monitored fields passed to defer() are failing + list(DoubleMonitored.objects.defer('name')) + + def test_monitor_still_works_with_deferred_fields_filtered_out_of_save_initial(self): + obj = DoubleMonitored.objects.defer('name').get(name='Charlie') + with freeze_time("2016-12-01"): + obj.name = 'Charlie2' + obj.save() + self.assertEqual(obj.name_changed, datetime(2016, 12, 1)) diff --git a/tests/test_fields/test_split_field.py b/tests/test_fields/test_split_field.py new file mode 100644 index 0000000..dfde85f --- /dev/null +++ b/tests/test_fields/test_split_field.py @@ -0,0 +1,78 @@ +from __future__ import unicode_literals + +from django.utils.six import text_type +from django.test import TestCase + +from tests.models import Article, SplitFieldAbstractParent + + +class SplitFieldTests(TestCase): + full_text = 'summary\n\n\n\nmore' + excerpt = 'summary\n' + + def setUp(self): + self.post = Article.objects.create( + title='example post', body=self.full_text) + + def test_unicode_content(self): + self.assertEqual(text_type(self.post.body), self.full_text) + + def test_excerpt(self): + self.assertEqual(self.post.body.excerpt, self.excerpt) + + def test_content(self): + self.assertEqual(self.post.body.content, self.full_text) + + def test_has_more(self): + self.assertTrue(self.post.body.has_more) + + def test_not_has_more(self): + post = Article.objects.create(title='example 2', + body='some text\n\nsome more\n') + self.assertFalse(post.body.has_more) + + def test_load_back(self): + post = Article.objects.get(pk=self.post.pk) + self.assertEqual(post.body.content, self.post.body.content) + self.assertEqual(post.body.excerpt, self.post.body.excerpt) + + def test_assign_to_body(self): + new_text = 'different\n\n\n\nother' + self.post.body = new_text + self.post.save() + self.assertEqual(text_type(self.post.body), new_text) + + def test_assign_to_content(self): + new_text = 'different\n\n\n\nother' + self.post.body.content = new_text + self.post.save() + self.assertEqual(text_type(self.post.body), new_text) + + def test_assign_to_excerpt(self): + with self.assertRaises(AttributeError): + self.post.body.excerpt = 'this should fail' + + def test_access_via_class(self): + with self.assertRaises(AttributeError): + Article.body + + def test_none(self): + a = Article(title='Some Title', body=None) + self.assertEqual(a.body, None) + + def test_assign_splittext(self): + a = Article(title='Some Title') + a.body = self.post.body + self.assertEqual(a.body.excerpt, 'summary\n') + + def test_value_to_string(self): + f = self.post._meta.get_field('body') + self.assertEqual(f.value_to_string(self.post), self.full_text) + + def test_abstract_inheritance(self): + class Child(SplitFieldAbstractParent): + pass + + self.assertEqual( + [f.name for f in Child._meta.fields], + ["id", "content", "_content_excerpt"]) diff --git a/tests/test_fields/test_status_field.py b/tests/test_fields/test_status_field.py new file mode 100644 index 0000000..dc0f223 --- /dev/null +++ b/tests/test_fields/test_status_field.py @@ -0,0 +1,32 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from model_utils.fields import StatusField +from tests.models import ( + Article, StatusFieldDefaultFilled, StatusFieldDefaultNotFilled, + StatusFieldChoicesName, +) + + +class StatusFieldTests(TestCase): + + def test_status_with_default_filled(self): + instance = StatusFieldDefaultFilled() + self.assertEqual(instance.status, instance.STATUS.yes) + + def test_status_with_default_not_filled(self): + instance = StatusFieldDefaultNotFilled() + self.assertEqual(instance.status, instance.STATUS.no) + + def test_no_check_for_status(self): + field = StatusField(no_check_for_status=True) + # this model has no STATUS attribute, so checking for it would error + field.prepare_class(Article) + + def test_get_status_display(self): + instance = StatusFieldDefaultFilled() + self.assertEqual(instance.get_status_display(), "Yes") + + def test_choices_name(self): + StatusFieldChoicesName() diff --git a/tests/test_fields/test_uuid_field.py b/tests/test_fields/test_uuid_field.py new file mode 100644 index 0000000..3a6c739 --- /dev/null +++ b/tests/test_fields/test_uuid_field.py @@ -0,0 +1,40 @@ +from __future__ import unicode_literals + +import uuid + +from django.core.exceptions import ValidationError +from django.test import TestCase + +from model_utils.fields import UUIDField + + +class UUIDFieldTests(TestCase): + + def test_uuid_version_default(self): + instance = UUIDField() + self.assertEqual(instance.default, uuid.uuid4) + + def test_uuid_version_1(self): + instance = UUIDField(version=1) + self.assertEqual(instance.default, uuid.uuid1) + + def test_uuid_version_2_error(self): + self.assertRaises(ValidationError, UUIDField, 'version', 2) + + def test_uuid_version_3(self): + instance = UUIDField(version=3) + self.assertEqual(instance.default, uuid.uuid3) + + def test_uuid_version_4(self): + instance = UUIDField(version=4) + self.assertEqual(instance.default, uuid.uuid4) + + def test_uuid_version_5(self): + instance = UUIDField(version=5) + self.assertEqual(instance.default, uuid.uuid5) + + def test_uuid_version_bellow_min(self): + self.assertRaises(ValidationError, UUIDField, 'version', 0) + + def test_uuid_version_above_max(self): + self.assertRaises(ValidationError, UUIDField, 'version', 6) diff --git a/tests/test_inheritance_iterable.py b/tests/test_inheritance_iterable.py new file mode 100644 index 0000000..884b763 --- /dev/null +++ b/tests/test_inheritance_iterable.py @@ -0,0 +1,22 @@ +from __future__ import unicode_literals + +from unittest import skipIf + +import django +from django.test import TestCase +from django.db.models import Prefetch + +from tests.models import InheritanceManagerTestParent, InheritanceManagerTestChild1 + + +class InheritanceIterableTest(TestCase): + @skipIf(django.VERSION[:2] == (1, 10), "Django 1.10 expects ModelIterable not a subclass of it") + def test_prefetch(self): + qs = InheritanceManagerTestChild1.objects.all().prefetch_related( + Prefetch( + 'normal_field', + queryset=InheritanceManagerTestParent.objects.all(), + to_attr='normal_field_prefetched' + ) + ) + self.assertEquals(qs.count(), 0) diff --git a/tests/test_managers/__init__.py b/tests/test_managers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_managers/test_inheritance_manager.py b/tests/test_managers/test_inheritance_manager.py new file mode 100644 index 0000000..374693e --- /dev/null +++ b/tests/test_managers/test_inheritance_manager.py @@ -0,0 +1,524 @@ +from __future__ import unicode_literals + +from unittest import skipUnless + +import django +from django.db import models +from django.test import TestCase + +from tests.models import ( + InheritanceManagerTestRelated, InheritanceManagerTestGrandChild1, + InheritanceManagerTestGrandChild1_2, InheritanceManagerTestParent, + InheritanceManagerTestChild1, + InheritanceManagerTestChild2, TimeFrame, InheritanceManagerTestChild3, + InheritanceManagerTestChild4, +) + + +class InheritanceManagerTests(TestCase): + def setUp(self): + self.child1 = InheritanceManagerTestChild1.objects.create() + self.child2 = InheritanceManagerTestChild2.objects.create() + self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() + self.grandchild1_2 = \ + InheritanceManagerTestGrandChild1_2.objects.create() + + def get_manager(self): + return InheritanceManagerTestParent.objects + + def test_normal(self): + children = set([ + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + ]) + self.assertEqual(set(self.get_manager().all()), children) + + def test_select_all_subclasses(self): + children = set([self.child1, self.child2]) + children.add(self.grandchild1) + children.add(self.grandchild1_2) + self.assertEqual( + set(self.get_manager().select_subclasses()), children) + + def test_select_subclasses_invalid_relation(self): + """ + If an invalid relation string is provided, we can provide the user + with a list which is valid, rather than just have the select_related() + raise an AttributeError further in. + """ + regex = '^.+? is not in the discovered subclasses, tried:.+$' + with self.assertRaisesRegexp(ValueError, regex): + self.get_manager().select_subclasses('user') + + def test_select_specific_subclasses(self): + children = set([ + self.child1, + InheritanceManagerTestParent(pk=self.child2.pk), + InheritanceManagerTestChild1(pk=self.grandchild1.pk), + InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), + ]) + self.assertEqual( + set( + self.get_manager().select_subclasses( + "inheritancemanagertestchild1") + ), + children, + ) + + def test_select_specific_grandchildren(self): + children = set([ + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk), + self.grandchild1, + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + ]) + self.assertEqual( + set( + self.get_manager().select_subclasses( + "inheritancemanagertestchild1__inheritancemanagertestgrandchild1" + ) + ), + children, + ) + + def test_children_and_grandchildren(self): + children = set([ + self.child1, + InheritanceManagerTestParent(pk=self.child2.pk), + self.grandchild1, + InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), + ]) + self.assertEqual( + set( + self.get_manager().select_subclasses( + "inheritancemanagertestchild1", + "inheritancemanagertestchild1__inheritancemanagertestgrandchild1" + ) + ), + children, + ) + + def test_get_subclass(self): + self.assertEqual( + self.get_manager().get_subclass(pk=self.child1.pk), + self.child1) + + def test_get_subclass_on_queryset(self): + self.assertEqual( + self.get_manager().all().get_subclass(pk=self.child1.pk), + self.child1) + + def test_prior_select_related(self): + with self.assertNumQueries(1): + obj = self.get_manager().select_related( + "inheritancemanagertestchild1").select_subclasses( + "inheritancemanagertestchild2").get(pk=self.child1.pk) + obj.inheritancemanagertestchild1 + + def test_manually_specifying_parent_fk_including_grandchildren(self): + """ + given a Model which inherits from another Model, but also declares + the OneToOne link manually using `related_name` and `parent_link`, + ensure that the relation names and subclasses are obtained correctly. + """ + child3 = InheritanceManagerTestChild3.objects.create() + qs = InheritanceManagerTestParent.objects.all() + results = qs.select_subclasses().order_by('pk') + + expected_objs = [ + self.child1, + self.child2, + self.grandchild1, + self.grandchild1_2, + child3 + ] + self.assertEqual(list(results), expected_objs) + + expected_related_names = [ + 'inheritancemanagertestchild1__inheritancemanagertestgrandchild1', + 'inheritancemanagertestchild1__inheritancemanagertestgrandchild1_2', + 'inheritancemanagertestchild1', + 'inheritancemanagertestchild2', + 'manual_onetoone', # this was set via parent_link & related_name + 'child4_onetoone', + ] + self.assertEqual(set(results.subclasses), + set(expected_related_names)) + + def test_manually_specifying_parent_fk_single_subclass(self): + """ + Using a string related_name when the relation is manually defined + instead of implicit should still work in the same way. + """ + related_name = 'manual_onetoone' + child3 = InheritanceManagerTestChild3.objects.create() + qs = InheritanceManagerTestParent.objects.all() + results = qs.select_subclasses(related_name).order_by('pk') + + expected_objs = [InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + child3] + self.assertEqual(list(results), expected_objs) + expected_related_names = [related_name] + self.assertEqual(set(results.subclasses), + set(expected_related_names)) + + def test_filter_on_values_queryset(self): + queryset = InheritanceManagerTestChild1.objects.values('id').filter(pk=self.child1.pk) + self.assertEqual(list(queryset), [{'id': self.child1.pk}]) + + @skipUnless(django.VERSION >= (1, 9, 0), "test only applies to Django 1.9+") + def test_dj19_values_list_on_select_subclasses(self): + """ + Using `select_subclasses` in conjunction with `values_list()` raised an + exception in `_get_sub_obj_recurse()` because the result of `values_list()` + is either a `tuple` or primitive objects if `flat=True` is specified, + because no type checking was done prior to fetching child nodes. + + Django versions below 1.9 are not affected by this bug. + """ + + # Querysets are cast to lists to force immediate evaluation. + # No exceptions must be thrown. + + # No argument to select_subclasses + objs_1 = list( + self.get_manager() + .select_subclasses() + .values_list('id') + ) + + # String argument to select_subclasses + objs_2 = list( + self.get_manager() + .select_subclasses( + "inheritancemanagertestchild2" + ) + .values_list('id') + ) + + # String argument to select_subclasses + objs_3 = list( + self.get_manager() + .select_subclasses( + InheritanceManagerTestChild2 + ).values_list('id') + ) + + assert all(( + isinstance(objs_1, list), + isinstance(objs_2, list), + isinstance(objs_3, list), + )) + + assert objs_1 == objs_2 == objs_3 + + +class InheritanceManagerUsingModelsTests(TestCase): + def setUp(self): + self.parent1 = InheritanceManagerTestParent.objects.create() + self.child1 = InheritanceManagerTestChild1.objects.create() + self.child2 = InheritanceManagerTestChild2.objects.create() + self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create() + self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create() + + def test_select_subclass_by_child_model(self): + """ + Confirm that passing a child model works the same as passing the + select_related manually + """ + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild1").order_by('pk') + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + InheritanceManagerTestChild1).order_by('pk') + self.assertEqual(objs.subclasses, objsmodels.subclasses) + self.assertEqual(list(objs), list(objsmodels)) + + def test_select_subclass_by_grandchild_model(self): + """ + Confirm that passing a grandchild model works the same as passing the + select_related manually + """ + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild1__inheritancemanagertestgrandchild1") \ + .order_by('pk') + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + InheritanceManagerTestGrandChild1).order_by('pk') + self.assertEqual(objs.subclasses, objsmodels.subclasses) + self.assertEqual(list(objs), list(objsmodels)) + + def test_selecting_all_subclasses_specifically_grandchildren(self): + """ + A bare select_subclasses() should achieve the same results as doing + select_subclasses and specifying all possible subclasses. + This test checks grandchildren, so only works on 1.6>= + """ + objs = InheritanceManagerTestParent.objects.select_subclasses().order_by('pk') + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + InheritanceManagerTestChild1, InheritanceManagerTestChild2, + InheritanceManagerTestChild3, InheritanceManagerTestChild4, + InheritanceManagerTestGrandChild1, + InheritanceManagerTestGrandChild1_2).order_by('pk') + self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) + self.assertEqual(list(objs), list(objsmodels)) + + def test_selecting_all_subclasses_specifically_children(self): + """ + A bare select_subclasses() should achieve the same results as doing + select_subclasses and specifying all possible subclasses. + + Note: This is sort of the same test as + `test_selecting_all_subclasses_specifically_grandchildren` but it + specifically switches what models are used because that happens + behind the scenes in a bare select_subclasses(), so we need to + emulate it. + """ + objs = InheritanceManagerTestParent.objects.select_subclasses().order_by('pk') + + models = (InheritanceManagerTestChild1, + InheritanceManagerTestChild2, + InheritanceManagerTestChild3, + InheritanceManagerTestChild4, + InheritanceManagerTestGrandChild1, + InheritanceManagerTestGrandChild1_2) + + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + *models).order_by('pk') + # order shouldn't matter, I don't think, as long as the resulting + # queryset (when cast to a list) is the same. + self.assertEqual(set(objs.subclasses), set(objsmodels.subclasses)) + self.assertEqual(list(objs), list(objsmodels)) + + def test_select_subclass_just_self(self): + """ + Passing in the same model as the manager/queryset is bound against + (ie: the root parent) should have no effect on the result set. + """ + objsmodels = InheritanceManagerTestParent.objects.select_subclasses( + InheritanceManagerTestParent).order_by('pk') + self.assertEqual([], objsmodels.subclasses) + self.assertEqual(list(objsmodels), [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + ]) + + def test_select_subclass_invalid_related_model(self): + """ + Confirming that giving a stupid model doesn't work. + """ + regex = '^.+? is not a subclass of .+$' + with self.assertRaisesRegexp(ValueError, regex): + InheritanceManagerTestParent.objects.select_subclasses( + TimeFrame).order_by('pk') + + def test_mixing_strings_and_classes_with_grandchildren(self): + """ + Given arguments consisting of both strings and model classes, + ensure the right resolutions take place, accounting for the extra + depth (grandchildren etc) 1.6> allows. + """ + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild2", + InheritanceManagerTestGrandChild1_2).order_by('pk') + expecting = ['inheritancemanagertestchild1__inheritancemanagertestgrandchild1_2', + 'inheritancemanagertestchild2'] + self.assertEqual(set(objs.subclasses), set(expecting)) + expecting2 = [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestChild2(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestGrandChild1_2(pk=self.grandchild1_2.pk), + ] + self.assertEqual(list(objs), expecting2) + + def test_mixing_strings_and_classes_with_children(self): + """ + Given arguments consisting of both strings and model classes, + ensure the right resolutions take place, walking down as far as + children. + """ + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild2", + InheritanceManagerTestChild1).order_by('pk') + expecting = ['inheritancemanagertestchild1', + 'inheritancemanagertestchild2'] + + self.assertEqual(set(objs.subclasses), set(expecting)) + expecting2 = [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestChild1(pk=self.child1.pk), + InheritanceManagerTestChild2(pk=self.child2.pk), + InheritanceManagerTestChild1(pk=self.grandchild1.pk), + InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), + ] + self.assertEqual(list(objs), expecting2) + + def test_duplications(self): + """ + Check that even if the same thing is provided as a string and a model + that the right results are retrieved. + """ + # mixing strings and models which evaluate to the same thing is fine. + objs = InheritanceManagerTestParent.objects.select_subclasses( + "inheritancemanagertestchild2", + InheritanceManagerTestChild2).order_by('pk') + self.assertEqual(list(objs), [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestChild2(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + ]) + + def test_child_doesnt_accidentally_get_parent(self): + """ + Given a Child model which also has an InheritanceManager, + none of the returned objects should be Parent objects. + """ + objs = InheritanceManagerTestChild1.objects.select_subclasses( + InheritanceManagerTestGrandChild1).order_by('pk') + self.assertEqual([ + InheritanceManagerTestChild1(pk=self.child1.pk), + InheritanceManagerTestGrandChild1(pk=self.grandchild1.pk), + InheritanceManagerTestChild1(pk=self.grandchild1_2.pk), + ], list(objs)) + + def test_manually_specifying_parent_fk_only_specific_child(self): + """ + given a Model which inherits from another Model, but also declares + the OneToOne link manually using `related_name` and `parent_link`, + ensure that the relation names and subclasses are obtained correctly. + """ + child3 = InheritanceManagerTestChild3.objects.create() + results = InheritanceManagerTestParent.objects.all().select_subclasses( + InheritanceManagerTestChild3).order_by('pk') + + expected_objs = [ + InheritanceManagerTestParent(pk=self.parent1.pk), + InheritanceManagerTestParent(pk=self.child1.pk), + InheritanceManagerTestParent(pk=self.child2.pk), + InheritanceManagerTestParent(pk=self.grandchild1.pk), + InheritanceManagerTestParent(pk=self.grandchild1_2.pk), + child3 + ] + self.assertEqual(list(results), expected_objs) + + expected_related_names = ['manual_onetoone'] + self.assertEqual(set(results.subclasses), + set(expected_related_names)) + + def test_extras_descend(self): + """ + Ensure that extra(select=) values are copied onto sub-classes. + """ + results = InheritanceManagerTestParent.objects.select_subclasses().extra( + select={'foo': 'id + 1'} + ) + self.assertTrue(all(result.foo == (result.id + 1) for result in results)) + + def test_limit_to_specific_subclass(self): + child3 = InheritanceManagerTestChild3.objects.create() + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3) + + self.assertEqual([child3], list(results)) + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_limit_to_specific_grandchild_class(self): + grandchild1 = InheritanceManagerTestGrandChild1.objects.get() + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestGrandChild1) + + self.assertEqual([grandchild1], list(results)) + + def test_limit_to_child_fetches_grandchildren_as_child_class(self): + # Not sure if this is the desired behaviour...? + children = InheritanceManagerTestChild1.objects.all() + + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild1) + + self.assertEqual(set(children), set(results)) + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_can_fetch_limited_class_grandchildren(self): + # Not sure if this is the desired behaviour...? + children = InheritanceManagerTestChild1.objects.select_subclasses() + + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild1).select_subclasses() + + self.assertEqual(set(children), set(results)) + + def test_selecting_multiple_instance_classes(self): + child3 = InheritanceManagerTestChild3.objects.create() + children1 = InheritanceManagerTestChild1.objects.all() + + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3, InheritanceManagerTestChild1) + + self.assertEqual(set([child3] + list(children1)), set(results)) + + @skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+") + def test_selecting_multiple_instance_classes_including_grandchildren(self): + child3 = InheritanceManagerTestChild3.objects.create() + grandchild1 = InheritanceManagerTestGrandChild1.objects.get() + + results = InheritanceManagerTestParent.objects.instance_of(InheritanceManagerTestChild3, InheritanceManagerTestGrandChild1).select_subclasses() + + self.assertEqual(set([child3, grandchild1]), set(results)) + + def test_select_subclasses_interaction_with_instance_of(self): + child3 = InheritanceManagerTestChild3.objects.create() + + results = InheritanceManagerTestParent.objects.select_subclasses(InheritanceManagerTestChild1).instance_of(InheritanceManagerTestChild3) + + self.assertEqual(set([child3]), set(results)) + + + +class InheritanceManagerRelatedTests(InheritanceManagerTests): + def setUp(self): + self.related = InheritanceManagerTestRelated.objects.create() + self.child1 = InheritanceManagerTestChild1.objects.create( + related=self.related) + self.child2 = InheritanceManagerTestChild2.objects.create( + related=self.related) + self.grandchild1 = InheritanceManagerTestGrandChild1.objects.create(related=self.related) + self.grandchild1_2 = InheritanceManagerTestGrandChild1_2.objects.create(related=self.related) + + def get_manager(self): + return self.related.imtests + + def test_get_method_with_select_subclasses(self): + self.assertEqual( + InheritanceManagerTestParent.objects.select_subclasses().get( + id=self.child1.id), + self.child1) + + def test_annotate_with_select_subclasses(self): + qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( + models.Count('id')) + self.assertEqual(qs.get(id=self.child1.id).id__count, 1) + + def test_annotate_with_named_arguments_with_select_subclasses(self): + qs = InheritanceManagerTestParent.objects.select_subclasses().annotate( + test_count=models.Count('id')) + self.assertEqual(qs.get(id=self.child1.id).test_count, 1) + + def test_annotate_before_select_subclasses(self): + qs = InheritanceManagerTestParent.objects.annotate( + models.Count('id')).select_subclasses() + self.assertEqual(qs.get(id=self.child1.id).id__count, 1) + + def test_annotate_with_named_arguments_before_select_subclasses(self): + qs = InheritanceManagerTestParent.objects.annotate( + test_count=models.Count('id')).select_subclasses() + self.assertEqual(qs.get(id=self.child1.id).test_count, 1) + + def test_clone_when_inheritance_queryset_selects_subclasses_should_clone_them_too(self): + qs = InheritanceManagerTestParent.objects.select_subclasses() + self.assertEqual(qs.subclasses, qs._clone().subclasses) diff --git a/tests/test_managers/test_join_manager.py b/tests/test_managers/test_join_manager.py new file mode 100644 index 0000000..b8a8131 --- /dev/null +++ b/tests/test_managers/test_join_manager.py @@ -0,0 +1,38 @@ + +from django.test import TestCase + +from tests.models import JoinItemForeignKey, BoxJoinModel + + +class JoinManagerTest(TestCase): + def setUp(self): + for i in range(20): + BoxJoinModel.objects.create(name='name_{i}'.format(i=i)) + + JoinItemForeignKey.objects.create( + weight=10, belonging=BoxJoinModel.objects.get(name='name_1') + ) + JoinItemForeignKey.objects.create(weight=20) + + def test_self_join(self): + a_slice = BoxJoinModel.objects.all()[0:10] + with self.assertNumQueries(1): + result = a_slice.join() + self.assertEquals(result.count(), 10) + + def test_self_join_with_where_statement(self): + qs = BoxJoinModel.objects.filter(name='name_1') + result = qs.join() + self.assertEquals(result.count(), 1) + + def test_join_with_other_qs(self): + item_qs = JoinItemForeignKey.objects.filter(weight=10) + boxes = BoxJoinModel.objects.all().join(qs=item_qs) + self.assertEquals(boxes.count(), 1) + self.assertEquals(boxes[0].name, 'name_1') + + def test_reverse_join(self): + box_qs = BoxJoinModel.objects.filter(name='name_1') + items = JoinItemForeignKey.objects.all().join(box_qs) + self.assertEquals(items.count(), 1) + self.assertEquals(items[0].weight, 10) diff --git a/tests/test_managers/test_query_manager.py b/tests/test_managers/test_query_manager.py new file mode 100644 index 0000000..dd539b6 --- /dev/null +++ b/tests/test_managers/test_query_manager.py @@ -0,0 +1,29 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from tests.models import Post + + +class QueryManagerTests(TestCase): + def setUp(self): + data = ((True, True, 0), + (True, False, 4), + (False, False, 2), + (False, True, 3), + (True, True, 1), + (True, False, 5)) + for p, c, o in data: + Post.objects.create(published=p, confirmed=c, order=o) + + def test_passing_kwargs(self): + qs = Post.public.all() + self.assertEqual([p.order for p in qs], [0, 1, 4, 5]) + + def test_passing_Q(self): + qs = Post.public_confirmed.all() + self.assertEqual([p.order for p in qs], [0, 1]) + + def test_ordering(self): + qs = Post.public_reversed.all() + self.assertEqual([p.order for p in qs], [5, 4, 1, 0]) diff --git a/tests/test_managers/test_softdelete_manager.py b/tests/test_managers/test_softdelete_manager.py new file mode 100644 index 0000000..4ae5475 --- /dev/null +++ b/tests/test_managers/test_softdelete_manager.py @@ -0,0 +1,28 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from tests.models import CustomSoftDelete + + +class CustomSoftDeleteManagerTests(TestCase): + + def test_custom_manager_empty(self): + qs = CustomSoftDelete.objects.only_read() + self.assertEqual(qs.count(), 0) + + def test_custom_qs_empty(self): + qs = CustomSoftDelete.objects.all().only_read() + self.assertEqual(qs.count(), 0) + + def test_is_read(self): + for is_read in [True, False, True, False]: + CustomSoftDelete.objects.create(is_read=is_read) + qs = CustomSoftDelete.objects.only_read() + self.assertEqual(qs.count(), 2) + + def test_is_read_removed(self): + for is_read, is_removed in [(True, True), (True, False), (False, False), (False, True)]: + CustomSoftDelete.objects.create(is_read=is_read, is_removed=is_removed) + qs = CustomSoftDelete.objects.only_read() + self.assertEqual(qs.count(), 1) diff --git a/tests/test_managers/test_status_manager.py b/tests/test_managers/test_status_manager.py new file mode 100644 index 0000000..593a547 --- /dev/null +++ b/tests/test_managers/test_status_manager.py @@ -0,0 +1,23 @@ +from __future__ import unicode_literals + +from django.db import models +from django.core.exceptions import ImproperlyConfigured +from django.test import TestCase + +from model_utils.managers import QueryManager +from model_utils.models import StatusModel +from tests.models import StatusManagerAdded + + +class StatusManagerAddedTests(TestCase): + def test_manager_available(self): + self.assertTrue(isinstance(StatusManagerAdded.active, QueryManager)) + + def test_conflict_error(self): + with self.assertRaises(ImproperlyConfigured): + class ErrorModel(StatusModel): + STATUS = ( + ('active', 'Is Active'), + ('deleted', 'Is Deleted'), + ) + active = models.BooleanField() diff --git a/tests/test_miscellaneous.py b/tests/test_miscellaneous.py new file mode 100644 index 0000000..2f34fbb --- /dev/null +++ b/tests/test_miscellaneous.py @@ -0,0 +1,29 @@ +from __future__ import unicode_literals + +from django.core.management import call_command +from django.test import TestCase + +from model_utils.fields import get_excerpt + + +class MigrationsTests(TestCase): + def test_makemigrations(self): + call_command('makemigrations', dry_run=True) + + +class GetExcerptTests(TestCase): + def test_split(self): + e = get_excerpt("some content\n\n\n\nsome more") + self.assertEqual(e, 'some content\n') + + def test_auto_split(self): + e = get_excerpt("para one\n\npara two\n\npara three") + self.assertEqual(e, 'para one\n\npara two') + + def test_middle_of_para(self): + e = get_excerpt("some text\n\nmore text") + self.assertEqual(e, 'some text') + + def test_middle_of_line(self): + e = get_excerpt("some text more text") + self.assertEqual(e, "some text more text") diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models/test_deferred_fields.py b/tests/test_models/test_deferred_fields.py new file mode 100644 index 0000000..ea8e3bd --- /dev/null +++ b/tests/test_models/test_deferred_fields.py @@ -0,0 +1,71 @@ +from __future__ import unicode_literals + +import django +from django.test import TestCase + +from tests.models import ModelWithCustomDescriptor + + +class CustomDescriptorTests(TestCase): + def setUp(self): + self.instance = ModelWithCustomDescriptor.objects.create( + custom_field='1', + tracked_custom_field='1', + regular_field=1, + tracked_regular_field=1, + ) + + def test_custom_descriptor_works(self): + instance = self.instance + self.assertEqual(instance.custom_field, '1') + self.assertEqual(instance.__dict__['custom_field'], 1) + self.assertEqual(instance.regular_field, 1) + instance.custom_field = 2 + self.assertEqual(instance.custom_field, '2') + self.assertEqual(instance.__dict__['custom_field'], 2) + instance.save() + instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk) + self.assertEqual(instance.custom_field, '2') + self.assertEqual(instance.__dict__['custom_field'], 2) + + def test_deferred(self): + instance = ModelWithCustomDescriptor.objects.only('id').get( + pk=self.instance.pk) + if django.VERSION >= (1, 10): + self.assertIn('custom_field', instance.get_deferred_fields()) + else: + self.assertIn('custom_field', instance._deferred_fields) + self.assertEqual(instance.custom_field, '1') + if django.VERSION >= (1, 10): + self.assertNotIn('custom_field', instance.get_deferred_fields()) + else: + self.assertNotIn('custom_field', instance._deferred_fields) + self.assertEqual(instance.regular_field, 1) + self.assertEqual(instance.tracked_custom_field, '1') + self.assertEqual(instance.tracked_regular_field, 1) + + self.assertFalse(instance.tracker.has_changed('tracked_custom_field')) + self.assertFalse(instance.tracker.has_changed('tracked_regular_field')) + + instance.tracked_custom_field = 2 + instance.tracked_regular_field = 2 + self.assertTrue(instance.tracker.has_changed('tracked_custom_field')) + self.assertTrue(instance.tracker.has_changed('tracked_regular_field')) + instance.save() + + instance = ModelWithCustomDescriptor.objects.get(pk=instance.pk) + self.assertEqual(instance.custom_field, '1') + self.assertEqual(instance.regular_field, 1) + self.assertEqual(instance.tracked_custom_field, '2') + self.assertEqual(instance.tracked_regular_field, 2) + + instance = ModelWithCustomDescriptor.objects.only('id').get(pk=instance.pk) + if django.VERSION >= (1, 10): + # This fails on 1.8 and 1.9, which is a bug in the deferred field + # implementation on those versions. + instance.tracked_custom_field = 3 + self.assertEqual(instance.tracked_custom_field, '3') + self.assertTrue(instance.tracker.has_changed('tracked_custom_field')) + del instance.tracked_custom_field + self.assertEqual(instance.tracked_custom_field, '2') + self.assertFalse(instance.tracker.has_changed('tracked_custom_field')) diff --git a/tests/test_models/test_savesignalhandling_model.py b/tests/test_models/test_savesignalhandling_model.py new file mode 100644 index 0000000..6af0820 --- /dev/null +++ b/tests/test_models/test_savesignalhandling_model.py @@ -0,0 +1,44 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from tests.models import SaveSignalHandlingTestModel +from tests.signals import pre_save_test, post_save_test +from django.db.models.signals import pre_save, post_save + + +class SaveSignalHandlingModelTests(TestCase): + + def test_pre_save(self): + pre_save.connect(pre_save_test, sender=SaveSignalHandlingTestModel) + + obj = SaveSignalHandlingTestModel.objects.create(name='Test') + delattr(obj, 'pre_save_runned') + obj.name = 'Test A' + obj.save() + self.assertEqual(obj.name, 'Test A') + self.assertTrue(hasattr(obj, 'pre_save_runned')) + + obj = SaveSignalHandlingTestModel.objects.create(name='Test') + delattr(obj, 'pre_save_runned') + obj.name = 'Test B' + obj.save(signals_to_disable=['pre_save']) + self.assertEqual(obj.name, 'Test B') + self.assertFalse(hasattr(obj, 'pre_save_runned')) + + def test_post_save(self): + post_save.connect(post_save_test, sender=SaveSignalHandlingTestModel) + + obj = SaveSignalHandlingTestModel.objects.create(name='Test') + delattr(obj, 'post_save_runned') + obj.name = 'Test A' + obj.save() + self.assertEqual(obj.name, 'Test A') + self.assertTrue(hasattr(obj, 'post_save_runned')) + + obj = SaveSignalHandlingTestModel.objects.create(name='Test') + delattr(obj, 'post_save_runned') + obj.name = 'Test B' + obj.save(signals_to_disable=['post_save']) + self.assertEqual(obj.name, 'Test B') + self.assertFalse(hasattr(obj, 'post_save_runned')) diff --git a/tests/test_models/test_softdeletable_model.py b/tests/test_models/test_softdeletable_model.py new file mode 100644 index 0000000..5f06fd3 --- /dev/null +++ b/tests/test_models/test_softdeletable_model.py @@ -0,0 +1,52 @@ +from __future__ import unicode_literals + +from django.db.utils import ConnectionDoesNotExist +from django.test import TestCase + +from tests.models import SoftDeletable + + +class SoftDeletableModelTests(TestCase): + def test_can_only_see_not_removed_entries(self): + SoftDeletable.objects.create(name='a', is_removed=True) + SoftDeletable.objects.create(name='b', is_removed=False) + + queryset = SoftDeletable.objects.all() + + self.assertEqual(queryset.count(), 1) + self.assertEqual(queryset[0].name, 'b') + + def test_instance_cannot_be_fully_deleted(self): + instance = SoftDeletable.objects.create(name='a') + + instance.delete() + + self.assertEqual(SoftDeletable.objects.count(), 0) + self.assertEqual(SoftDeletable.all_objects.count(), 1) + + def test_instance_cannot_be_fully_deleted_via_queryset(self): + SoftDeletable.objects.create(name='a') + + SoftDeletable.objects.all().delete() + + self.assertEqual(SoftDeletable.objects.count(), 0) + self.assertEqual(SoftDeletable.all_objects.count(), 1) + + def test_delete_instance_no_connection(self): + obj = SoftDeletable.objects.create(name='a') + + self.assertRaises(ConnectionDoesNotExist, obj.delete, using='other') + + def test_instance_purge(self): + instance = SoftDeletable.objects.create(name='a') + + instance.delete(soft=False) + + self.assertEqual(SoftDeletable.objects.count(), 0) + self.assertEqual(SoftDeletable.all_objects.count(), 0) + + def test_instance_purge_no_connection(self): + instance = SoftDeletable.objects.create(name='a') + + self.assertRaises(ConnectionDoesNotExist, instance.delete, + using='other', soft=False) diff --git a/tests/test_models/test_status_model.py b/tests/test_models/test_status_model.py new file mode 100644 index 0000000..6950dbf --- /dev/null +++ b/tests/test_models/test_status_model.py @@ -0,0 +1,70 @@ +from datetime import datetime + +from freezegun import freeze_time + +from django.test.testcases import TestCase + +from tests.models import Status, StatusPlainTuple, StatusCustomManager + + +class StatusModelTests(TestCase): + def setUp(self): + self.model = Status + self.on_hold = Status.STATUS.on_hold + self.active = Status.STATUS.active + + def test_created(self): + with freeze_time(datetime(2016, 1, 1)): + c1 = self.model.objects.create() + self.assertTrue(c1.status_changed, datetime(2016, 1, 1)) + + self.model.objects.create() + self.assertEqual(self.model.active.count(), 2) + self.assertEqual(self.model.deleted.count(), 0) + + def test_modification(self): + t1 = self.model.objects.create() + date_created = t1.status_changed + t1.status = self.on_hold + t1.save() + self.assertEqual(self.model.active.count(), 0) + self.assertEqual(self.model.on_hold.count(), 1) + self.assertTrue(t1.status_changed > date_created) + date_changed = t1.status_changed + t1.save() + self.assertEqual(t1.status_changed, date_changed) + date_active_again = t1.status_changed + t1.status = self.active + t1.save() + self.assertTrue(t1.status_changed > date_active_again) + + +class StatusModelPlainTupleTests(StatusModelTests): + def setUp(self): + self.model = StatusPlainTuple + self.on_hold = StatusPlainTuple.STATUS[2][0] + self.active = StatusPlainTuple.STATUS[0][0] + + +class StatusModelDefaultManagerTests(TestCase): + + def test_default_manager_is_not_status_model_generated_ones(self): + # Regression test for GH-251 + # The logic behind order for managers seems to have changed in Django 1.10 + # and affects default manager. + # This code was previously failing because the first custom manager (which filters + # with first Choice value, here 'first_choice') generated by StatusModel was + # considered as default manager... + # This situation only happens when we define a model inheriting from an "abstract" + # class which defines an "objects" manager. + + StatusCustomManager.objects.create(status='first_choice') + StatusCustomManager.objects.create(status='second_choice') + StatusCustomManager.objects.create(status='second_choice') + + # ...which made this count() equal to 1 (only 1 element with status='first_choice')... + self.assertEqual(StatusCustomManager._default_manager.count(), 3) + + # ...and this one equal to 0, because of 2 successive filters of 'first_choice' + # (default manager) and 'second_choice' (explicit filter below). + self.assertEqual(StatusCustomManager._default_manager.filter(status='second_choice').count(), 2) diff --git a/tests/test_models/test_timeframed_model.py b/tests/test_models/test_timeframed_model.py new file mode 100644 index 0000000..dccc5a7 --- /dev/null +++ b/tests/test_models/test_timeframed_model.py @@ -0,0 +1,47 @@ +from __future__ import unicode_literals + +from datetime import datetime, timedelta + +from django.db import models +from django.core.exceptions import ImproperlyConfigured +from django.test import TestCase + +from model_utils.managers import QueryManager +from model_utils.models import TimeFramedModel +from tests.models import TimeFrame, TimeFrameManagerAdded + + +class TimeFramedModelTests(TestCase): + def setUp(self): + self.now = datetime.now() + + def test_not_yet_begun(self): + TimeFrame.objects.create(start=self.now + timedelta(days=2)) + self.assertEqual(TimeFrame.timeframed.count(), 0) + + def test_finished(self): + TimeFrame.objects.create(end=self.now - timedelta(days=1)) + self.assertEqual(TimeFrame.timeframed.count(), 0) + + def test_no_end(self): + TimeFrame.objects.create(start=self.now - timedelta(days=10)) + self.assertEqual(TimeFrame.timeframed.count(), 1) + + def test_no_start(self): + TimeFrame.objects.create(end=self.now + timedelta(days=2)) + self.assertEqual(TimeFrame.timeframed.count(), 1) + + def test_within_range(self): + TimeFrame.objects.create(start=self.now - timedelta(days=1), + end=self.now + timedelta(days=1)) + self.assertEqual(TimeFrame.timeframed.count(), 1) + + +class TimeFrameManagerAddedTests(TestCase): + def test_manager_available(self): + self.assertTrue(isinstance(TimeFrameManagerAdded.timeframed, QueryManager)) + + def test_conflict_error(self): + with self.assertRaises(ImproperlyConfigured): + class ErrorModel(TimeFramedModel): + timeframed = models.BooleanField() diff --git a/tests/test_models/test_timestamped_model.py b/tests/test_models/test_timestamped_model.py new file mode 100644 index 0000000..cac07f3 --- /dev/null +++ b/tests/test_models/test_timestamped_model.py @@ -0,0 +1,32 @@ +from __future__ import unicode_literals + +from datetime import datetime + +from freezegun import freeze_time + +from django.test import TestCase + +from tests.models import TimeStamp + + +class TimeStampedModelTests(TestCase): + def test_created(self): + with freeze_time(datetime(2016, 1, 1)): + t1 = TimeStamp.objects.create() + self.assertEqual(t1.created, datetime(2016, 1, 1)) + + def test_created_sets_modified(self): + ''' + Ensure that on creation that modifed is set exactly equal to created. + ''' + t1 = TimeStamp.objects.create() + self.assertEqual(t1.created, t1.modified) + + def test_modified(self): + with freeze_time(datetime(2016, 1, 1)): + t1 = TimeStamp.objects.create() + + with freeze_time(datetime(2016, 1, 2)): + t1.save() + + self.assertEqual(t1.modified, datetime(2016, 1, 2)) diff --git a/tests/test_models/test_uuid_model.py b/tests/test_models/test_uuid_model.py new file mode 100644 index 0000000..5559159 --- /dev/null +++ b/tests/test_models/test_uuid_model.py @@ -0,0 +1,20 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from tests.models import CustomUUIDModel, CustomNotPrimaryUUIDModel + + +class UUIDFieldTests(TestCase): + + def test_uuid_model_with_uuid_field_as_primary_key(self): + instance = CustomUUIDModel() + instance.save() + self.assertEqual(instance.id.__class__.__name__, 'UUID') + self.assertEqual(instance.id, instance.pk) + + def test_uuid_model_with_uuid_field_as_not_primary_key(self): + instance = CustomNotPrimaryUUIDModel() + instance.save() + self.assertEqual(instance.uuid.__class__.__name__, 'UUID') + self.assertNotEqual(instance.uuid, instance.pk) diff --git a/tox.ini b/tox.ini index b754bc7..ca2f2e3 100644 --- a/tox.ini +++ b/tox.ini @@ -1,109 +1,40 @@ [tox] envlist = - py26-1.4, py26-1.5, py26-1.6, - py27-1.4, py27-1.5, py27-1.6, py27-trunk, py27-1.5-nosouth, - py32-1.5, py32-1.6, py32-trunk, - py33-1.5, py33-1.6, py33-trunk + py27-django{111} + py37-django{202,201} + py36-django{111,202,201,trunk} + flake8 [testenv] deps = - South == 0.8.1 - coverage == 3.6 -commands = coverage run -a setup.py test + django111: Django>=1.11,<1.12 + django202: Django>=2.2,<3.0 + django201: Django>=2.1,<2.2 + djangotrunk: https://github.com/django/django/archive/master.tar.gz + freezegun == 0.3.8 + -rrequirements-test.txt +ignore_outcome = + djangotrunk: True +passenv = + CI + TRAVIS + TRAVIS_* -[testenv:py26-1.4] -basepython = python2.6 -deps = - Django == 1.4.10 - South == 0.7.6 - coverage == 3.6 +commands = + pip install -e . + py.test {posargs} -[testenv:py26-1.5] -basepython = python2.6 +[testenv:flake8] +basepython = + python3.6 deps = - Django == 1.5.5 - South == 0.8.1 - coverage == 3.6 + flake8 +commands = + flake8 model_utils tests -[testenv:py26-1.6] -basepython = python2.6 -deps = - https://github.com/django/django/tarball/stable/1.6.x - South == 0.8.1 - coverage == 3.6 - -[testenv:py27-1.4] -basepython = python2.7 -deps = - Django == 1.4.10 - South == 0.8.1 - coverage == 3.6 - -[testenv:py27-1.5] -basepython = python2.7 -deps = - Django == 1.5.5 - South == 0.8.1 - coverage == 3.6 - -[testenv:py27-1.6] -basepython = python2.7 -deps = - Django == 1.6.1 - South == 0.8.1 - coverage == 3.6 - -[testenv:py27-trunk] -basepython = python2.7 -deps = - https://github.com/django/django/tarball/master - South == 0.8.1 - coverage == 3.6 - -[testenv:py27-1.5-nosouth] -basepython = python2.7 -deps = - Django == 1.5.5 - coverage == 3.6 - -[testenv:py32-1.5] -basepython = python3.2 -deps = - Django == 1.5.5 - South == 0.8.1 - coverage == 3.6 - -[testenv:py32-1.6] -basepython = python3.2 -deps = - Django == 1.6.1 - South == 0.8.1 - coverage == 3.6 - -[testenv:py32-trunk] -basepython = python3.2 -deps = - https://github.com/django/django/tarball/master - South == 0.8.1 - coverage == 3.6 - -[testenv:py33-1.5] -basepython = python3.3 -deps = - Django == 1.5.5 - South == 0.8.1 - coverage == 3.6 - -[testenv:py33-1.6] -basepython = python3.3 -deps = - Django == 1.6.1 - South == 0.8.1 - coverage == 3.6 - -[testenv:py33-trunk] -basepython = python3.3 -deps = - https://github.com/django/django/tarball/master - South == 0.8.1 - coverage == 3.6 +[flake8] +ignore = + E731 ; do not assign a lambda expression, use a def + W503 ; line break before binary operator + E402 ; module level import not at top of file + E501 ; line too long diff --git a/translations.py b/translations.py new file mode 100755 index 0000000..8eebf95 --- /dev/null +++ b/translations.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +import os +import sys + +from django.conf import settings +import django + +DEFAULT_SETTINGS = dict( + INSTALLED_APPS=( + 'model_utils', + 'tests', + ), + DATABASES={ + "default": { + "ENGINE": "django.db.backends.sqlite3" + } + }, + SILENCED_SYSTEM_CHECKS=["1_7.W001"], +) + + +def run(command): + if not settings.configured: + settings.configure(**DEFAULT_SETTINGS) + + # Compatibility with Django 1.7's stricter initialization + if hasattr(django, 'setup'): + django.setup() + + parent = os.path.dirname(os.path.abspath(__file__)) + appdir = os.path.join(parent, 'model_utils') + os.chdir(appdir) + + from django.core.management import call_command + + call_command('%smessages' % command) + + +if __name__ == '__main__': + if (len(sys.argv)) < 2 or (sys.argv[1] not in {'make', 'compile'}): + print("Run `translations.py make` or `translations.py compile`.") + sys.exit(1) + run(sys.argv[1])