diff --git a/tos/compat.py b/tos/compat.py new file mode 100644 index 0000000..11855c9 --- /dev/null +++ b/tos/compat.py @@ -0,0 +1,19 @@ +import django +from django.conf import settings + + +def get_fk_user_model(): + if django.VERSION >= (1, 5): + return settings.AUTH_USER_MODEL + else: + from django.contrib.auth.models import User + return User + + +def get_runtime_user_model(): + if django.VERSION >= (1, 5): + from django.contrib.auth import get_user_model + return get_user_model() + else: + from django.contrib.auth.models import User + return User diff --git a/tos/models.py b/tos/models.py index dbdffe6..1174b7a 100644 --- a/tos/models.py +++ b/tos/models.py @@ -2,14 +2,7 @@ from django.core.exceptions import ValidationError from django.db import models from django.utils.translation import ugettext_lazy as _ -# Django 1.4 compatability -try: - from django.contrib.auth import get_user_model - USER_MODEL = get_user_model() -except ImportError: - from django.contrib.auth.models import User - USER_MODEL = User - +from tos.compat import get_fk_user_model class NoActiveTermsOfService(ValidationError): pass @@ -76,7 +69,7 @@ class TermsOfService(BaseModel): class UserAgreement(BaseModel): terms_of_service = models.ForeignKey(TermsOfService, related_name='terms') - user = models.ForeignKey(USER_MODEL, related_name='user_agreement') + user = models.ForeignKey(get_fk_user_model(), related_name='user_agreement') def __unicode__(self): return u'%s agreed to TOS: %s' % (self.user.username, diff --git a/tos/tests/test_models.py b/tos/tests/test_models.py index 5987939..42d3fc1 100644 --- a/tos/tests/test_models.py +++ b/tos/tests/test_models.py @@ -1,24 +1,24 @@ from django.core.exceptions import ValidationError from django.test import TestCase +from tos.compat import get_runtime_user_model from tos.models import ( TermsOfService, UserAgreement, has_user_agreed_latest_tos, - USER_MODEL ) class TestModels(TestCase): def setUp(self): - self.user1 = USER_MODEL.objects.create_user('user1', + self.user1 = get_runtime_user_model().objects.create_user('user1', 'user1@example.com', 'user1pass') - self.user2 = USER_MODEL.objects.create_user('user2', + self.user2 = get_runtime_user_model().objects.create_user('user2', 'user2@example.com', 'user2pass') - self.user3 = USER_MODEL.objects.create_user('user3', + self.user3 = get_runtime_user_model().objects.create_user('user3', 'user3@example.com', 'user3pass') diff --git a/tos/tests/test_views.py b/tos/tests/test_views.py index 5297965..7418867 100644 --- a/tos/tests/test_views.py +++ b/tos/tests/test_views.py @@ -2,20 +2,15 @@ from django.conf import settings from django.core.urlresolvers import reverse from django.test import TestCase -# Django 1.4 compatability -try: - from django.contrib.auth import get_user_model -except ImportError: - from django.contrib.auth.models import User - get_user_model = lambda: User +from tos.compat import get_runtime_user_model +from tos.models import TermsOfService, UserAgreement, has_user_agreed_latest_tos -from tos.models import TermsOfService, UserAgreement, has_user_agreed_latest_tos, USER_MODEL as USER class TestViews(TestCase): def setUp(self): - self.user1 = USER.objects.create_user('user1', 'user1@example.com', 'user1pass') - self.user2 = USER.objects.create_user('user2', 'user2@example.com', 'user2pass') + self.user1 = get_runtime_user_model().objects.create_user('user1', 'user1@example.com', 'user1pass') + self.user2 = get_runtime_user_model().objects.create_user('user2', 'user2@example.com', 'user2pass') self.tos1 = TermsOfService.objects.create( content="first edition of the terms of service", diff --git a/tos/views.py b/tos/views.py index 0b0bf1f..bcf7b19 100644 --- a/tos/views.py +++ b/tos/views.py @@ -13,16 +13,9 @@ from django.views.decorators.cache import never_cache from django.views.decorators.csrf import csrf_protect from django.utils.translation import ugettext_lazy as _ +from tos.compat import get_runtime_user_model from tos.models import has_user_agreed_latest_tos, TermsOfService, UserAgreement -# Django 1.4 compatability -try: - from django.contrib.auth import get_user_model - USER_MODEL = get_user_model() -except ImportError: - from django.contrib.auth.models import User - USER_MODEL = User - class TosView(TemplateView): template_name = "tos/tos.html" @@ -58,7 +51,7 @@ def check_tos(request, template_name='tos/tos_check.html', tos = TermsOfService.objects.get_current_tos() if request.method == "POST": if request.POST.get("accept", "") == "accept": - user = USER_MODEL.objects.get(pk=request.session['tos_user']) + user = get_runtime_user_model().objects.get(pk=request.session['tos_user']) user.backend = request.session['tos_backend'] # Save the user agreement to the new TOS