diff --git a/cachalot/tests/read.py b/cachalot/tests/read.py index b2b492f..d81f6cc 100644 --- a/cachalot/tests/read.py +++ b/cachalot/tests/read.py @@ -816,21 +816,21 @@ class ReadTestCase(TestUtilsMixin, FilteredTransactionTestCase): with self.assertRaises(TransactionManagementError): list(Test.objects.select_for_update()) - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): with transaction.atomic(): data1 = list(Test.objects.select_for_update()) self.assertListEqual(data1, [self.t1, self.t2]) self.assertListEqual([t.name for t in data1], ['test1', 'test2']) - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): with transaction.atomic(): data2 = list(Test.objects.select_for_update()) self.assertListEqual(data2, [self.t1, self.t2]) self.assertListEqual([t.name for t in data2], ['test1', 'test2']) - with self.assertNumQueries(4 if DJANGO_VERSION >= (4, 2) else 2): + with self.assertNumQueries(2): with transaction.atomic(): data3 = list(Test.objects.select_for_update()) data4 = list(Test.objects.select_for_update()) diff --git a/cachalot/tests/thread_safety.py b/cachalot/tests/thread_safety.py index 96400a2..cbd23dd 100644 --- a/cachalot/tests/thread_safety.py +++ b/cachalot/tests/thread_safety.py @@ -1,11 +1,10 @@ from threading import Thread -from django import VERSION as DJANGO_VERSION from django.db import connection, transaction -from django.test import TransactionTestCase, skipUnlessDBFeature +from django.test import skipUnlessDBFeature from .models import Test -from .test_utils import TestUtilsMixin +from .test_utils import TestUtilsMixin, FilteredTransactionTestCase class TestThread(Thread): @@ -20,7 +19,7 @@ class TestThread(Thread): @skipUnlessDBFeature('test_db_allows_multiple_connections') -class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase): +class ThreadSafetyTestCase(TestUtilsMixin, FilteredTransactionTestCase): def test_concurrent_caching(self): t1 = TestThread().start_and_join() t = Test.objects.create(name='test') @@ -30,7 +29,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase): self.assertEqual(t2, t) def test_concurrent_caching_during_atomic(self): - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): with transaction.atomic(): t1 = TestThread().start_and_join() t = Test.objects.create(name='test') @@ -46,7 +45,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase): def test_concurrent_caching_before_and_during_atomic_1(self): t1 = TestThread().start_and_join() - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): with transaction.atomic(): t2 = TestThread().start_and_join() t = Test.objects.create(name='test') @@ -61,7 +60,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase): def test_concurrent_caching_before_and_during_atomic_2(self): t1 = TestThread().start_and_join() - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): with transaction.atomic(): t = Test.objects.create(name='test') t2 = TestThread().start_and_join() @@ -74,7 +73,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase): self.assertEqual(data, t) def test_concurrent_caching_during_and_after_atomic_1(self): - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): with transaction.atomic(): t1 = TestThread().start_and_join() t = Test.objects.create(name='test') @@ -89,7 +88,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase): self.assertEqual(data, t) def test_concurrent_caching_during_and_after_atomic_2(self): - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): with transaction.atomic(): t = Test.objects.create(name='test') t1 = TestThread().start_and_join() @@ -104,7 +103,7 @@ class ThreadSafetyTestCase(TestUtilsMixin, TransactionTestCase): self.assertEqual(data, t) def test_concurrent_caching_during_and_after_atomic_3(self): - with self.assertNumQueries(3 if DJANGO_VERSION >= (4, 2) else 1): + with self.assertNumQueries(1): with transaction.atomic(): t1 = TestThread().start_and_join() t = Test.objects.create(name='test')