diff --git a/cachalot/tests/thread_safety.py b/cachalot/tests/thread_safety.py index cdc32fe..0ca6251 100644 --- a/cachalot/tests/thread_safety.py +++ b/cachalot/tests/thread_safety.py @@ -1,7 +1,8 @@ # coding: utf-8 from __future__ import unicode_literals -from threading import Thread, Event +from threading import Thread, Lock +from time import sleep from django.db import connection, transaction from django.test import TransactionTestCase, skipUnlessDBFeature @@ -12,18 +13,11 @@ from .models import Test class TestThread(Thread): def __init__(self): super(TestThread, self).__init__() - self.event0 = Event() - self.event1 = Event() + self.lock = Lock() - def wait_for_main(self): - self.event1.set() - self.event1.clear() - self.event0.wait(0.5) - - def wait_for_child(self): - self.event0.set() - self.event0.clear() - self.event1.wait(0.5) + def wait(self): + with self.lock: + sleep(0.1) def start(self, n=2): self.n = n @@ -32,7 +26,7 @@ class TestThread(Thread): def run(self): for i in range(1, self.n+1): setattr(self, 't%d' % i, Test.objects.first()) - self.wait_for_main() + self.wait() connection.close() @@ -44,9 +38,9 @@ class ThreadSafetyTestCase(TransactionTestCase): @skipUnlessDBFeature('test_db_allows_multiple_connections') def test_concurrent_caching(self): self.thread.start() - self.thread.wait_for_child() + self.thread.wait() t = Test.objects.create(name='test') - self.thread.wait_for_child() + self.thread.wait() self.assertEqual(self.thread.t1, None) self.assertEqual(self.thread.t2, t) @@ -56,9 +50,9 @@ class ThreadSafetyTestCase(TransactionTestCase): with self.assertNumQueries(1): with transaction.atomic(): self.thread.start() - self.thread.wait_for_child() + self.thread.wait() t = Test.objects.create(name='test') - self.thread.wait_for_child() + self.thread.wait() self.assertEqual(self.thread.t1, None) self.assertEqual(self.thread.t2, None) @@ -70,11 +64,11 @@ class ThreadSafetyTestCase(TransactionTestCase): @skipUnlessDBFeature('test_db_allows_multiple_connections') def test_concurrent_caching_before_and_during_atomic_1(self): self.thread.start() - self.thread.wait_for_child() + self.thread.wait() with self.assertNumQueries(1): with transaction.atomic(): - self.thread.wait_for_child() + self.thread.wait() t = Test.objects.create(name='test') self.assertEqual(self.thread.t1, None) @@ -87,12 +81,12 @@ class ThreadSafetyTestCase(TransactionTestCase): @skipUnlessDBFeature('test_db_allows_multiple_connections') def test_concurrent_caching_before_and_during_atomic_2(self): self.thread.start() - self.thread.wait_for_child() + self.thread.wait() with self.assertNumQueries(1): with transaction.atomic(): t = Test.objects.create(name='test') - self.thread.wait_for_child() + self.thread.wait() self.assertEqual(self.thread.t1, None) self.assertEqual(self.thread.t2, None) @@ -106,10 +100,10 @@ class ThreadSafetyTestCase(TransactionTestCase): with self.assertNumQueries(1): with transaction.atomic(): self.thread.start() - self.thread.wait_for_child() + self.thread.wait() t = Test.objects.create(name='test') - self.thread.wait_for_child() + self.thread.wait() self.assertEqual(self.thread.t1, None) self.assertEqual(self.thread.t2, t) @@ -124,9 +118,9 @@ class ThreadSafetyTestCase(TransactionTestCase): with transaction.atomic(): t = Test.objects.create(name='test') self.thread.start() - self.thread.wait_for_child() + self.thread.wait() - self.thread.wait_for_child() + self.thread.wait() self.assertEqual(self.thread.t1, None) self.assertEqual(self.thread.t2, t) @@ -140,11 +134,11 @@ class ThreadSafetyTestCase(TransactionTestCase): with self.assertNumQueries(1): with transaction.atomic(): self.thread.start(3) - self.thread.wait_for_child() + self.thread.wait() t = Test.objects.create(name='test') - self.thread.wait_for_child() + self.thread.wait() - self.thread.wait_for_child() + self.thread.wait() self.assertEqual(self.thread.t1, None) self.assertEqual(self.thread.t2, None)