Added a FilteredTransactionTestCase, updated tests.

This commit is contained in:
Benedikt Willi 2023-06-06 10:05:44 +02:00
parent beff1e4050
commit fa1fbc1a5d
4 changed files with 97 additions and 95 deletions

View file

@ -21,7 +21,7 @@ from cachalot.cache import cachalot_caches
from ..settings import cachalot_settings
from ..utils import UncachableQuery
from .models import Test, TestChild, TestParent, UnmanagedModel
from .test_utils import TestUtilsMixin
from .test_utils import TestUtilsMixin, FilteredTransactionTestCase
from .tests_decorators import all_final_sql_checks, with_final_sql_check, no_final_sql_check
@ -36,7 +36,7 @@ def is_field_available(name):
return name in fields
class ReadTestCase(TestUtilsMixin, TransactionTestCase):
class ReadTestCase(TestUtilsMixin, FilteredTransactionTestCase):
"""
Tests if every SQL request that only reads data is cached.
@ -896,9 +896,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_query_cached(qs, [self.t2, self.t1])
def test_table_inheritance(self):
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t_child = TestChild.objects.create(name='test_child')
with self.assertNumQueries(1):

View file

@ -1,5 +1,8 @@
from django.core.management.color import no_style
from django.db import connection, transaction
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
from django.test import TransactionTestCase
from django.test.utils import CaptureQueriesContext
from ..utils import _get_tables
from .models import PostgresModel
@ -57,3 +60,61 @@ class TestUtilsMixin:
assert_function(data2, data1)
if result is not None:
assert_function(data2, result)
class FilteredTransactionTestCase(TransactionTestCase):
"""
TransactionTestCase with assertNumQueries that ignores BEGIN, COMMIT and ROLLBACK
queries.
"""
def assertNumQueries(self, num, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs):
conn = connections[using]
context = FilteredAssertNumQueriesContext(self, num, conn)
if func is None:
return context
with context:
func(*args, **kwargs)
class FilteredAssertNumQueriesContext(CaptureQueriesContext):
"""
Capture queries and assert their number ignoring BEGIN, COMMIT and ROLLBACK queries.
"""
EXCLUDE = ('BEGIN', 'COMMIT', 'ROLLBACK')
def __init__(self, test_case, num, connection):
self.test_case = test_case
self.num = num
super().__init__(connection)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if exc_type is not None:
return
filtered_queries = []
excluded_queries = []
for q in self.captured_queries:
if q['sql'].upper() not in self.EXCLUDE:
filtered_queries.append(q)
else:
excluded_queries.append(q)
executed = len(filtered_queries)
self.test_case.assertEqual(
executed,
self.num,
f"\n{executed} queries executed, {self.num} expected\n" +
"\nCaptured queries were:\n" +
"".join(
f"{i}. {query['sql']}\n"
for i, query in enumerate(filtered_queries, start=1)
) +
"\nCaptured queries, that were excluded:\n" +
"".join(
f"{i}. {query['sql']}\n"
for i, query in enumerate(excluded_queries, start=1)
)
)

View file

@ -1,20 +1,17 @@
from cachalot.transaction import AtomicCache
from django import VERSION as DJANGO_VERSION
from django.contrib.auth.models import User
from django.core.cache import cache
from django.db import transaction, connection, IntegrityError
from django.test import SimpleTestCase, TransactionTestCase, skipUnlessDBFeature
from django.test import SimpleTestCase, skipUnlessDBFeature
from .models import Test
from .test_utils import TestUtilsMixin
from .test_utils import TestUtilsMixin, FilteredTransactionTestCase
class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
class AtomicTestCase(TestUtilsMixin, FilteredTransactionTestCase):
def test_successful_read_atomic(self):
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
with transaction.atomic():
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
@ -24,9 +21,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, [])
def test_unsuccessful_read_atomic(self):
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
try:
with transaction.atomic():
data1 = list(Test.objects.all())
@ -44,27 +39,21 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
with transaction.atomic():
t1 = Test.objects.create(name='test1')
with self.assertNumQueries(1):
data2 = list(Test.objects.all())
self.assertListEqual(data2, [t1])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
with transaction.atomic():
t2 = Test.objects.create(name='test2')
with self.assertNumQueries(1):
data3 = list(Test.objects.all())
self.assertListEqual(data3, [t1, t2])
with self.assertNumQueries(
4 if self.is_sqlite else (5 if DJANGO_VERSION >= (4, 2) else 3)
):
with self.assertNumQueries(3):
with transaction.atomic():
data4 = list(Test.objects.all())
t3 = Test.objects.create(name='test3')
@ -79,9 +68,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
try:
with transaction.atomic():
Test.objects.create(name='test')
@ -96,9 +83,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
Test.objects.get(name='test')
def test_cache_inside_atomic(self):
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
with transaction.atomic():
data1 = list(Test.objects.all())
data2 = list(Test.objects.all())
@ -106,9 +91,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, [])
def test_invalidation_inside_atomic(self):
with self.assertNumQueries(
4 if self.is_sqlite else (5 if DJANGO_VERSION >= (4, 2) else 3)
):
with self.assertNumQueries(3):
with transaction.atomic():
data1 = list(Test.objects.all())
t = Test.objects.create(name='test')
@ -117,9 +100,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, [t])
def test_successful_nested_read_atomic(self):
with self.assertNumQueries(
7 if self.is_sqlite else (8 if DJANGO_VERSION >= (4, 2) else 6)
):
with self.assertNumQueries(6):
with transaction.atomic():
list(Test.objects.all())
with transaction.atomic():
@ -134,9 +115,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
list(User.objects.all())
def test_unsuccessful_nested_read_atomic(self):
with self.assertNumQueries(
6 if self.is_sqlite else (7 if DJANGO_VERSION >= (4, 2) else 5)
):
with self.assertNumQueries(5):
with transaction.atomic():
try:
with transaction.atomic():
@ -149,9 +128,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
list(Test.objects.all())
def test_successful_nested_write_atomic(self):
with self.assertNumQueries(
13 if self.is_sqlite else (14 if DJANGO_VERSION >= (4, 2) else 12)
):
with self.assertNumQueries(12):
with transaction.atomic():
t1 = Test.objects.create(name='test1')
with transaction.atomic():
@ -168,9 +145,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data3, [t1, t2, t3, t4])
def test_unsuccessful_nested_write_atomic(self):
with self.assertNumQueries(
16 if self.is_sqlite else (17 if DJANGO_VERSION >= (4, 2) else 15)
):
with self.assertNumQueries(15):
with transaction.atomic():
t1 = Test.objects.create(name='test1')
try:

View file

@ -1,8 +1,5 @@
import sys
from unittest import skipIf, skipUnless
from django import VERSION as DJANGO_VERSION
from django.contrib.auth.models import User, Permission, Group
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import MultipleObjectsReturned
@ -14,10 +11,10 @@ from django.db.models.expressions import RawSQL
from django.test import TransactionTestCase, skipUnlessDBFeature
from .models import Test, TestParent, TestChild
from .test_utils import TestUtilsMixin
from .test_utils import TestUtilsMixin, FilteredTransactionTestCase
class WriteTestCase(TestUtilsMixin, TransactionTestCase):
class WriteTestCase(TestUtilsMixin, FilteredTransactionTestCase):
"""
Tests if every SQL request writing data is not cached and invalidates the
implied data.
@ -58,9 +55,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t, created = Test.objects.get_or_create(name='test')
self.assertTrue(created)
@ -82,18 +77,14 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(1):
self.assertListEqual(list(Test.objects.all()), [])
with self.assertNumQueries(
5 if self.is_sqlite else (6 if DJANGO_VERSION >= (4, 2) else 4)
):
with self.assertNumQueries(4):
t, created = Test.objects.update_or_create(
name='test', defaults={'public': True})
self.assertTrue(created)
self.assertEqual(t.name, 'test')
self.assertEqual(t.public, True)
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t, created = Test.objects.update_or_create(
name='test', defaults={'public': False})
self.assertFalse(created)
@ -102,9 +93,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
# The number of SQL queries doesnt decrease because update_or_create
# always calls an UPDATE, even when data wasnt changed.
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t, created = Test.objects.update_or_create(
name='test', defaults={'public': False})
self.assertFalse(created)
@ -119,21 +108,17 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
unsaved_tests = [Test(name='test%02d' % i) for i in range(1, 11)]
Test.objects.bulk_create(unsaved_tests)
self.assertEqual(Test.objects.count(), 10)
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
unsaved_tests = [Test(name='test%02d' % i) for i in range(1, 11)]
Test.objects.bulk_create(unsaved_tests)
self.assertEqual(Test.objects.count(), 20)
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
data2 = list(Test.objects.all())
self.assertEqual(len(data2), 20)
self.assertListEqual([t.name for t in data2],
@ -174,16 +159,12 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data1, [t1.name, t2.name])
self.assertListEqual(data2, [t1.name])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.bulk_create([Test(name='test%s' % i)
for i in range(2, 11)])
with self.assertNumQueries(1):
self.assertEqual(Test.objects.count(), 10)
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.all().delete()
with self.assertNumQueries(1):
self.assertEqual(Test.objects.count(), 0)
@ -378,9 +359,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data4, [user1, user2])
self.assertListEqual([u.n for u in data4], [1, 0])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.bulk_create([
Test(name='test3', owner=user1),
Test(name='test4', owner=user2),
@ -608,9 +587,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
data2 = list(Test.objects.select_related('owner'))
self.assertListEqual(data2, [])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.bulk_create([
Test(name='test1', owner=u1),
Test(name='test2', owner=u2),
@ -624,9 +601,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(data3[2].owner, u2)
self.assertEqual(data3[3].owner, u1)
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.filter(name__in=['test1', 'test2']).delete()
with self.assertNumQueries(1):
data4 = list(Test.objects.select_related('owner'))
@ -658,12 +633,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(data3[0].owner, u)
self.assertListEqual(list(data3[0].owner.groups.all()), [])
with self.assertNumQueries(
8 if self.is_postgresql and DJANGO_VERSION >= (4, 2)
else 4 if self.is_postgresql and DJANGO_VERSION >= (3, 0)
else 4 if self.is_mysql and DJANGO_VERSION >= (3, 0)
else 6
):
with self.assertNumQueries(4):
group = Group.objects.create(name='test_group')
permissions = list(Permission.objects.all()[:5])
group.permissions.add(*permissions)
@ -718,7 +688,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
@skipUnlessDBFeature('has_select_for_update')
def test_invalidate_select_for_update(self):
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
Test.objects.bulk_create([Test(name='test1'), Test(name='test2')])
with self.assertNumQueries(1):
@ -876,9 +846,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertRaises(TestChild.DoesNotExist):
TestChild.objects.get()
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t_child = TestChild.objects.create(name='test_child')
with self.assertNumQueries(1):