From 9ad5582f09c2555e4dffbf893cf22670d9244c87 Mon Sep 17 00:00:00 2001 From: Bertrand Bordage Date: Sun, 19 Oct 2014 22:51:59 +0200 Subject: [PATCH] Takes SQLite BEGINs into account in TransactionTestCases. --- cachalot/tests.py | 148 +++++++++++++++++++++++++++++----------------- 1 file changed, 94 insertions(+), 54 deletions(-) diff --git a/cachalot/tests.py b/cachalot/tests.py index 593d665..91a35eb 100644 --- a/cachalot/tests.py +++ b/cachalot/tests.py @@ -579,21 +579,23 @@ class WriteTestCase(TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): t1 = Test.objects.create(name='test1') - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t2 = Test.objects.create(name='test2') with self.assertNumQueries(1): data2 = list(Test.objects.all()) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t3 = Test.objects.create(name='test3') with self.assertNumQueries(1): data3 = list(Test.objects.all()) self.assertListEqual(data2, [t1, t2]) self.assertListEqual(data3, [t1, t2, t3]) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t3_copy = Test.objects.create(name='test3') self.assertNotEqual(t3_copy, t3) with self.assertNumQueries(1): @@ -609,10 +611,9 @@ class WriteTestCase(TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - # get_or_create has to try to find the object, then create it - # inside a transaction. - # This triggers 2 queries: SELECT & UPDATE - with self.assertNumQueries(2): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(3 if is_sqlite else 2): t, created = Test.objects.get_or_create(name='test') self.assertTrue(created) @@ -639,12 +640,14 @@ class WriteTestCase(TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): unsaved_tests = [Test(name='test%02d' % i) for i in range(1, 11)] Test.objects.bulk_create(unsaved_tests) self.assertEqual(Test.objects.count(), 10) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): unsaved_tests = [Test(name='test%02d' % i) for i in range(1, 11)] Test.objects.bulk_create(unsaved_tests) self.assertEqual(Test.objects.count(), 20) @@ -656,12 +659,14 @@ class WriteTestCase(TransactionTestCase): ['test%02d' % (i // 2) for i in range(2, 22)]) def test_update(self): - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): t = Test.objects.create(name='test1') with self.assertNumQueries(1): t1 = Test.objects.get() - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t.name = 'test2' t.save() with self.assertNumQueries(1): @@ -669,33 +674,35 @@ class WriteTestCase(TransactionTestCase): self.assertEqual(t1.name, 'test1') self.assertEqual(t2.name, 'test2') - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.update(name='test3') with self.assertNumQueries(1): t3 = Test.objects.get() self.assertEqual(t3.name, 'test3') def test_delete(self): - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): t1 = Test.objects.create(name='test1') - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t2 = Test.objects.create(name='test2') with self.assertNumQueries(1): data1 = list(Test.objects.values_list('name', flat=True)) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t2.delete() with self.assertNumQueries(1): data2 = list(Test.objects.values_list('name', flat=True)) self.assertListEqual(data1, [t1.name, t2.name]) self.assertListEqual(data2, [t1.name]) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.bulk_create([Test(name='test%s' % i) for i in range(2, 11)]) with self.assertNumQueries(1): self.assertEqual(Test.objects.count(), 10) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.all().delete() with self.assertNumQueries(1): self.assertEqual(Test.objects.count(), 0) @@ -706,7 +713,9 @@ class WriteTestCase(TransactionTestCase): Test.objects.create(name='test') - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): self.assertTrue(Test.objects.create()) def test_invalidate_count(self): @@ -764,22 +773,24 @@ class WriteTestCase(TransactionTestCase): with self.assertNumQueries(1): self.assertEqual(User.objects.aggregate(n=Count('test'))['n'], 0) - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): u = User.objects.create_user('test') with self.assertNumQueries(1): self.assertEqual(User.objects.aggregate(n=Count('test'))['n'], 0) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.create(name='test1') with self.assertNumQueries(1): self.assertEqual(User.objects.aggregate(n=Count('test'))['n'], 0) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.create(name='test2', owner=u) with self.assertNumQueries(1): self.assertEqual(User.objects.aggregate(n=Count('test'))['n'], 1) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.create(name='test3') with self.assertNumQueries(1): self.assertEqual(User.objects.aggregate(n=Count('test'))['n'], 1) @@ -789,13 +800,15 @@ class WriteTestCase(TransactionTestCase): data1 = list(User.objects.annotate(n=Count('test')).order_by('pk')) self.assertListEqual(data1, []) - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.create(name='test1') with self.assertNumQueries(1): data2 = list(User.objects.annotate(n=Count('test')).order_by('pk')) self.assertListEqual(data2, []) - with self.assertNumQueries(2): + with self.assertNumQueries(4 if is_sqlite else 2): user1 = User.objects.create_user('user1') user2 = User.objects.create_user('user2') with self.assertNumQueries(1): @@ -803,14 +816,14 @@ class WriteTestCase(TransactionTestCase): self.assertListEqual(data3, [user1, user2]) self.assertListEqual([u.n for u in data3], [0, 0]) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.create(name='test2', owner=user1) with self.assertNumQueries(1): data4 = list(User.objects.annotate(n=Count('test')).order_by('pk')) self.assertListEqual(data4, [user1, user2]) self.assertListEqual([u.n for u in data4], [1, 0]) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.bulk_create([ Test(name='test3', owner=user1), Test(name='test4', owner=user2), @@ -879,14 +892,16 @@ class WriteTestCase(TransactionTestCase): data1 = list(Test.objects.select_related('owner')) self.assertListEqual(data1, []) - with self.assertNumQueries(2): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(4 if is_sqlite else 2): u1 = User.objects.create_user('test1') u2 = User.objects.create_user('test2') with self.assertNumQueries(1): data2 = list(Test.objects.select_related('owner')) self.assertListEqual(data2, []) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.bulk_create([ Test(name='test1', owner=u1), Test(name='test2', owner=u2), @@ -900,7 +915,7 @@ class WriteTestCase(TransactionTestCase): self.assertEqual(data3[2].owner, u2) self.assertEqual(data3[3].owner, u1) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): Test.objects.filter(name__in=['test1', 'test2']).delete() with self.assertNumQueries(1): data4 = list(Test.objects.select_related('owner')) @@ -908,6 +923,7 @@ class WriteTestCase(TransactionTestCase): self.assertEqual(data4[1].owner, u1) def test_invalidate_prefetch_related(self): + is_sqlite = connection.vendor == 'sqlite' is_mysql = connection.vendor == 'mysql' with self.assertNumQueries(1): @@ -915,7 +931,7 @@ class WriteTestCase(TransactionTestCase): .prefetch_related('owner__groups__permissions')) self.assertListEqual(data1, []) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t1 = Test.objects.create(name='test1') with self.assertNumQueries(1): data2 = list(Test.objects.select_related('owner') @@ -923,7 +939,7 @@ class WriteTestCase(TransactionTestCase): self.assertListEqual(data2, [t1]) self.assertEqual(data2[0].owner, None) - with self.assertNumQueries(2): + with self.assertNumQueries(4 if is_sqlite else 2): u = User.objects.create_user('user') t1.owner = u t1.save() @@ -934,7 +950,7 @@ class WriteTestCase(TransactionTestCase): self.assertEqual(data3[0].owner, u) self.assertListEqual(list(data3[0].owner.groups.all()), []) - with self.assertNumQueries(6): + with self.assertNumQueries(9 if is_sqlite else 6): group = Group.objects.create(name='test_group') permissions = list(Permission.objects.all()[:5]) group.permissions.add(*permissions) @@ -950,7 +966,7 @@ class WriteTestCase(TransactionTestCase): self.assertListEqual(list(groups[0].permissions.all()), permissions) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t2 = Test.objects.create(name='test2') with self.assertNumQueries(3 if is_mysql else 1): data5 = list(Test.objects.select_related('owner') @@ -964,13 +980,13 @@ class WriteTestCase(TransactionTestCase): for p in g.permissions.all()] self.assertListEqual(data5_permissions, permissions) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): permissions[0].save() with self.assertNumQueries(2 if is_mysql else 1): list(Test.objects.select_related('owner') .prefetch_related('owner__groups__permissions')) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): group.name = 'modified_test_group' group.save() with self.assertNumQueries(2): @@ -979,7 +995,7 @@ class WriteTestCase(TransactionTestCase): g = list(data6[0].owner.groups.all())[0] self.assertEqual(g.name, 'modified_test_group') - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): User.objects.update(username='modified_user') with self.assertNumQueries(3 if is_mysql else 2): @@ -996,26 +1012,28 @@ class WriteTestCase(TransactionTestCase): pass def test_invalidate_extra_tables(self): - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): User.objects.create_user('user1') with self.assertNumQueries(1): data1 = list(Test.objects.all().extra(tables=['auth_user'])) self.assertListEqual(data1, []) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t1 = Test.objects.create(name='test1') with self.assertNumQueries(1): data2 = list(Test.objects.all().extra(tables=['auth_user'])) self.assertListEqual(data2, [t1]) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t2 = Test.objects.create(name='test2') with self.assertNumQueries(1): data3 = list(Test.objects.all().extra(tables=['auth_user'])) self.assertListEqual(data3, [t1, t2]) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): User.objects.create_user('user2') with self.assertNumQueries(1): data4 = list(Test.objects.all().extra(tables=['auth_user'])) @@ -1028,7 +1046,9 @@ class WriteTestCase(TransactionTestCase): class AtomicTestCase(TransactionTestCase): def test_successful_read_atomic(self): - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): with transaction.atomic(): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) @@ -1038,7 +1058,9 @@ class AtomicTestCase(TransactionTestCase): self.assertListEqual(data2, []) def test_unsuccessful_read_atomic(self): - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): try: with transaction.atomic(): data1 = list(Test.objects.all()) @@ -1056,21 +1078,23 @@ class AtomicTestCase(TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): with transaction.atomic(): t1 = Test.objects.create(name='test1') with self.assertNumQueries(1): data2 = list(Test.objects.all()) self.assertListEqual(data2, [t1]) - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): with transaction.atomic(): t2 = Test.objects.create(name='test2') with self.assertNumQueries(1): data3 = list(Test.objects.all()) self.assertListEqual(data3, [t1, t2]) - with self.assertNumQueries(3): + with self.assertNumQueries(4 if is_sqlite else 3): with transaction.atomic(): data4 = list(Test.objects.all()) t3 = Test.objects.create(name='test3') @@ -1085,7 +1109,9 @@ class AtomicTestCase(TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): try: with transaction.atomic(): Test.objects.create(name='test') @@ -1100,7 +1126,9 @@ class AtomicTestCase(TransactionTestCase): Test.objects.get(name='test') def test_cache_inside_atomic(self): - with self.assertNumQueries(1): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(2 if is_sqlite else 1): with transaction.atomic(): data1 = list(Test.objects.all()) data2 = list(Test.objects.all()) @@ -1108,7 +1136,9 @@ class AtomicTestCase(TransactionTestCase): self.assertListEqual(data2, []) def test_invalidation_inside_atomic(self): - with self.assertNumQueries(3): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(4 if is_sqlite else 3): with transaction.atomic(): data1 = list(Test.objects.all()) t = Test.objects.create(name='test') @@ -1117,7 +1147,9 @@ class AtomicTestCase(TransactionTestCase): self.assertListEqual(data2, [t]) def test_successful_nested_read_atomic(self): - with self.assertNumQueries(6): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(7 if is_sqlite else 6): with transaction.atomic(): list(Test.objects.all()) with transaction.atomic(): @@ -1132,7 +1164,9 @@ class AtomicTestCase(TransactionTestCase): list(User.objects.all()) def test_unsuccessful_nested_read_atomic(self): - with self.assertNumQueries(4): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(5 if is_sqlite else 4): with transaction.atomic(): try: with transaction.atomic(): @@ -1145,7 +1179,9 @@ class AtomicTestCase(TransactionTestCase): list(Test.objects.all()) def test_successful_nested_write_atomic(self): - with self.assertNumQueries(12): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(13 if is_sqlite else 12): with transaction.atomic(): t1 = Test.objects.create(name='test1') with transaction.atomic(): @@ -1162,7 +1198,9 @@ class AtomicTestCase(TransactionTestCase): self.assertListEqual(data3, [t1, t2, t3, t4]) def test_unsuccessful_nested_write_atomic(self): - with self.assertNumQueries(12): + is_sqlite = connection.vendor == 'sqlite' + + with self.assertNumQueries(13 if is_sqlite else 12): with transaction.atomic(): t1 = Test.objects.create(name='test1') try: @@ -1212,8 +1250,10 @@ class SettingsTestCase(TransactionTestCase): with self.assertNumQueries(0): list(Test.objects.all()) + is_sqlite = connection.vendor == 'sqlite' + with cachalot_settings(CACHALOT_ENABLED=False): - with self.assertNumQueries(1): + with self.assertNumQueries(2 if is_sqlite else 1): t = Test.objects.create(name='test') with self.assertNumQueries(1): data = list(Test.objects.all())