diff --git a/cachalot/tests.py b/cachalot/tests.py index 87b5368..8e7fc8b 100644 --- a/cachalot/tests.py +++ b/cachalot/tests.py @@ -3,6 +3,8 @@ from __future__ import unicode_literals import datetime +from threading import Thread +from time import sleep try: from unittest import skip, skipIf except ImportError: # For Python 2.6 @@ -15,7 +17,7 @@ from django.db import transaction, connection from django.db.models import ( Model, CharField, ForeignKey, BooleanField, DateField, DateTimeField, Count) -from django.test import TestCase +from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature from .settings import cachalot_settings @@ -32,7 +34,7 @@ class Test(Model): ordering = ('name',) -class ReadTestCase(TestCase): +class ReadTestCase(TransactionTestCase): """ Tests if every SQL request that only reads data is cached. @@ -1233,3 +1235,61 @@ class SettingsTestCase(TestCase): list(Test.objects.all()) with self.assertNumQueries(0): list(Test.objects.all()) + + +class TestThread(Thread): + def wait_for_main(self): + self.wait = True + while self.wait: + sleep(0.001) + + def wait_for_child(self): + self.wait = False + while not self.wait: + sleep(0.001) + + def run(self): + self.wait_for_main() + self.t1 = Test.objects.first() + + self.wait_for_main() + self.t2 = Test.objects.first() + self.wait = True + + connection.close() + + +class ThreadSafetyTestCase(TransactionTestCase): + def setUp(self): + self.thread = TestThread() + + def tearDown(self): + if self.thread.is_alive(): + self.thread.join(0) + + @skipUnlessDBFeature('test_db_allows_multiple_connections') + def test_concurrent_caching(self): + self.thread.start() + + self.thread.wait_for_child() + t = Test.objects.create(name='test') + self.thread.wait_for_child() + + self.assertEqual(self.thread.t1, None) + self.assertEqual(self.thread.t2, t) + + @skipUnlessDBFeature('test_db_allows_multiple_connections') + def test_concurrent_caching_during_atomic(self): + self.thread.start() + with self.assertNumQueries(1): + with transaction.atomic(): + self.thread.wait_for_child() + t = Test.objects.create(name='test2') + self.thread.wait_for_child() + + self.assertEqual(self.thread.t1, None) + self.assertEqual(self.thread.t2, None) + + with self.assertNumQueries(1): + data = Test.objects.first() + self.assertEqual(data, t)