diff --git a/tos/middleware.py b/tos/middleware.py index f0006cf..fbadb82 100644 --- a/tos/middleware.py +++ b/tos/middleware.py @@ -1,6 +1,6 @@ from django import VERSION as DJANGO_VERSION from django.conf import settings -from django.contrib.auth import SESSION_KEY as session_key +from django.contrib.auth import REDIRECT_FIELD_NAME, SESSION_KEY as session_key from django.core.urlresolvers import reverse from django.http import HttpResponseRedirect from django.utils import deprecation @@ -68,7 +68,11 @@ class UserAgreementMiddleware(deprecation.MiddlewareMixin if DJANGO_VERSION >= ( request.session['tos_user'] = user_id request.session['tos_backend'] = user_auth_backend - response = HttpResponseRedirect(tos_check_url) + response = HttpResponseRedirect('{0}?{1}={2}'.format( + tos_check_url, + REDIRECT_FIELD_NAME, + request.path_info, + )) add_never_cache_headers(response) return response diff --git a/tos/tests/test_middleware.py b/tos/tests/test_middleware.py index 3724556..83fbbe0 100644 --- a/tos/tests/test_middleware.py +++ b/tos/tests/test_middleware.py @@ -1,4 +1,5 @@ from django.conf import settings +from django.contrib.auth import REDIRECT_FIELD_NAME from django.core.urlresolvers import reverse from django.test import TestCase from django.test.utils import override_settings @@ -20,6 +21,7 @@ class TestMiddleware(TestCase): # User that has not yet agreed to TOS self.user2 = get_runtime_user_model().objects.create_user('user2', 'user2@example.com', 'user2pass') + self.user3 = get_runtime_user_model().objects.create_user('user3', 'user3@example.com', 'user3pass') self.tos1 = TermsOfService.objects.create( content="first edition of the terms of service", @@ -36,6 +38,12 @@ class TestMiddleware(TestCase): user=self.user1 ) + self.redirect_page = '{0}?{1}={2}'.format( + reverse('tos_check_tos'), + REDIRECT_FIELD_NAME, + reverse('index'), + ) + def test_middleware_redirects(self): """ User that hasn't accepted TOS should be redirected to confirm. Also make sure @@ -43,7 +51,7 @@ class TestMiddleware(TestCase): """ self.client.login(username='user2', password='user2pass') response = self.client.get(reverse('index')) - self.assertRedirects(response, reverse('tos_check_tos')) + self.assertRedirects(response, self.redirect_page) # Make sure confirm works after middleware redirect. response = self.client.post(reverse('tos_check_tos'), {'accept': 'accept'}) @@ -60,3 +68,17 @@ class TestMiddleware(TestCase): def test_anonymous_user_200(self): response = self.client.get(reverse('index')) self.assertEqual(response.status_code, 200) + + def test_accept_after_middleware_redirects_properly(self): + self.client.login(username='user3', password='user3pass') + + response = self.client.get(reverse('index'), follow=True) + + self.assertRedirects(response, self.redirect_page) + + # Agree + response = self.client.post(self.redirect_page, {'accept': 'accept'}) + + # Confirm redirects back to the index page + self.assertEqual(response.status_code, 302) + self.assertEqual(response.url.replace('http://testserver', ''), str(reverse('index')))