diff --git a/tos/tests/templates/index.html b/tos/tests/templates/index.html
index e69de29..9015a7a 100644
--- a/tos/tests/templates/index.html
+++ b/tos/tests/templates/index.html
@@ -0,0 +1 @@
+index
diff --git a/tos/tests/test_middleware.py b/tos/tests/test_middleware.py
index 7979ddb..85a96e9 100644
--- a/tos/tests/test_middleware.py
+++ b/tos/tests/test_middleware.py
@@ -20,6 +20,10 @@ from tos.signal_handlers import invalidate_cached_agreements
class TestMiddleware(TestCase):
def setUp(self):
+ # Clear cache between tests
+ cache = get_cache(getattr(settings, 'TOS_CACHE_NAME', 'default'))
+ cache.clear()
+
# User that has agreed to TOS
self.user1 = get_runtime_user_model().objects.create_user('user1', 'user1@example.com', 'user1pass')
@@ -63,6 +67,26 @@ class TestMiddleware(TestCase):
# Confirm redirects.
self.assertEqual(response.status_code, 302)
+ def test_invalidate_cache_on_accept_fix_redirect_loop(self):
+ """
+ Make sure accepting doesnt send you right back to tos page.
+ """
+ self.assertFalse(UserAgreement.objects.filter(terms_of_service=self.tos1, user=self.user2).exists())
+
+ self.client.login(username='user2', password='user2pass')
+ response = self.client.get(reverse('index'))
+ self.assertRedirects(response, self.redirect_page)
+
+ # Make sure confirm works after middleware redirect.
+ response = self.client.post(reverse('tos_check_tos'), {'accept': 'accept'})
+
+ self.assertTrue(UserAgreement.objects.filter(terms_of_service=self.tos1, user=self.user2).exists())
+
+ response = self.client.get(reverse('index'))
+ self.assertEqual(response.status_code, 200)
+
+ self.assertIn('index', str(response.content))
+
def test_middleware_doesnt_redirect(self):
"""User that has accepted TOS should get 200."""
self.client.login(username='user1', password='user1pass')
diff --git a/tos/views.py b/tos/views.py
index d090d06..d2b83c9 100644
--- a/tos/views.py
+++ b/tos/views.py
@@ -15,10 +15,13 @@ from django.views.decorators.csrf import csrf_protect
from django.views.generic import TemplateView
from django.utils.translation import ugettext_lazy as _
-from tos.compat import get_runtime_user_model, get_request_site
+from tos.compat import get_cache, get_runtime_user_model, get_request_site
from tos.models import has_user_agreed_latest_tos, TermsOfService, UserAgreement
+cache = get_cache(getattr(settings, 'TOS_CACHE_NAME', 'default'))
+
+
class TosView(TemplateView):
template_name = "tos/tos.html"
@@ -59,6 +62,9 @@ def check_tos(request, template_name='tos/tos_check.html',
# Save the user agreement to the new TOS
UserAgreement.objects.get_or_create(terms_of_service=tos, user=user)
+ key_version = cache.get('django:tos:key_version')
+ cache.delete('django:tos:agreed:{0}'.format(user.pk), version=key_version)
+
# Log the user in
auth_login(request, user)