diff --git a/cachalot/tests.py b/cachalot/tests.py index 726971c..9613223 100644 --- a/cachalot/tests.py +++ b/cachalot/tests.py @@ -1252,11 +1252,14 @@ class TestThread(Thread): while not self.wait and not self.exit: sleep(0.001) + def start(self, n=2): + self.n = n + super(TestThread, self).start() + def run(self): - self.t1 = Test.objects.first() - self.wait_for_main() - self.t2 = Test.objects.first() - self.wait_for_main() + for i in range(1, self.n+1): + setattr(self, 't%d' % i, Test.objects.first()) + self.wait_for_main() connection.close() @@ -1363,3 +1366,22 @@ class ThreadSafetyTestCase(TransactionTestCase): with self.assertNumQueries(0): data = Test.objects.first() self.assertEqual(data, t) + + @skipUnlessDBFeature('test_db_allows_multiple_connections') + def test_concurrent_caching_during_and_after_atomic_3(self): + with self.assertNumQueries(1): + with transaction.atomic(): + self.thread.start(3) + self.thread.wait_for_child() + t = Test.objects.create(name='test') + self.thread.wait_for_child() + + self.thread.wait_for_child() + + self.assertEqual(self.thread.t1, None) + self.assertEqual(self.thread.t2, None) + self.assertEqual(self.thread.t3, t) + + with self.assertNumQueries(0): + data = Test.objects.first() + self.assertEqual(data, t)