diff --git a/eav/exceptions.py b/eav/exceptions.py new file mode 100644 index 0000000..924a8c0 --- /dev/null +++ b/eav/exceptions.py @@ -0,0 +1,2 @@ +class IllegalAssignmentException(Exception): + pass diff --git a/eav/models.py b/eav/models.py index 6d72d9c..25639d6 100644 --- a/eav/models.py +++ b/eav/models.py @@ -9,6 +9,7 @@ This module defines the four concrete, non-abstract models: Along with the :class:`Entity` helper class. ''' +from copy import copy from django.conf import settings from django.contrib.contenttypes import fields as generic @@ -18,6 +19,7 @@ from django.db import models from django.utils import timezone from django.utils.translation import ugettext_lazy as _ +from .exceptions import IllegalAssignmentException from .fields import EavDatatypeField, EavSlugField from .validators import * @@ -440,20 +442,10 @@ class Entity(object): The helper class that will be attached to any entity registered with eav. ''' - @staticmethod - def post_save_handler(sender, *args, **kwargs): - ''' - Post save handler attached to self.model. Calls :meth:`save` when - the model instance we are attached to is saved. - ''' - instance = kwargs['instance'] - entity = getattr(instance, instance._eav_config_cls.eav_attr) - entity.save() - @staticmethod def pre_save_handler(sender, *args, **kwargs): ''' - Pre save handler attached to self.model. Called before the + Pre save handler attached to self.instance. Called before the model instance we are attached to is saved. This allows us to call :meth:`validate_attributes` before the entity is saved. ''' @@ -461,12 +453,22 @@ class Entity(object): entity = getattr(kwargs['instance'], instance._eav_config_cls.eav_attr) entity.validate_attributes() + @staticmethod + def post_save_handler(sender, *args, **kwargs): + ''' + Post save handler attached to self.instance. Calls :meth:`save` when + the model instance we are attached to is saved. + ''' + instance = kwargs['instance'] + entity = getattr(instance, instance._eav_config_cls.eav_attr) + entity.save() + def __init__(self, instance): ''' - Set self.model equal to the instance of the model that we're attached + Set self.instance equal to the instance of the model that we're attached to. Also, store the content type of that instance. ''' - self.model = instance + self.instance = instance self.ct = ContentType.objects.get_for_model(instance) def __getattr__(self, name): @@ -487,7 +489,7 @@ class Entity(object): except Attribute.DoesNotExist: raise AttributeError( _('%(obj)s has no EAV attribute named %(attr)s') - % dict(obj = self.model, attr = name) + % dict(obj = self.instance, attr = name) ) try: @@ -502,7 +504,7 @@ class Entity(object): Return a query set of all :class:`Attribute` objects that can be set for this entity. ''' - return self.model._eav_config_cls.get_attributes().order_by('display_order') + return self.instance._eav_config_cls.get_attributes().order_by('display_order') def _hasattr(self, attribute_slug): ''' @@ -527,30 +529,32 @@ class Entity(object): for attribute in self.get_all_attributes(): if self._hasattr(attribute.slug): attribute_value = self._getattr(attribute.slug) - attribute.save_value(self.model, attribute_value) + attribute.save_value(self.instance, attribute_value) def validate_attributes(self): ''' Called before :meth:`save`, first validate all the entity values to make sure they can be created / saved cleanly. - - Raise ``ValidationError`` if they can't be. + Raises ``ValidationError`` if they can't be. ''' values_dict = self.get_values_dict() for attribute in self.get_all_attributes(): value = None + # Value was assigned to this instance. if self._hasattr(attribute.slug): value = self._getattr(attribute.slug) + values_dict.pop(attribute.slug, None) + # Otherwise try pre-loaded from DB. else: - value = values_dict.get(attribute.slug, None) + value = values_dict.pop(attribute.slug, None) if value is None: if attribute.required: - raise ValidationError(_( - '{} EAV field cannot be blank'.format(attribute.slug) - )) + raise ValidationError( + _('{} EAV field cannot be blank'.format(attribute.slug)) + ) else: try: attribute.validate_value(value) @@ -560,28 +564,32 @@ class Entity(object): % dict(attr = attribute.slug, err = e) ) + illegal = values_dict or ( + self.get_object_attributes() - self.get_all_attribute_slugs()) + + if illegal: + raise IllegalAssignmentException( + 'Instance of the class {} cannot have values for attributes: {}.' + .format(self.instance.__class__, ', '.join(illegal)) + ) + def get_values_dict(self): - values_dict = dict() - - for value in self.get_values(): - values_dict[value.attribute.slug] = value.value - - return values_dict + return {v.attribute.slug: v.value for v in self.get_values()} def get_values(self): ''' - Get all set :class:`Value` objects for self.model + Get all set :class:`Value` objects for self.instance ''' return Value.objects.filter( - entity_ct=self.ct, - entity_id=self.model.pk + entity_ct = self.ct, + entity_id = self.instance.pk ).select_related() def get_all_attribute_slugs(self): ''' Returns a list of slugs for all attributes available to this entity. ''' - return self.get_all_attributes().values_list('slug', flat=True) + return set(self.get_all_attributes().values_list('slug', flat=True)) def get_attribute_by_slug(self, slug): ''' @@ -595,6 +603,13 @@ class Entity(object): ''' return self.get_values().get(attribute=attribute) + def get_object_attributes(self): + ''' + Returns entity instance attributes, except for + ``instance`` and ``ct`` which are used internally. + ''' + return set(copy(self.__dict__).keys()) - set(['instance', 'ct']) + def __iter__(self): ''' Iterate over set eav values. This would allow you to do:: diff --git a/eav/registry.py b/eav/registry.py index 73a97f8..eef4c24 100644 --- a/eav/registry.py +++ b/eav/registry.py @@ -87,7 +87,7 @@ class Registry(object): @staticmethod def attach_eav_attr(sender, *args, **kwargs): ''' - Attache EAV Entity toolkit to an instance after init. + Attach EAV Entity toolkit to an instance after init. ''' instance = kwargs['instance'] config_cls = instance.__class__._eav_config_cls @@ -131,19 +131,22 @@ class Registry(object): def _attach_signals(self): ''' - Attach all signals for eav + Attach pre- and post- save signals from model class + to Entity helper. This way, Entity instance will be + able to prepare and clean-up before and after creation / + update of the user's model class instance. ''' - post_init.connect(Registry.attach_eav_attr, sender=self.model_cls) - pre_save.connect(Entity.pre_save_handler, sender=self.model_cls) - post_save.connect(Entity.post_save_handler, sender=self.model_cls) + post_init.connect(Registry.attach_eav_attr, sender = self.model_cls) + pre_save.connect(Entity.pre_save_handler, sender = self.model_cls) + post_save.connect(Entity.post_save_handler, sender = self.model_cls) def _detach_signals(self): ''' - Detach all signals for eav + Detach all signals for eav. ''' - post_init.disconnect(Registry.attach_eav_attr, sender=self.model_cls) - pre_save.disconnect(Entity.pre_save_handler, sender=self.model_cls) - post_save.disconnect(Entity.post_save_handler, sender=self.model_cls) + post_init.disconnect(Registry.attach_eav_attr, sender = self.model_cls) + pre_save.disconnect(Entity.pre_save_handler, sender = self.model_cls) + post_save.disconnect(Entity.post_save_handler, sender = self.model_cls) def _attach_generic_relation(self): ''' diff --git a/tests/attributes.py b/tests/attributes.py index 8ac5f51..d276c67 100644 --- a/tests/attributes.py +++ b/tests/attributes.py @@ -4,6 +4,7 @@ from django.test import TestCase import eav from eav.models import Attribute, Value from eav.registry import EavConfig +from eav.exceptions import IllegalAssignmentException from .models import Encounter, Patient @@ -54,7 +55,6 @@ class Attributes(TestCase): p.eav.height = 2.3 p.save() e.eav_field.age = 4 - e.eav_field.height = 4.5 e.save() self.assertEqual(Value.objects.count(), 3) p = Patient.objects.get(name='Jon') @@ -62,4 +62,19 @@ class Attributes(TestCase): self.assertEqual(p.eav.height, 2.3) e = Encounter.objects.get(num=1) self.assertEqual(e.eav_field.age, 4) - self.assertFalse(hasattr(e.eav_field, 'height')) + + def test_illegal_assignemnt(self): + class EncounterEavConfig(EavConfig): + @classmethod + def get_attributes(cls): + return Attribute.objects.filter(datatype=Attribute.TYPE_INT) + + eav.unregister(Encounter) + eav.register(Encounter, EncounterEavConfig) + + p = Patient.objects.create(name='Jon') + e = Encounter.objects.create(patient=p, num=1) + + with self.assertRaises(IllegalAssignmentException): + e.eav.color = 'red' + e.save()