diff --git a/model_utils/__init__.py b/model_utils/__init__.py index 2b43cf1..6c31496 100644 --- a/model_utils/__init__.py +++ b/model_utils/__init__.py @@ -65,3 +65,60 @@ class ChoiceEnum(object): ', '.join(("'%s'" % i[1] for i in self._choices))) +class Choices(object): + """ + A class to encapsulate handy functionality for lists of choices + for a Django model field. + + Accepts verbose choice names as arguments, and automatically + assigns numeric keys to them. When iterated over, behaves as the + standard Django choices tuple of two-tuples. + + Attribute access allows conversion of verbose choice name to + choice key, dictionary access the reverse. + + Example: + + >>> STATUS = Choices('DRAFT', 'PUBLISHED') + >>> STATUS.draft + DRAFT + >>> STATUS[1] + 'PUBLISHED' + >>> tuple(STATUS) + (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED')) + + >>> STATUS = Choices(('DRAFT', 'is a draft'), ('PUBLISHED', 'is published')) + >>> STATUS.draft + DRAFT + >>> tuple(STATUS) + (('DRAFT', 'is a draft'), ('PUBLISHED', 'is published')) + + """ + + def __init__(self, *choices): + self._choices = tuple(self.equalize(choices)) + self._choice_dict = dict(self._choices) + self._reverse_dict = dict(((i[0].lower(), i[0]) for i in self._choices)) + + def equalize(self, choices): + for choice in choices: + if isinstance(choice, (list, tuple)): + yield choice + else: + yield (choice, choice) + + def __iter__(self): + return iter(self._choices) + + def __getattr__(self, attname): + try: + return self._reverse_dict[attname] + except KeyError: + raise AttributeError(attname) + + def __getitem__(self, key): + return self._choices[key][0] + + def __repr__(self): + return '%s(%s)' % (self.__class__.__name__, + ', '.join(("'%s'" % i[0] for i in self._choices))) diff --git a/model_utils/fields.py b/model_utils/fields.py index ec24b2d..e0b435b 100644 --- a/model_utils/fields.py +++ b/model_utils/fields.py @@ -3,6 +3,8 @@ from datetime import datetime from django.db import models from django.conf import settings +from model_utils import Choices + class AutoCreatedField(models.DateTimeField): """ A DateTimeField that automatically populates itself at @@ -30,7 +32,7 @@ class AutoLastModifiedField(AutoCreatedField): return value -def _previous_condition(model_instance, attname, add): +def _previous_status(model_instance, attname, add): if add: return None pk_value = getattr(model_instance, model_instance._meta.pk.attname) @@ -40,49 +42,53 @@ def _previous_condition(model_instance, attname, add): return None return getattr(current, attname, None) -class ConditionField(models.PositiveIntegerField): +class StatusField(models.CharField): """ - A PositiveIntegerField that has set conditional choices by default. + A CharField that has set status choices by default. """ + def __init__(self, *args, **kwargs): + kwargs.setdefault('max_length', 100) + super(StatusField, self).__init__(*args, **kwargs) + def contribute_to_class(self, cls, name): if not cls._meta.abstract: - assert not not hasattr(cls, 'CONDITIONS'), "The model '%s' doesn't have conditions set." % cls.__name__ - setattr(self, '_choices', cls.CONDITIONS) - setattr(self, 'default', tuple(cls.CONDITIONS)[0][0]) # sets first as default - super(ConditionField, self).contribute_to_class(cls, name) + assert hasattr(cls, 'STATUS'), "The model '%s' doesn't have status choices set." % cls.__name__ + assert isinstance(cls.STATUS, Choices), "The status choices of model '%s' isn't a subclass of %s" % (cls.__name__, Choices) + setattr(self, '_choices', cls.STATUS) + setattr(self, 'default', tuple(cls.STATUS)[0][0]) # sets first as default + super(StatusField, self).contribute_to_class(cls, name) def pre_save(self, model_instance, add): - previous = _previous_condition(model_instance, 'get_%s_display' % self.attname, add) + previous = _previous_status(model_instance, 'get_%s_display' % self.attname, add) if previous: previous = previous() - setattr(model_instance, 'previous_condition', previous) - return super(ConditionField, self).pre_save(model_instance, add) + setattr(model_instance, 'previous_status', previous) + return super(StatusField, self).pre_save(model_instance, add) -class ConditionModifedField(models.DateTimeField): +class StatusModifedField(models.DateTimeField): def __init__(self, *args, **kwargs): kwargs.setdefault('default', datetime.now) - depends_on = kwargs.pop('depends_on', 'condition') + depends_on = kwargs.pop('depends_on', 'status') if not depends_on: raise TypeError( '%s requires a depends_on parameter' % self.__class__.__name__) self.depends_on = depends_on - super(ConditionModifedField, self).__init__(*args, **kwargs) + super(StatusModifedField, self).__init__(*args, **kwargs) def contribute_to_class(self, cls, name): - #print cls._meta, cls - assert not getattr(cls._meta, "has_condition_modified_field", False), "A model can't have more than one ConditionModifedField." - super(ConditionModifedField, self).contribute_to_class(cls, name) - setattr(cls._meta, "has_condition_modified_field", True) + assert not getattr(cls._meta, "has_status_modified_field", False), "A model can't have more than one StatusModifedField." + super(StatusModifedField, self).contribute_to_class(cls, name) + setattr(cls._meta, "has_status_modified_field", True) def pre_save(self, model_instance, add): value = datetime.now() - previous = _previous_condition(model_instance, self.depends_on, add) + previous = _previous_status(model_instance, self.depends_on, add) current = getattr(model_instance, self.depends_on, None) if (previous and (previous != current)) or (current and not previous): setattr(model_instance, self.attname, value) - return super(ConditionModifedField, self).pre_save(model_instance, add) + return super(StatusModifedField, self).pre_save(model_instance, add) SPLIT_MARKER = getattr(settings, 'SPLIT_MARKER', '') diff --git a/model_utils/models.py b/model_utils/models.py index 6aeb1d2..4f08f0b 100644 --- a/model_utils/models.py +++ b/model_utils/models.py @@ -8,7 +8,7 @@ from django.db.models.fields import FieldDoesNotExist from model_utils.managers import QueryManager from model_utils.fields import AutoCreatedField, AutoLastModifiedField, \ - ConditionField, ConditionModifedField + StatusField, StatusModifedField class InheritanceCastModel(models.Model): """ @@ -63,59 +63,59 @@ class TimeFramedBaseModel(ModelBase): except FieldDoesNotExist: pass cls.add_to_class('timeframed', QueryManager( - (models.Q(starts__lte=datetime.now()) | models.Q(starts__isnull=True)) & - (models.Q(ends__gte=datetime.now()) | models.Q(ends__isnull=True)) + (models.Q(start__lte=datetime.now()) | models.Q(start__isnull=True)) & + (models.Q(end__gte=datetime.now()) | models.Q(end__isnull=True)) )) class TimeFramedModel(models.Model): """ - An abstract base class model that provides ``starts`` - and ``ends`` fields to record a timeframe. + An abstract base class model that provides ``start`` + and ``end`` fields to record a timeframe. """ __metaclass__ = TimeFramedBaseModel - starts = models.DateTimeField(_('starts'), null=True, blank=True) - ends = models.DateTimeField(_('ends'), null=True, blank=True) + start = models.DateTimeField(_('start'), null=True, blank=True) + end = models.DateTimeField(_('end'), null=True, blank=True) class Meta: abstract = True -class ConditionalBaseModel(ModelBase): +class StatusBaseModel(ModelBase): """ - A model base class for the ``ConditionalModel`` to add - a series of model managers for each given condition. + A model base class for the ``StatusModel`` to add + a series of model managers for each given status. """ def _prepare(cls): - super(ConditionalBaseModel, cls)._prepare() - conditions = getattr(cls, 'CONDITIONS', None) - if conditions is None: + super(StatusBaseModel, cls)._prepare() + status = getattr(cls, 'STATUS', ()) + if status is None: return - for value, name in conditions._choices: + for value, name in status: try: cls._meta.get_field(name) raise ValueError("Model %s has a field named '%s' and " - "conflicts with a condition." + "conflicts with a status." % (cls.__name__, name)) except FieldDoesNotExist: pass - cls.add_to_class(name, QueryManager(**{'condition': value})) + cls.add_to_class(value, QueryManager(**{'status': value})) -class ConditionalModel(models.Model): +class StatusModel(models.Model): """ An abstract base class model that provides self-updating - condition fields like ``deleted`` and ``restored``. + status fields like ``deleted`` and ``restored``. """ - __metaclass__ = ConditionalBaseModel + __metaclass__ = StatusBaseModel - condition = ConditionField(_('condition')) - condition_date = ConditionModifedField(_('condition date')) + status = StatusField(_('status')) + status_date = StatusModifedField(_('status date')) def __unicode__(self): - return self.get_condition_display() + return self.get_status_display() class Meta: abstract = True diff --git a/model_utils/tests/models.py b/model_utils/tests/models.py index ae31bfe..8e49695 100644 --- a/model_utils/tests/models.py +++ b/model_utils/tests/models.py @@ -1,10 +1,10 @@ from django.db import models from django.utils.translation import ugettext_lazy as _ -from model_utils.models import InheritanceCastModel, TimeStampedModel, ConditionalModel, TimeFramedModel +from model_utils.models import InheritanceCastModel, TimeStampedModel, StatusModel, TimeFramedModel from model_utils.managers import QueryManager from model_utils.fields import SplitField -from model_utils import ChoiceEnum +from model_utils import Choices class InheritParent(InheritanceCastModel): pass @@ -18,13 +18,14 @@ class TimeStamp(TimeStampedModel): class TimeFrame(TimeFramedModel): pass -class Condition(ConditionalModel): - CONDITIONS = ChoiceEnum( +class Status(StatusModel): + STATUS = Choices( ('active', _('active')), ('deleted', _('deleted')), ('on_hold', _('on hold')), ) + class Post(models.Model): published = models.BooleanField() confirmed = models.BooleanField() diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index dbeaa1f..aeb4c5d 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -5,10 +5,10 @@ from django.conf import settings from django.contrib.contenttypes.models import ContentType from django.db.models.fields import FieldDoesNotExist -from model_utils import ChoiceEnum +from model_utils import ChoiceEnum, Choices from model_utils.fields import get_excerpt from model_utils.tests.models import InheritParent, InheritChild, TimeStamp, \ - Post, Article, Condition, TimeFrame + Post, Article, Status, TimeFrame class GetExcerptTests(TestCase): @@ -106,6 +106,40 @@ class LabelChoiceEnumTests(ChoiceEnumTests): def test_display(self): self.assertEquals(self.STATUS.get_deleted_display(), 'DELETED') +class ChoicesTests(TestCase): + def setUp(self): + self.STATUS = Choices('DRAFT', 'PUBLISHED') + + def test_getattr(self): + self.assertEquals(self.STATUS.draft, 'DRAFT') + + def test_getitem(self): + self.assertEquals(self.STATUS[1], 'PUBLISHED') + + def test_iteration(self): + self.assertEquals(tuple(self.STATUS), (('DRAFT', 'DRAFT'), ('PUBLISHED', 'PUBLISHED'))) + + def test_display(self): + self.assertEquals(self.STATUS.draft, 'DRAFT') + +class LabelChoicesTests(ChoicesTests): + def setUp(self): + self.STATUS = Choices( + ('DRAFT', 'draft'), + ('PUBLISHED', 'published'), + 'DELETED', + ) + + def test_iteration(self): + self.assertEquals(tuple(self.STATUS), ( + ('DRAFT', 'draft'), + ('PUBLISHED', 'published'), + ('DELETED', 'DELETED')) + ) + + def test_display(self): + self.assertEquals(self.STATUS.deleted, 'DELETED') + class InheritanceCastModelTests(TestCase): def setUp(self): self.parent = InheritParent.objects.create() @@ -142,44 +176,49 @@ class TimeFramedModelTests(TestCase): def testCreated(self): now = datetime.now() # objects are out of the timeframe - TimeFrame.objects.create(starts=now+timedelta(days=2)) - TimeFrame.objects.create(ends=now-timedelta(days=1)) + TimeFrame.objects.create(start=now+timedelta(days=2)) + TimeFrame.objects.create(end=now-timedelta(days=1)) self.assertEquals(TimeFrame.timeframed.count(), 0) # objects in the timeframe for various reasons - TimeFrame.objects.create(starts=now-timedelta(days=10)) - TimeFrame.objects.create(ends=now+timedelta(days=2)) - TimeFrame.objects.create(starts=now-timedelta(days=1), ends=now+timedelta(days=1)) + TimeFrame.objects.create(start=now-timedelta(days=10)) + TimeFrame.objects.create(end=now+timedelta(days=2)) + TimeFrame.objects.create(start=now-timedelta(days=1), end=now+timedelta(days=1)) self.assertEquals(TimeFrame.timeframed.count(), 3) -class ConditionalModelTests(TestCase): +class StatusModelTests(TestCase): + def testCreated(self): - c1 = Condition.objects.create() - c2 = Condition.objects.create() - self.assert_(c2.condition_date > c1.condition_date) - self.assertEquals(Condition.active.count(), 2) + c1 = Status.objects.create() + c2 = Status.objects.create() + self.assert_(c2.status_date > c1.status_date) + self.assertEquals(Status.active.count(), 2) + self.assertEquals(Status.deleted.count(), 0) def testModification(self): - t1 = Condition.objects.create() - date_created = t1.condition_date - t1.condition = t1.CONDITIONS.on_hold + t1 = Status.objects.create() + date_created = t1.status_date + t1.status = t1.STATUS.on_hold t1.save() - self.assert_(t1.condition_date > date_created) - date_changed = t1.condition_date + self.assertEquals(Status.active.count(), 0) + self.assertEquals(Status.on_hold.count(), 1) + self.assert_(t1.status_date > date_created) + date_changed = t1.status_date t1.save() - self.assertEquals(t1.condition_date, date_changed) - date_active_again = t1.condition_date - t1.condition = t1.CONDITIONS.active + self.assertEquals(t1.status_date, date_changed) + date_active_again = t1.status_date + t1.status = t1.STATUS.active t1.save() - self.assert_(t1.condition_date > date_active_again) + self.assert_(t1.status_date > date_active_again) def testPreviousConditon(self): - c = Condition.objects.create() - self.assertEquals(c.previous_condition, None) - c.condition = c.CONDITIONS.on_hold - c.save() - self.assertEquals(c.previous_condition, c.CONDITIONS.get_active_display()) + status = Status.objects.create() + self.assertEquals(status.previous_status, None) + status.status = status.STATUS.on_hold + status.save() + self.assertEquals(status.previous_status, status.STATUS.active) + class QueryManagerTests(TestCase): def setUp(self):