From fa1fbc1a5d7e1dbdfdf60e8562967ee624c935ea Mon Sep 17 00:00:00 2001 From: Benedikt Willi Date: Tue, 6 Jun 2023 10:05:44 +0200 Subject: [PATCH] Added a FilteredTransactionTestCase, updated tests. --- cachalot/tests/read.py | 8 ++--- cachalot/tests/test_utils.py | 63 ++++++++++++++++++++++++++++++++- cachalot/tests/transaction.py | 55 ++++++++--------------------- cachalot/tests/write.py | 66 +++++++++-------------------------- 4 files changed, 97 insertions(+), 95 deletions(-) diff --git a/cachalot/tests/read.py b/cachalot/tests/read.py index badeec2..b2b492f 100644 --- a/cachalot/tests/read.py +++ b/cachalot/tests/read.py @@ -21,7 +21,7 @@ from cachalot.cache import cachalot_caches from ..settings import cachalot_settings from ..utils import UncachableQuery from .models import Test, TestChild, TestParent, UnmanagedModel -from .test_utils import TestUtilsMixin +from .test_utils import TestUtilsMixin, FilteredTransactionTestCase from .tests_decorators import all_final_sql_checks, with_final_sql_check, no_final_sql_check @@ -36,7 +36,7 @@ def is_field_available(name): return name in fields -class ReadTestCase(TestUtilsMixin, TransactionTestCase): +class ReadTestCase(TestUtilsMixin, FilteredTransactionTestCase): """ Tests if every SQL request that only reads data is cached. @@ -896,9 +896,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase): self.assert_query_cached(qs, [self.t2, self.t1]) def test_table_inheritance(self): - with self.assertNumQueries( - 3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2) - ): + with self.assertNumQueries(2): t_child = TestChild.objects.create(name='test_child') with self.assertNumQueries(1): diff --git a/cachalot/tests/test_utils.py b/cachalot/tests/test_utils.py index cb21774..decb7ce 100644 --- a/cachalot/tests/test_utils.py +++ b/cachalot/tests/test_utils.py @@ -1,5 +1,8 @@ from django.core.management.color import no_style -from django.db import connection, transaction +from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction + +from django.test import TransactionTestCase +from django.test.utils import CaptureQueriesContext from ..utils import _get_tables from .models import PostgresModel @@ -57,3 +60,61 @@ class TestUtilsMixin: assert_function(data2, data1) if result is not None: assert_function(data2, result) + +class FilteredTransactionTestCase(TransactionTestCase): + """ + TransactionTestCase with assertNumQueries that ignores BEGIN, COMMIT and ROLLBACK + queries. + """ + def assertNumQueries(self, num, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs): + conn = connections[using] + + context = FilteredAssertNumQueriesContext(self, num, conn) + if func is None: + return context + + with context: + func(*args, **kwargs) + + +class FilteredAssertNumQueriesContext(CaptureQueriesContext): + """ + Capture queries and assert their number ignoring BEGIN, COMMIT and ROLLBACK queries. + """ + EXCLUDE = ('BEGIN', 'COMMIT', 'ROLLBACK') + + def __init__(self, test_case, num, connection): + self.test_case = test_case + self.num = num + super().__init__(connection) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if exc_type is not None: + return + + filtered_queries = [] + excluded_queries = [] + for q in self.captured_queries: + if q['sql'].upper() not in self.EXCLUDE: + filtered_queries.append(q) + else: + excluded_queries.append(q) + + executed = len(filtered_queries) + + self.test_case.assertEqual( + executed, + self.num, + f"\n{executed} queries executed, {self.num} expected\n" + + "\nCaptured queries were:\n" + + "".join( + f"{i}. {query['sql']}\n" + for i, query in enumerate(filtered_queries, start=1) + ) + + "\nCaptured queries, that were excluded:\n" + + "".join( + f"{i}. {query['sql']}\n" + for i, query in enumerate(excluded_queries, start=1) + ) + ) diff --git a/cachalot/tests/transaction.py b/cachalot/tests/transaction.py index ceb7e55..38b3710 100644 --- a/cachalot/tests/transaction.py +++ b/cachalot/tests/transaction.py @@ -1,20 +1,17 @@ from cachalot.transaction import AtomicCache -from django import VERSION as DJANGO_VERSION from django.contrib.auth.models import User from django.core.cache import cache from django.db import transaction, connection, IntegrityError -from django.test import SimpleTestCase, TransactionTestCase, skipUnlessDBFeature +from django.test import SimpleTestCase, skipUnlessDBFeature from .models import Test -from .test_utils import TestUtilsMixin +from .test_utils import TestUtilsMixin, FilteredTransactionTestCase -class AtomicTestCase(TestUtilsMixin, TransactionTestCase): +class AtomicTestCase(TestUtilsMixin, FilteredTransactionTestCase): def test_successful_read_atomic(self): - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): with transaction.atomic(): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) @@ -24,9 +21,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): self.assertListEqual(data2, []) def test_unsuccessful_read_atomic(self): - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): try: with transaction.atomic(): data1 = list(Test.objects.all()) @@ -44,27 +39,21 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): with transaction.atomic(): t1 = Test.objects.create(name='test1') with self.assertNumQueries(1): data2 = list(Test.objects.all()) self.assertListEqual(data2, [t1]) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): with transaction.atomic(): t2 = Test.objects.create(name='test2') with self.assertNumQueries(1): data3 = list(Test.objects.all()) self.assertListEqual(data3, [t1, t2]) - with self.assertNumQueries( - 4 if self.is_sqlite else (5 if DJANGO_VERSION >= (4, 2) else 3) - ): + with self.assertNumQueries(3): with transaction.atomic(): data4 = list(Test.objects.all()) t3 = Test.objects.create(name='test3') @@ -79,9 +68,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): try: with transaction.atomic(): Test.objects.create(name='test') @@ -96,9 +83,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): Test.objects.get(name='test') def test_cache_inside_atomic(self): - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): with transaction.atomic(): data1 = list(Test.objects.all()) data2 = list(Test.objects.all()) @@ -106,9 +91,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): self.assertListEqual(data2, []) def test_invalidation_inside_atomic(self): - with self.assertNumQueries( - 4 if self.is_sqlite else (5 if DJANGO_VERSION >= (4, 2) else 3) - ): + with self.assertNumQueries(3): with transaction.atomic(): data1 = list(Test.objects.all()) t = Test.objects.create(name='test') @@ -117,9 +100,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): self.assertListEqual(data2, [t]) def test_successful_nested_read_atomic(self): - with self.assertNumQueries( - 7 if self.is_sqlite else (8 if DJANGO_VERSION >= (4, 2) else 6) - ): + with self.assertNumQueries(6): with transaction.atomic(): list(Test.objects.all()) with transaction.atomic(): @@ -134,9 +115,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): list(User.objects.all()) def test_unsuccessful_nested_read_atomic(self): - with self.assertNumQueries( - 6 if self.is_sqlite else (7 if DJANGO_VERSION >= (4, 2) else 5) - ): + with self.assertNumQueries(5): with transaction.atomic(): try: with transaction.atomic(): @@ -149,9 +128,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): list(Test.objects.all()) def test_successful_nested_write_atomic(self): - with self.assertNumQueries( - 13 if self.is_sqlite else (14 if DJANGO_VERSION >= (4, 2) else 12) - ): + with self.assertNumQueries(12): with transaction.atomic(): t1 = Test.objects.create(name='test1') with transaction.atomic(): @@ -168,9 +145,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase): self.assertListEqual(data3, [t1, t2, t3, t4]) def test_unsuccessful_nested_write_atomic(self): - with self.assertNumQueries( - 16 if self.is_sqlite else (17 if DJANGO_VERSION >= (4, 2) else 15) - ): + with self.assertNumQueries(15): with transaction.atomic(): t1 = Test.objects.create(name='test1') try: diff --git a/cachalot/tests/write.py b/cachalot/tests/write.py index 24d87f3..dd6459a 100644 --- a/cachalot/tests/write.py +++ b/cachalot/tests/write.py @@ -1,8 +1,5 @@ -import sys - from unittest import skipIf, skipUnless -from django import VERSION as DJANGO_VERSION from django.contrib.auth.models import User, Permission, Group from django.contrib.contenttypes.models import ContentType from django.core.exceptions import MultipleObjectsReturned @@ -14,10 +11,10 @@ from django.db.models.expressions import RawSQL from django.test import TransactionTestCase, skipUnlessDBFeature from .models import Test, TestParent, TestChild -from .test_utils import TestUtilsMixin +from .test_utils import TestUtilsMixin, FilteredTransactionTestCase -class WriteTestCase(TestUtilsMixin, TransactionTestCase): +class WriteTestCase(TestUtilsMixin, FilteredTransactionTestCase): """ Tests if every SQL request writing data is not cached and invalidates the implied data. @@ -58,9 +55,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - with self.assertNumQueries( - 3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2) - ): + with self.assertNumQueries(2): t, created = Test.objects.get_or_create(name='test') self.assertTrue(created) @@ -82,18 +77,14 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): with self.assertNumQueries(1): self.assertListEqual(list(Test.objects.all()), []) - with self.assertNumQueries( - 5 if self.is_sqlite else (6 if DJANGO_VERSION >= (4, 2) else 4) - ): + with self.assertNumQueries(4): t, created = Test.objects.update_or_create( name='test', defaults={'public': True}) self.assertTrue(created) self.assertEqual(t.name, 'test') self.assertEqual(t.public, True) - with self.assertNumQueries( - 3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2) - ): + with self.assertNumQueries(2): t, created = Test.objects.update_or_create( name='test', defaults={'public': False}) self.assertFalse(created) @@ -102,9 +93,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): # The number of SQL queries doesn’t decrease because update_or_create # always calls an UPDATE, even when data wasn’t changed. - with self.assertNumQueries( - 3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2) - ): + with self.assertNumQueries(2): t, created = Test.objects.update_or_create( name='test', defaults={'public': False}) self.assertFalse(created) @@ -119,21 +108,17 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): data1 = list(Test.objects.all()) self.assertListEqual(data1, []) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): unsaved_tests = [Test(name='test%02d' % i) for i in range(1, 11)] Test.objects.bulk_create(unsaved_tests) self.assertEqual(Test.objects.count(), 10) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): unsaved_tests = [Test(name='test%02d' % i) for i in range(1, 11)] Test.objects.bulk_create(unsaved_tests) self.assertEqual(Test.objects.count(), 20) - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): data2 = list(Test.objects.all()) self.assertEqual(len(data2), 20) self.assertListEqual([t.name for t in data2], @@ -174,16 +159,12 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): self.assertListEqual(data1, [t1.name, t2.name]) self.assertListEqual(data2, [t1.name]) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): Test.objects.bulk_create([Test(name='test%s' % i) for i in range(2, 11)]) with self.assertNumQueries(1): self.assertEqual(Test.objects.count(), 10) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): Test.objects.all().delete() with self.assertNumQueries(1): self.assertEqual(Test.objects.count(), 0) @@ -378,9 +359,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): self.assertListEqual(data4, [user1, user2]) self.assertListEqual([u.n for u in data4], [1, 0]) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): Test.objects.bulk_create([ Test(name='test3', owner=user1), Test(name='test4', owner=user2), @@ -608,9 +587,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): data2 = list(Test.objects.select_related('owner')) self.assertListEqual(data2, []) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): Test.objects.bulk_create([ Test(name='test1', owner=u1), Test(name='test2', owner=u2), @@ -624,9 +601,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): self.assertEqual(data3[2].owner, u2) self.assertEqual(data3[3].owner, u1) - with self.assertNumQueries( - 2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1) - ): + with self.assertNumQueries(1): Test.objects.filter(name__in=['test1', 'test2']).delete() with self.assertNumQueries(1): data4 = list(Test.objects.select_related('owner')) @@ -658,12 +633,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): self.assertEqual(data3[0].owner, u) self.assertListEqual(list(data3[0].owner.groups.all()), []) - with self.assertNumQueries( - 8 if self.is_postgresql and DJANGO_VERSION >= (4, 2) - else 4 if self.is_postgresql and DJANGO_VERSION >= (3, 0) - else 4 if self.is_mysql and DJANGO_VERSION >= (3, 0) - else 6 - ): + with self.assertNumQueries(4): group = Group.objects.create(name='test_group') permissions = list(Permission.objects.all()[:5]) group.permissions.add(*permissions) @@ -718,7 +688,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): @skipUnlessDBFeature('has_select_for_update') def test_invalidate_select_for_update(self): - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): Test.objects.bulk_create([Test(name='test1'), Test(name='test2')]) with self.assertNumQueries(1): @@ -876,9 +846,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase): with self.assertRaises(TestChild.DoesNotExist): TestChild.objects.get() - with self.assertNumQueries( - 3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2) - ): + with self.assertNumQueries(2): t_child = TestChild.objects.create(name='test_child') with self.assertNumQueries(1):