mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-25 15:50:25 +00:00
Added JSONEncoder().encode(...) check for json-like fields
This commit is contained in:
commit
630741b423
4 changed files with 99 additions and 71 deletions
|
|
@ -6,12 +6,13 @@ Donald Stufft <donald.stufft@gmail.com>
|
|||
Facundo Gaich <facugaich@gmail.com>
|
||||
Felipe Prenholato <philipe.rp@gmail.com>
|
||||
Gregor Müllegger <gregor@muellegger.de>
|
||||
ivirabyan
|
||||
James Oakley <jfunk@funktronics.ca>
|
||||
Jannis Leidel <jannis@leidel.info>
|
||||
Javier García Sogo <jgsogo@gmail.com>
|
||||
Jeff Elmore <jeffelmore.org>
|
||||
Keryn Knight <kerynknight.com>
|
||||
ivirabyan
|
||||
Mikhail Silonov <silonov.pro>
|
||||
Paul McLanahan <paul@mclanahan.net>
|
||||
Rinat Shigapov <rinatshigapov@gmail.com>
|
||||
Ryan Kaskel <dev@ryankaskel.com>
|
||||
|
|
@ -19,4 +20,3 @@ Simon Meers <simon@simonmeers.com>
|
|||
sayane
|
||||
Trey Hunner <trey@treyhunner.com>
|
||||
zyegfryed
|
||||
Mikhail Silonov <silonov.pro>
|
||||
|
|
|
|||
|
|
@ -7,6 +7,11 @@ master (unreleased)
|
|||
* `Choices` now `__contains__` its Python identifier values. Thanks Keryn
|
||||
Knight. (Merge of GH-69).
|
||||
|
||||
=======
|
||||
* Fixed a bug causing ``KeyError`` when saving with the parameter
|
||||
``update_fields`` in which there are untracked fields. Thanks Mikhail
|
||||
Silonov. (Merge of GH-70, fixes GH-71).
|
||||
|
||||
* Added JSON Fields support.
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -106,15 +106,13 @@ class SplitFieldTests(TestCase):
|
|||
|
||||
|
||||
def test_assign_to_excerpt(self):
|
||||
def _invalid_assignment():
|
||||
with self.assertRaises(AttributeError):
|
||||
self.post.body.excerpt = 'this should fail'
|
||||
self.assertRaises(AttributeError, _invalid_assignment)
|
||||
|
||||
|
||||
def test_access_via_class(self):
|
||||
def _invalid_access():
|
||||
with self.assertRaises(AttributeError):
|
||||
Article.body
|
||||
self.assertRaises(AttributeError, _invalid_access)
|
||||
|
||||
|
||||
def test_none(self):
|
||||
|
|
@ -169,7 +167,8 @@ class MonitorFieldTests(TestCase):
|
|||
|
||||
|
||||
def test_no_monitor_arg(self):
|
||||
self.assertRaises(TypeError, MonitorField)
|
||||
with self.assertRaises(TypeError):
|
||||
MonitorField()
|
||||
|
||||
|
||||
class StatusFieldTests(TestCase):
|
||||
|
|
@ -221,7 +220,8 @@ class ChoicesTests(TestCase):
|
|||
|
||||
|
||||
def test_wrong_length_tuple(self):
|
||||
self.assertRaises(ValueError, Choices, ('a',))
|
||||
with self.assertRaises(ValueError):
|
||||
Choices(('a',))
|
||||
|
||||
def test_contains_value(self):
|
||||
self.assertTrue('PUBLISHED' in self.STATUS)
|
||||
|
|
@ -385,23 +385,23 @@ class InheritanceManagerTests(TestCase):
|
|||
)
|
||||
|
||||
|
||||
@skipUnless(django.VERSION >= (1, 6, 0), "test only applies to Django 1.6+")
|
||||
def test_select_specific_grandchildren(self):
|
||||
if django.VERSION >= (1, 6, 0):
|
||||
children = set([
|
||||
InheritanceManagerTestParent(pk=self.child1.pk),
|
||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
||||
self.grandchild1,
|
||||
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
|
||||
])
|
||||
self.assertEqual(
|
||||
set(
|
||||
self.get_manager().select_subclasses(
|
||||
"inheritancemanagertestchild1__"
|
||||
"inheritancemanagertestgrandchild1"
|
||||
)
|
||||
),
|
||||
children,
|
||||
)
|
||||
children = set([
|
||||
InheritanceManagerTestParent(pk=self.child1.pk),
|
||||
InheritanceManagerTestParent(pk=self.child2.pk),
|
||||
self.grandchild1,
|
||||
InheritanceManagerTestParent(pk=self.grandchild1_2.pk),
|
||||
])
|
||||
self.assertEqual(
|
||||
set(
|
||||
self.get_manager().select_subclasses(
|
||||
"inheritancemanagertestchild1__"
|
||||
"inheritancemanagertestgrandchild1"
|
||||
)
|
||||
),
|
||||
children,
|
||||
)
|
||||
|
||||
|
||||
def test_get_subclass(self):
|
||||
|
|
@ -411,13 +411,11 @@ class InheritanceManagerTests(TestCase):
|
|||
|
||||
|
||||
def test_prior_select_related(self):
|
||||
# Django 1.2 doesn't have assertNumQueries
|
||||
if django.VERSION >= (1, 3):
|
||||
with self.assertNumQueries(1):
|
||||
obj = self.get_manager().select_related(
|
||||
"inheritancemanagertestchild1").select_subclasses(
|
||||
"inheritancemanagertestchild2").get(pk=self.child1.pk)
|
||||
obj.inheritancemanagertestchild1
|
||||
with self.assertNumQueries(1):
|
||||
obj = self.get_manager().select_related(
|
||||
"inheritancemanagertestchild1").select_subclasses(
|
||||
"inheritancemanagertestchild2").get(pk=self.child1.pk)
|
||||
obj.inheritancemanagertestchild1
|
||||
|
||||
|
||||
|
||||
|
|
@ -521,10 +519,9 @@ class TimeFrameManagerAddedTests(TestCase):
|
|||
|
||||
|
||||
def test_conflict_error(self):
|
||||
def _run():
|
||||
with self.assertRaises(ImproperlyConfigured):
|
||||
class ErrorModel(TimeFramedModel):
|
||||
timeframed = models.BooleanField()
|
||||
self.assertRaises(ImproperlyConfigured, _run)
|
||||
|
||||
|
||||
|
||||
|
|
@ -575,14 +572,13 @@ class StatusManagerAddedTests(TestCase):
|
|||
|
||||
|
||||
def test_conflict_error(self):
|
||||
def _run():
|
||||
with self.assertRaises(ImproperlyConfigured):
|
||||
class ErrorModel(StatusModel):
|
||||
STATUS = (
|
||||
('active', 'active'),
|
||||
('deleted', 'deleted'),
|
||||
)
|
||||
active = models.BooleanField()
|
||||
self.assertRaises(ImproperlyConfigured, _run)
|
||||
|
||||
|
||||
|
||||
|
|
@ -629,9 +625,8 @@ class SouthFreezingTests(TestCase):
|
|||
|
||||
def test_no_excerpt_field_works(self):
|
||||
from .models import NoRendered
|
||||
self.assertRaises(FieldDoesNotExist,
|
||||
NoRendered._meta.get_field,
|
||||
'_body_excerpt')
|
||||
with self.assertRaises(FieldDoesNotExist):
|
||||
NoRendered._meta.get_field('_body_excerpt')
|
||||
|
||||
def test_status_field_no_check_for_status(self):
|
||||
sf = StatusFieldDefaultFilled._meta.get_field('status')
|
||||
|
|
@ -658,9 +653,8 @@ class PassThroughManagerTests(TestCase):
|
|||
def test_manager_only_methods(self):
|
||||
stats = Dude.abiders.get_stats()
|
||||
self.assertEqual(stats['rug_count'], 1)
|
||||
def notonqs():
|
||||
with self.assertRaises(AttributeError):
|
||||
Dude.abiders.all().get_stats()
|
||||
self.assertRaises(AttributeError, notonqs)
|
||||
|
||||
|
||||
def test_queryset_pickling(self):
|
||||
|
|
@ -716,7 +710,8 @@ class FieldTrackerTestCase(TestCase):
|
|||
tracker = kwargs.pop('tracker', self.tracker)
|
||||
for field, value in kwargs.items():
|
||||
if value is None:
|
||||
self.assertRaises(FieldError, tracker.has_changed, field)
|
||||
with self.assertRaises(FieldError):
|
||||
tracker.has_changed(field)
|
||||
else:
|
||||
self.assertEqual(tracker.has_changed(field), value)
|
||||
|
||||
|
|
@ -793,8 +788,8 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.assertPrevious(name=None, number=None)
|
||||
self.assertCurrent(name='retro', number=4, id=None)
|
||||
self.assertChanged(name=None, number=None)
|
||||
self.assertRaises(ValueError, self.instance.save,
|
||||
update_fields=['number'])
|
||||
with self.assertRaises(ValueError):
|
||||
self.instance.save(update_fields=['number'])
|
||||
|
||||
def test_post_save_has_changed(self):
|
||||
self.update_instance(name='retro', number=4)
|
||||
|
|
@ -830,26 +825,26 @@ class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
|
|||
self.instance.save()
|
||||
self.assertCurrent(id=self.instance.id, name='new age', number=8)
|
||||
|
||||
@skipUnless(
|
||||
django.VERSION >= (1, 5, 0), "Django 1.4 doesn't have update_fields")
|
||||
def test_update_fields(self):
|
||||
# Django 1.4 doesn't have update_fields
|
||||
if django.VERSION >= (1, 5, 0):
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
self.instance.number = 8
|
||||
self.assertChanged(name='retro', number=4)
|
||||
self.instance.save(update_fields=[])
|
||||
self.assertChanged(name='retro', number=4)
|
||||
self.instance.save(update_fields=['name'])
|
||||
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
||||
self.assertEqual(in_db.name, self.instance.name)
|
||||
self.assertNotEqual(in_db.number, self.instance.number)
|
||||
self.assertChanged(number=4)
|
||||
self.instance.save(update_fields=['number'])
|
||||
self.assertChanged()
|
||||
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
||||
self.assertEqual(in_db.name, self.instance.name)
|
||||
self.assertEqual(in_db.number, self.instance.number)
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
self.instance.number = 8
|
||||
self.assertChanged(name='retro', number=4)
|
||||
self.instance.save(update_fields=[])
|
||||
self.assertChanged(name='retro', number=4)
|
||||
self.instance.save(update_fields=['name'])
|
||||
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
||||
self.assertEqual(in_db.name, self.instance.name)
|
||||
self.assertNotEqual(in_db.number, self.instance.number)
|
||||
self.assertChanged(number=4)
|
||||
self.instance.save(update_fields=['number'])
|
||||
self.assertChanged()
|
||||
in_db = self.tracked_class.objects.get(id=self.instance.id)
|
||||
self.assertEqual(in_db.name, self.instance.name)
|
||||
self.assertEqual(in_db.number, self.instance.number)
|
||||
|
||||
|
||||
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
||||
|
|
@ -923,6 +918,16 @@ class FieldTrackedModelCustomTests(FieldTrackerTestCase,
|
|||
self.instance.save()
|
||||
self.assertCurrent(name='new age')
|
||||
|
||||
@skipUnless(
|
||||
django.VERSION >= (1, 5, 0), "Django 1.4 doesn't have update_fields")
|
||||
def test_update_fields(self):
|
||||
self.update_instance(name='retro', number=4)
|
||||
self.assertChanged()
|
||||
self.instance.name = 'new age'
|
||||
self.instance.number = 8
|
||||
self.instance.save(update_fields=['name', 'number'])
|
||||
self.assertChanged()
|
||||
|
||||
|
||||
class JSONFieldTrackedModelTests(FieldTrackerTestCase):
|
||||
|
||||
|
|
@ -1141,7 +1146,7 @@ class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
|
|||
|
||||
|
||||
class FieldTrackedModelMultiTests(FieldTrackerTestCase,
|
||||
FieldTrackerCommonTests):
|
||||
FieldTrackerCommonTests):
|
||||
|
||||
tracked_class = TrackedMultiple
|
||||
|
||||
|
|
@ -1305,8 +1310,8 @@ class ModelTrackerTests(FieldTrackerTests):
|
|||
self.assertPrevious(name=None, number=None)
|
||||
self.assertCurrent(name='retro', number=4, id=None)
|
||||
self.assertChanged()
|
||||
self.assertRaises(ValueError, self.instance.save,
|
||||
update_fields=['number'])
|
||||
with self.assertRaises(ValueError):
|
||||
self.instance.save(update_fields=['number'])
|
||||
|
||||
def test_pre_save_has_changed(self):
|
||||
self.assertHasChanged(name=True, number=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from copy import deepcopy
|
||||
from json import JSONEncoder
|
||||
|
||||
from django.db import models
|
||||
from django.core.exceptions import FieldError
|
||||
|
|
@ -72,7 +73,8 @@ class FieldTracker(object):
|
|||
|
||||
def finalize_class(self, sender, **kwargs):
|
||||
if self.fields is None:
|
||||
self.fields = [field.attname for field in sender._meta.local_fields]
|
||||
self.fields = (field.attname for field in sender._meta.local_fields)
|
||||
self.fields = set(self.fields)
|
||||
self.field_map = self.get_field_map(sender)
|
||||
models.signals.post_init.connect(self.initialize_tracker, sender=sender)
|
||||
setattr(sender, self.name, self)
|
||||
|
|
@ -81,22 +83,38 @@ class FieldTracker(object):
|
|||
tracker = self.tracker_class(instance, self.fields, self.field_map)
|
||||
setattr(instance, self.attname, tracker)
|
||||
saved_data = tracker.set_saved_fields()
|
||||
self.prevent_side_effects(saved_data)
|
||||
self.prevent_json_fields_side_effects(saved_data)
|
||||
self.patch_save(instance)
|
||||
|
||||
def patch_save(self, instance):
|
||||
original_save = instance.save
|
||||
def save(**kwargs):
|
||||
ret = original_save(**kwargs)
|
||||
update_fields = kwargs.get('update_fields')
|
||||
if not update_fields and update_fields is not None: # () or []
|
||||
fields = update_fields
|
||||
elif update_fields is None:
|
||||
fields = None
|
||||
else:
|
||||
fields = (
|
||||
field for field in update_fields if
|
||||
field in self.fields
|
||||
)
|
||||
getattr(instance, self.attname).set_saved_fields(
|
||||
fields=kwargs.get('update_fields'))
|
||||
fields=fields
|
||||
)
|
||||
return ret
|
||||
instance.save = save
|
||||
|
||||
def prevent_side_effects(self, saved_data):
|
||||
def prevent_json_fields_side_effects(self, saved_data):
|
||||
for field, field_value in saved_data.items():
|
||||
if isinstance(field_value, dict):
|
||||
saved_data[field] = deepcopy(field_value)
|
||||
if isinstance(field_value, (dict, list, tuple)):
|
||||
try:
|
||||
JSONEncoder().encode(field_value)
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
saved_data[field] = deepcopy(field_value)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
if instance is None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue