diff --git a/model_utils/__init__.py b/model_utils/__init__.py index 18991e7..4b39185 100644 --- a/model_utils/__init__.py +++ b/model_utils/__init__.py @@ -20,14 +20,32 @@ class ChoiceEnum(object): >>> tuple(STATUS) ((0, 'DRAFT'), (1, 'PUBLISHED')) + >>> STATUS = ChoiceEnum(('DRAFT', 'draft'), ('PUBLISHED', 'published')) + >>> STATUS.DRAFT + 0 + >>> tuple(STATUS) + ((0, 'draft'), (1, 'published')) + """ + def __init__(self, *choices): - self._choices = tuple(enumerate(choices)) + self._choices = tuple() + self._iter_choices = tuple() + for i, choice in enumerate(self.equalize(choices)): + self._choices += ((i, choice[0]),) + self._iter_choices += ((i, choice[1]),) self._choice_dict = dict(self._choices) self._reverse_dict = dict(((i[1], i[0]) for i in self._choices)) + def equalize(self, choices): + for choice in choices: + if isinstance(choice, (list, tuple)): + yield choice + else: + yield (choice, choice) + def __iter__(self): - return iter(self._choices) + return iter(self._iter_choices) def __getattr__(self, attname): try: diff --git a/model_utils/tests/tests.py b/model_utils/tests/tests.py index 59d81f6..98f093c 100644 --- a/model_utils/tests/tests.py +++ b/model_utils/tests/tests.py @@ -88,6 +88,18 @@ class ChoiceEnumTests(TestCase): self.assertEquals(tuple(self.STATUS), ((0, 'DRAFT'), (1, 'PUBLISHED'))) +class LabelChoiceEnumTests(ChoiceEnumTests): + def setUp(self): + self.STATUS = ChoiceEnum( + ('DRAFT', 'draft'), + ('PUBLISHED', 'published'), + 'DELETED', + ) + + def test_iteration(self): + self.assertEquals(tuple(self.STATUS), ((0, 'draft'), (1, 'published'), (2, 'DELETED'))) + + class InheritanceCastModelTests(TestCase): def setUp(self): self.parent = InheritParent.objects.create()