Andrew wang/dj4.2 Improve assertNumQueries by filtering out transaction-related queries (#238)

This commit is contained in:
Benedikt Willi 2023-06-09 10:06:38 +02:00 committed by GitHub
parent beff1e4050
commit 7e254ca10a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 110 additions and 109 deletions

View file

@ -21,7 +21,7 @@ from cachalot.cache import cachalot_caches
from ..settings import cachalot_settings
from ..utils import UncachableQuery
from .models import Test, TestChild, TestParent, UnmanagedModel
from .test_utils import TestUtilsMixin
from .test_utils import TestUtilsMixin, FilteredTransactionTestCase
from .tests_decorators import all_final_sql_checks, with_final_sql_check, no_final_sql_check
@ -36,7 +36,7 @@ def is_field_available(name):
return name in fields
class ReadTestCase(TestUtilsMixin, TransactionTestCase):
class ReadTestCase(TestUtilsMixin, FilteredTransactionTestCase):
"""
Tests if every SQL request that only reads data is cached.
@ -816,21 +816,21 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertRaises(TransactionManagementError):
list(Test.objects.select_for_update())
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
with transaction.atomic():
data1 = list(Test.objects.select_for_update())
self.assertListEqual(data1, [self.t1, self.t2])
self.assertListEqual([t.name for t in data1],
['test1', 'test2'])
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
with transaction.atomic():
data2 = list(Test.objects.select_for_update())
self.assertListEqual(data2, [self.t1, self.t2])
self.assertListEqual([t.name for t in data2],
['test1', 'test2'])
with self.assertNumQueries(4 if DJANGO_VERSION >= (4, 2) else 2):
with self.assertNumQueries(2):
with transaction.atomic():
data3 = list(Test.objects.select_for_update())
data4 = list(Test.objects.select_for_update())
@ -896,9 +896,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_query_cached(qs, [self.t2, self.t1])
def test_table_inheritance(self):
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t_child = TestChild.objects.create(name='test_child')
with self.assertNumQueries(1):

View file

@ -1,5 +1,8 @@
from django.core.management.color import no_style
from django.db import connection, transaction
from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction
from django.test import TransactionTestCase
from django.test.utils import CaptureQueriesContext
from ..utils import _get_tables
from .models import PostgresModel
@ -57,3 +60,61 @@ class TestUtilsMixin:
assert_function(data2, data1)
if result is not None:
assert_function(data2, result)
class FilteredTransactionTestCase(TransactionTestCase):
"""
TransactionTestCase with assertNumQueries that ignores BEGIN, COMMIT and ROLLBACK
queries.
"""
def assertNumQueries(self, num, func=None, *args, using=DEFAULT_DB_ALIAS, **kwargs):
conn = connections[using]
context = FilteredAssertNumQueriesContext(self, num, conn)
if func is None:
return context
with context:
func(*args, **kwargs)
class FilteredAssertNumQueriesContext(CaptureQueriesContext):
"""
Capture queries and assert their number ignoring BEGIN, COMMIT and ROLLBACK queries.
"""
EXCLUDE = ('BEGIN', 'COMMIT', 'ROLLBACK')
def __init__(self, test_case, num, connection):
self.test_case = test_case
self.num = num
super().__init__(connection)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if exc_type is not None:
return
filtered_queries = []
excluded_queries = []
for q in self.captured_queries:
if q['sql'].upper() not in self.EXCLUDE:
filtered_queries.append(q)
else:
excluded_queries.append(q)
executed = len(filtered_queries)
self.test_case.assertEqual(
executed,
self.num,
f"\n{executed} queries executed on {self.connection.vendor}, {self.num} expected\n" +
"\nCaptured queries were:\n" +
"".join(
f"{i}. {query['sql']}\n"
for i, query in enumerate(filtered_queries, start=1)
) +
"\nCaptured queries, that were excluded:\n" +
"".join(
f"{i}. {query['sql']}\n"
for i, query in enumerate(excluded_queries, start=1)
)
)

View file

@ -1,11 +1,10 @@
from threading import Thread
from django import VERSION as DJANGO_VERSION
from django.db import connection, transaction
from django.test import TransactionTestCase, skipUnlessDBFeature
from django.test import skipUnlessDBFeature
from .models import Test
from .test_utils import TestUtilsMixin
from .test_utils import TestUtilsMixin, FilteredTransactionTestCase
class TestThread(Thread):
@ -20,7 +19,7 @@ class TestThread(Thread):
@skipUnlessDBFeature('test_db_allows_multiple_connections')
class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase):
class ThreadSafetyTestCase(TestUtilsMixin, FilteredTransactionTestCase):
def test_concurrent_caching(self):
t1 = TestThread().start_and_join()
t = Test.objects.create(name='test')
@ -30,7 +29,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(t2, t)
def test_concurrent_caching_during_atomic(self):
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
with transaction.atomic():
t1 = TestThread().start_and_join()
t = Test.objects.create(name='test')
@ -46,7 +45,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase):
def test_concurrent_caching_before_and_during_atomic_1(self):
t1 = TestThread().start_and_join()
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
with transaction.atomic():
t2 = TestThread().start_and_join()
t = Test.objects.create(name='test')
@ -61,7 +60,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase):
def test_concurrent_caching_before_and_during_atomic_2(self):
t1 = TestThread().start_and_join()
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
with transaction.atomic():
t = Test.objects.create(name='test')
t2 = TestThread().start_and_join()
@ -74,7 +73,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(data, t)
def test_concurrent_caching_during_and_after_atomic_1(self):
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
with transaction.atomic():
t1 = TestThread().start_and_join()
t = Test.objects.create(name='test')
@ -89,7 +88,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(data, t)
def test_concurrent_caching_during_and_after_atomic_2(self):
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
with transaction.atomic():
t = Test.objects.create(name='test')
t1 = TestThread().start_and_join()
@ -104,7 +103,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(data, t)
def test_concurrent_caching_during_and_after_atomic_3(self):
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
with transaction.atomic():
t1 = TestThread().start_and_join()
t = Test.objects.create(name='test')

View file

@ -1,20 +1,17 @@
from cachalot.transaction import AtomicCache
from django import VERSION as DJANGO_VERSION
from django.contrib.auth.models import User
from django.core.cache import cache
from django.db import transaction, connection, IntegrityError
from django.test import SimpleTestCase, TransactionTestCase, skipUnlessDBFeature
from django.test import SimpleTestCase, skipUnlessDBFeature
from .models import Test
from .test_utils import TestUtilsMixin
from .test_utils import TestUtilsMixin, FilteredTransactionTestCase
class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
class AtomicTestCase(TestUtilsMixin, FilteredTransactionTestCase):
def test_successful_read_atomic(self):
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
with transaction.atomic():
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
@ -24,9 +21,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, [])
def test_unsuccessful_read_atomic(self):
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
try:
with transaction.atomic():
data1 = list(Test.objects.all())
@ -44,27 +39,21 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
with transaction.atomic():
t1 = Test.objects.create(name='test1')
with self.assertNumQueries(1):
data2 = list(Test.objects.all())
self.assertListEqual(data2, [t1])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
with transaction.atomic():
t2 = Test.objects.create(name='test2')
with self.assertNumQueries(1):
data3 = list(Test.objects.all())
self.assertListEqual(data3, [t1, t2])
with self.assertNumQueries(
4 if self.is_sqlite else (5 if DJANGO_VERSION >= (4, 2) else 3)
):
with self.assertNumQueries(3):
with transaction.atomic():
data4 = list(Test.objects.all())
t3 = Test.objects.create(name='test3')
@ -79,9 +68,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
try:
with transaction.atomic():
Test.objects.create(name='test')
@ -96,9 +83,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
Test.objects.get(name='test')
def test_cache_inside_atomic(self):
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
with transaction.atomic():
data1 = list(Test.objects.all())
data2 = list(Test.objects.all())
@ -106,9 +91,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, [])
def test_invalidation_inside_atomic(self):
with self.assertNumQueries(
4 if self.is_sqlite else (5 if DJANGO_VERSION >= (4, 2) else 3)
):
with self.assertNumQueries(3):
with transaction.atomic():
data1 = list(Test.objects.all())
t = Test.objects.create(name='test')
@ -117,9 +100,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, [t])
def test_successful_nested_read_atomic(self):
with self.assertNumQueries(
7 if self.is_sqlite else (8 if DJANGO_VERSION >= (4, 2) else 6)
):
with self.assertNumQueries(6):
with transaction.atomic():
list(Test.objects.all())
with transaction.atomic():
@ -134,9 +115,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
list(User.objects.all())
def test_unsuccessful_nested_read_atomic(self):
with self.assertNumQueries(
6 if self.is_sqlite else (7 if DJANGO_VERSION >= (4, 2) else 5)
):
with self.assertNumQueries(5):
with transaction.atomic():
try:
with transaction.atomic():
@ -149,9 +128,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
list(Test.objects.all())
def test_successful_nested_write_atomic(self):
with self.assertNumQueries(
13 if self.is_sqlite else (14 if DJANGO_VERSION >= (4, 2) else 12)
):
with self.assertNumQueries(12):
with transaction.atomic():
t1 = Test.objects.create(name='test1')
with transaction.atomic():
@ -168,9 +145,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data3, [t1, t2, t3, t4])
def test_unsuccessful_nested_write_atomic(self):
with self.assertNumQueries(
16 if self.is_sqlite else (17 if DJANGO_VERSION >= (4, 2) else 15)
):
with self.assertNumQueries(15):
with transaction.atomic():
t1 = Test.objects.create(name='test1')
try:

View file

@ -1,8 +1,5 @@
import sys
from unittest import skipIf, skipUnless
from django import VERSION as DJANGO_VERSION
from django.contrib.auth.models import User, Permission, Group
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import MultipleObjectsReturned
@ -14,10 +11,10 @@ from django.db.models.expressions import RawSQL
from django.test import TransactionTestCase, skipUnlessDBFeature
from .models import Test, TestParent, TestChild
from .test_utils import TestUtilsMixin
from .test_utils import TestUtilsMixin, FilteredTransactionTestCase
class WriteTestCase(TestUtilsMixin, TransactionTestCase):
class WriteTestCase(TestUtilsMixin, FilteredTransactionTestCase):
"""
Tests if every SQL request writing data is not cached and invalidates the
implied data.
@ -58,9 +55,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t, created = Test.objects.get_or_create(name='test')
self.assertTrue(created)
@ -82,18 +77,14 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(1):
self.assertListEqual(list(Test.objects.all()), [])
with self.assertNumQueries(
5 if self.is_sqlite else (6 if DJANGO_VERSION >= (4, 2) else 4)
):
with self.assertNumQueries(4):
t, created = Test.objects.update_or_create(
name='test', defaults={'public': True})
self.assertTrue(created)
self.assertEqual(t.name, 'test')
self.assertEqual(t.public, True)
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t, created = Test.objects.update_or_create(
name='test', defaults={'public': False})
self.assertFalse(created)
@ -102,9 +93,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
# The number of SQL queries doesnt decrease because update_or_create
# always calls an UPDATE, even when data wasnt changed.
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t, created = Test.objects.update_or_create(
name='test', defaults={'public': False})
self.assertFalse(created)
@ -119,21 +108,17 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
unsaved_tests = [Test(name='test%02d' % i) for i in range(1, 11)]
Test.objects.bulk_create(unsaved_tests)
self.assertEqual(Test.objects.count(), 10)
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
unsaved_tests = [Test(name='test%02d' % i) for i in range(1, 11)]
Test.objects.bulk_create(unsaved_tests)
self.assertEqual(Test.objects.count(), 20)
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
data2 = list(Test.objects.all())
self.assertEqual(len(data2), 20)
self.assertListEqual([t.name for t in data2],
@ -174,16 +159,12 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data1, [t1.name, t2.name])
self.assertListEqual(data2, [t1.name])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.bulk_create([Test(name='test%s' % i)
for i in range(2, 11)])
with self.assertNumQueries(1):
self.assertEqual(Test.objects.count(), 10)
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.all().delete()
with self.assertNumQueries(1):
self.assertEqual(Test.objects.count(), 0)
@ -378,9 +359,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data4, [user1, user2])
self.assertListEqual([u.n for u in data4], [1, 0])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.bulk_create([
Test(name='test3', owner=user1),
Test(name='test4', owner=user2),
@ -608,9 +587,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
data2 = list(Test.objects.select_related('owner'))
self.assertListEqual(data2, [])
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.bulk_create([
Test(name='test1', owner=u1),
Test(name='test2', owner=u2),
@ -624,9 +601,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(data3[2].owner, u2)
self.assertEqual(data3[3].owner, u1)
with self.assertNumQueries(
2 if self.is_sqlite else (3 if DJANGO_VERSION >= (4, 2) else 1)
):
with self.assertNumQueries(1):
Test.objects.filter(name__in=['test1', 'test2']).delete()
with self.assertNumQueries(1):
data4 = list(Test.objects.select_related('owner'))
@ -658,12 +633,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(data3[0].owner, u)
self.assertListEqual(list(data3[0].owner.groups.all()), [])
with self.assertNumQueries(
8 if self.is_postgresql and DJANGO_VERSION >= (4, 2)
else 4 if self.is_postgresql and DJANGO_VERSION >= (3, 0)
else 4 if self.is_mysql and DJANGO_VERSION >= (3, 0)
else 6
):
with self.assertNumQueries(4):
group = Group.objects.create(name='test_group')
permissions = list(Permission.objects.all()[:5])
group.permissions.add(*permissions)
@ -718,7 +688,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
@skipUnlessDBFeature('has_select_for_update')
def test_invalidate_select_for_update(self):
with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1):
with self.assertNumQueries(1):
Test.objects.bulk_create([Test(name='test1'), Test(name='test2')])
with self.assertNumQueries(1):
@ -876,9 +846,7 @@ class WriteTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertRaises(TestChild.DoesNotExist):
TestChild.objects.get()
with self.assertNumQueries(
3 if self.is_sqlite else (4 if DJANGO_VERSION >= (4, 2) else 2)
):
with self.assertNumQueries(2):
t_child = TestChild.objects.create(name='test_child')
with self.assertNumQueries(1):

View file

@ -3,7 +3,7 @@ envlist =
py{37,38,39,310}-django3.2-{sqlite3,postgresql,mysql}-{redis,memcached,pylibmc,locmem,filebased},
py{38,39,310}-django4.1-{sqlite3,postgresql,mysql}-{redis,memcached,pylibmc,locmem,filebased},
py{38,39,310,311}-django4.2-{sqlite3,postgresql,mysql}-{redis,memcached,pylibmc,locmem,filebased},
py{38,39,310,311}-djangomain-{sqlite3,postgresql,mysql}-{redis,memcached,pylibmc,locmem,filebased},
py{310,311}-djangomain-{sqlite3,postgresql,mysql}-{redis,memcached,pylibmc,locmem,filebased},
[testenv]
passenv = *