Add final SQL check when looking up involved tables (#199)

This commit is contained in:
Dominik Bartenstein 2021-12-27 18:30:44 +01:00 committed by GitHub
parent f1087da6f9
commit 434a5759de
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 454 additions and 64 deletions

View file

@ -1,6 +1,12 @@
Whats new in django-cachalot?
==============================
2.4.6
-----
- Add final SQL check to include potentially overlooked tables when looking up involved tables (#199)
- Add ``CACHALOT_FINAL_SQL_CHECK`` for enabling Final SQL check
2.4.5
-----

View file

@ -61,6 +61,7 @@ class Settings(object):
CACHALOT_ADDITIONAL_TABLES = ()
CACHALOT_QUERY_KEYGEN = 'cachalot.utils.get_query_cache_key'
CACHALOT_TABLE_KEYGEN = 'cachalot.utils.get_table_cache_key'
CACHALOT_FINAL_SQL_CHECK = False
@classmethod
def add_converter(cls, setting):

View file

@ -4,7 +4,7 @@ from django.dispatch import receiver
from ..settings import cachalot_settings
from .read import ReadTestCase, ParameterTypeTestCase
from .write import WriteTestCase, DatabaseCommandTestCase
from .transaction import AtomicTestCase
from .transaction import AtomicCacheTestCase, AtomicTestCase
from .thread_safety import ThreadSafetyTestCase
from .multi_db import MultiDatabaseTestCase
from .settings import SettingsTestCase

View file

@ -45,6 +45,10 @@ class TestParent(Model):
class TestChild(TestParent):
"""
A OneToOneField to TestParent is automatically added here.
https://docs.djangoproject.com/en/3.2/topics/db/models/#multi-table-inheritance
"""
public = BooleanField(default=False)
permissions = ManyToManyField('auth.Permission', blank=True)

View file

@ -5,17 +5,18 @@ from unittest import skipUnless
from django.contrib.postgres.functions import TransactionNow
from django.db import connection
from django.test import TransactionTestCase, override_settings
from psycopg2.extras import NumericRange, DateRange, DateTimeTZRange
from psycopg2.extras import DateRange, DateTimeTZRange, NumericRange
from pytz import timezone
from ..utils import UncachableQuery
from .api import invalidate
from .models import PostgresModel, Test
from .test_utils import TestUtilsMixin
from .tests_decorators import all_final_sql_checks, no_final_sql_check, with_final_sql_check
# FIXME: Add tests for aggregations.
def is_pg_field_available(name):
fields = []
try:
@ -91,14 +92,18 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
self.obj1.save()
self.obj2.save()
@all_final_sql_checks
def test_unaccent(self):
Test.objects.create(name='Clémentine')
Test.objects.create(name='Clementine')
obj1 = Test.objects.create(name='Clémentine')
obj2 = Test.objects.create(name='Clementine')
qs = (Test.objects.filter(name__unaccent='Clémentine')
.values_list('name', flat=True))
self.assert_tables(qs, Test)
self.assert_query_cached(qs, ['Clementine', 'Clémentine'])
obj1.delete()
obj2.delete()
@all_final_sql_checks
def test_int_array(self):
with self.assertNumQueries(1):
data1 = [o.int_array for o in PostgresModel.objects.all()]
@ -145,6 +150,7 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, PostgresModel)
self.assert_query_cached(qs, [[1, 2, 3]])
@all_final_sql_checks
def test_hstore(self):
with self.assertNumQueries(1):
data1 = [o.hstore for o in PostgresModel.objects.all()]
@ -198,6 +204,7 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, PostgresModel)
self.assert_query_cached(qs, [{'a': '1', 'b': '2'}])
@all_final_sql_checks
@skipUnless(is_pg_field_available("JSONField"),
"JSONField was removed in Dj 4.0")
def test_json(self):
@ -309,6 +316,7 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(list(qs.all()),
[self.obj1.json, self.obj2.json])
@all_final_sql_checks
def test_int_range(self):
with self.assertNumQueries(1):
data1 = [o.int_range for o in PostgresModel.objects.all()]
@ -378,13 +386,16 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, PostgresModel)
self.assert_query_cached(qs, [NumericRange(1900, 2000)])
PostgresModel.objects.create(int_range=[1900, 1900])
obj = PostgresModel.objects.create(int_range=[1900, 1900])
qs = (PostgresModel.objects.filter(int_range__isempty=True)
.values_list('int_range', flat=True))
self.assert_tables(qs, PostgresModel)
self.assert_query_cached(qs, [NumericRange(empty=True)])
obj.delete()
@all_final_sql_checks
@skipUnless(is_pg_field_available("FloatRangeField"),
"FloatRangeField was removed in Dj 3.1")
def test_float_range(self):
@ -398,6 +409,7 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
NumericRange(Decimal('-1000.0'), Decimal('9.87654321')),
NumericRange(Decimal('0.0'))])
@all_final_sql_checks
@skipUnless(is_pg_field_available("DecimalRangeField"),
"DecimalRangeField was added in Dj 2.2")
def test_decimal_range(self):
@ -407,6 +419,7 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
NumericRange(Decimal('-1000.0'), Decimal('9.87654321')),
NumericRange(Decimal('0.0'))])
@all_final_sql_checks
def test_date_range(self):
qs = PostgresModel.objects.values_list('date_range', flat=True)
self.assert_tables(qs, PostgresModel)
@ -414,6 +427,7 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
DateRange(date(1678, 3, 4), date(1741, 7, 28)),
DateRange(date(1989, 1, 30))])
@all_final_sql_checks
def test_datetime_range(self):
qs = PostgresModel.objects.values_list('datetime_range', flat=True)
self.assert_tables(qs, PostgresModel)
@ -422,6 +436,7 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
tzinfo=timezone('Europe/Paris'))),
DateTimeTZRange(bounds='()')])
@all_final_sql_checks
def test_transaction_now(self):
"""
Checks that queries with a TransactionNow() parameter are not cached.
@ -431,3 +446,5 @@ class PostgresReadTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertRaises(UncachableQuery):
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [obj], after=1)
obj.delete()

View file

@ -4,6 +4,7 @@ from uuid import UUID
from decimal import Decimal
from django import VERSION as django_version
from django.conf import settings
from django.contrib.auth.models import Group, Permission, User
from django.contrib.contenttypes.models import ContentType
from django.db import (
@ -22,6 +23,8 @@ from ..utils import UncachableQuery
from .models import Test, TestChild, TestParent, UnmanagedModel
from .test_utils import TestUtilsMixin
from .tests_decorators import all_final_sql_checks, with_final_sql_check, no_final_sql_check
def is_field_available(name):
fields = []
@ -125,6 +128,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, data1)
self.assertListEqual(data2, [self.t1, self.t2])
@all_final_sql_checks
def test_filter(self):
qs = Test.objects.filter(public=True)
self.assert_tables(qs, Test)
@ -142,11 +146,13 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [self.t1])
@all_final_sql_checks
def test_filter_empty(self):
qs = Test.objects.filter(public=True, name='user')
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [])
@all_final_sql_checks
def test_exclude(self):
qs = Test.objects.exclude(public=True)
self.assert_tables(qs, Test)
@ -156,11 +162,13 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [self.t1])
@all_final_sql_checks
def test_slicing(self):
qs = Test.objects.all()[:1]
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [self.t1])
@all_final_sql_checks
def test_order_by(self):
qs = Test.objects.order_by('pk')
self.assert_tables(qs, Test)
@ -170,12 +178,38 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [self.t2, self.t1])
@all_final_sql_checks
def test_random_order_by(self):
qs = Test.objects.order_by('?')
with self.assertRaises(UncachableQuery):
self.assert_tables(qs, Test)
self.assert_query_cached(qs, after=1, compare_results=False)
@with_final_sql_check
def test_order_by_field_of_another_table_with_check(self):
qs = Test.objects.order_by('owner__username')
self.assert_tables(qs, Test, User)
self.assert_query_cached(qs, [self.t2, self.t1])
@no_final_sql_check
def test_order_by_field_of_another_table_no_check(self):
qs = Test.objects.order_by('owner__username')
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [self.t2, self.t1])
@with_final_sql_check
def test_order_by_field_of_another_table_with_expression_with_check(self):
qs = Test.objects.order_by(Coalesce('name', 'owner__username'))
self.assert_tables(qs, Test, User)
self.assert_query_cached(qs, [self.t1, self.t2])
@no_final_sql_check
def test_order_by_field_of_another_table_with_expression_no_check(self):
qs = Test.objects.order_by(Coalesce('name', 'owner__username'))
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [self.t1, self.t2])
@all_final_sql_checks
@skipIf(connection.vendor == 'mysql',
'MySQL does not support limit/offset on a subquery. '
'Since Django only applies ordering in subqueries when they are '
@ -187,11 +221,13 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test)
self.assert_query_cached(qs, after=1, compare_results=False)
@all_final_sql_checks
def test_reverse(self):
qs = Test.objects.reverse()
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [self.t2, self.t1])
@all_final_sql_checks
def test_distinct(self):
# We ensure that the query without distinct should return duplicate
# objects, in order to have a real-world example.
@ -222,12 +258,14 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertDictEqual(data2, data1)
self.assertDictEqual(data2, {self.t2.pk: self.t2})
@all_final_sql_checks
def test_values(self):
qs = Test.objects.values('name', 'public')
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [{'name': 'test1', 'public': False},
{'name': 'test2', 'public': True}])
@all_final_sql_checks
def test_values_list(self):
qs = Test.objects.values_list('name', flat=True)
self.assert_tables(qs, Test)
@ -249,18 +287,21 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(data2, data1)
self.assertEqual(data2, self.t2)
@all_final_sql_checks
def test_dates(self):
qs = Test.objects.dates('date', 'year')
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [datetime.date(1789, 1, 1),
datetime.date(1944, 1, 1)])
@all_final_sql_checks
def test_datetimes(self):
qs = Test.objects.datetimes('datetime', 'hour')
self.assert_tables(qs, Test)
self.assert_query_cached(qs, [datetime.datetime(1789, 7, 14, 16),
datetime.datetime(1944, 6, 6, 6)])
@all_final_sql_checks
@skipIf(connection.vendor == 'mysql',
'Time zones are not supported by MySQL.')
@override_settings(USE_TZ=True)
@ -271,6 +312,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
datetime.datetime(1789, 7, 14, 16, tzinfo=UTC),
datetime.datetime(1944, 6, 6, 6, tzinfo=UTC)])
@all_final_sql_checks
def test_foreign_key(self):
with self.assertNumQueries(3):
data1 = [t.owner for t in Test.objects.all()]
@ -283,7 +325,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test, User)
self.assert_query_cached(qs, [self.user.pk, self.admin.pk])
def test_many_to_many(self):
def _test_many_to_many(self):
u = User.objects.create_user('test_user')
ct = ContentType.objects.get_for_model(User)
u.user_permissions.add(
@ -293,50 +335,93 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
name='Can touch', content_type=ct, codename='touch'),
Permission.objects.create(
name='Can cuddle', content_type=ct, codename='cuddle'))
qs = u.user_permissions.values_list('codename', flat=True)
return u.user_permissions.values_list('codename', flat=True)
@with_final_sql_check
def test_many_to_many_when_sql_check(self):
qs = self._test_many_to_many()
self.assert_tables(qs, User, User.user_permissions.through, Permission, ContentType)
self.assert_query_cached(qs, ['cuddle', 'discuss', 'touch'])
@no_final_sql_check
def test_many_to_many_when_no_sql_check(self):
qs = self._test_many_to_many()
self.assert_tables(qs, User, User.user_permissions.through, Permission)
self.assert_query_cached(qs, ['cuddle', 'discuss', 'touch'])
@all_final_sql_checks
def test_subquery(self):
additional_tables = []
if django_version[0] >= 4 and settings.CACHALOT_FINAL_SQL_CHECK:
# with Django 4.0 comes some query optimalizations that do selects little differently.
additional_tables.append('django_content_type')
qs = Test.objects.filter(owner__in=User.objects.all())
self.assert_tables(qs, Test, User)
self.assert_query_cached(qs, [self.t1, self.t2])
qs = Test.objects.filter(
owner__groups__permissions__in=Permission.objects.all())
self.assert_tables(qs, Test, User, User.groups.through, Group,
Group.permissions.through, Permission)
owner__groups__permissions__in=Permission.objects.all()
)
self.assert_tables(
qs, Test, User, User.groups.through, Group,
Group.permissions.through, Permission,
*additional_tables
)
self.assert_query_cached(qs, [self.t1, self.t1, self.t1])
qs = Test.objects.filter(
owner__groups__permissions__in=Permission.objects.all()
).distinct()
self.assert_tables(qs, Test, User, User.groups.through, Group,
Group.permissions.through, Permission)
self.assert_tables(
qs, Test, User, User.groups.through, Group,
Group.permissions.through, Permission,
*additional_tables
)
self.assert_query_cached(qs, [self.t1])
qs = TestChild.objects.exclude(permissions__isnull=True)
self.assert_tables(qs, TestParent, TestChild,
TestChild.permissions.through, Permission)
self.assert_tables(
qs, TestParent, TestChild,
TestChild.permissions.through, Permission
)
self.assert_query_cached(qs, [])
qs = TestChild.objects.exclude(permissions__name='')
self.assert_tables(qs, TestParent, TestChild,
TestChild.permissions.through, Permission)
self.assert_tables(
qs, TestParent, TestChild,
TestChild.permissions.through, Permission
)
self.assert_query_cached(qs, [])
def test_custom_subquery(self):
@with_final_sql_check
def test_custom_subquery_with_check(self):
tests = Test.objects.filter(permission=OuterRef('pk')).values('name')
qs = Permission.objects.annotate(first_permission=Subquery(tests[:1]))
self.assert_tables(qs, Permission, Test, ContentType)
self.assert_query_cached(qs, list(Permission.objects.all()))
@no_final_sql_check
def test_custom_subquery_no_check(self):
tests = Test.objects.filter(permission=OuterRef('pk')).values('name')
qs = Permission.objects.annotate(first_permission=Subquery(tests[:1]))
self.assert_tables(qs, Permission, Test)
self.assert_query_cached(qs, list(Permission.objects.all()))
@with_final_sql_check
def test_custom_subquery_exists(self):
tests = Test.objects.filter(permission=OuterRef('pk'))
qs = Permission.objects.annotate(has_tests=Exists(tests))
self.assert_tables(qs, Permission, Test, ContentType)
self.assert_query_cached(qs, list(Permission.objects.all()))
@no_final_sql_check
def test_custom_subquery_exists(self):
tests = Test.objects.filter(permission=OuterRef('pk'))
qs = Permission.objects.annotate(has_tests=Exists(tests))
self.assert_tables(qs, Permission, Test)
self.assert_query_cached(qs, list(Permission.objects.all()))
@all_final_sql_checks
def test_raw_subquery(self):
with self.assertNumQueries(0):
raw_sql = RawSQL('SELECT id FROM auth_permission WHERE id = %s',
@ -350,28 +435,34 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test, Permission)
self.assert_query_cached(qs, [self.t1])
@all_final_sql_checks
def test_aggregate(self):
Test.objects.create(name='test3', owner=self.user)
test3 = Test.objects.create(name='test3', owner=self.user)
with self.assertNumQueries(1):
n1 = User.objects.aggregate(n=Count('test'))['n']
with self.assertNumQueries(0):
n2 = User.objects.aggregate(n=Count('test'))['n']
self.assertEqual(n2, n1)
self.assertEqual(n2, 3)
test3.delete()
@all_final_sql_checks
def test_annotate(self):
Test.objects.create(name='test3', owner=self.user)
test3 = Test.objects.create(name='test3', owner=self.user)
qs = (User.objects.annotate(n=Count('test')).order_by('pk')
.values_list('n', flat=True))
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [2, 1])
test3.delete()
@all_final_sql_checks
def test_annotate_subquery(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(first_test=Subquery(tests[:1]))
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
@all_final_sql_checks
def test_annotate_case_with_when_and_query_in_default(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
@ -383,6 +474,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
@all_final_sql_checks
def test_annotate_case_with_when(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
@ -394,6 +486,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
@all_final_sql_checks
def test_annotate_coalesce(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
@ -405,6 +498,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
@all_final_sql_checks
def test_annotate_raw(self):
qs = User.objects.annotate(
perm_id=RawSQL('SELECT id FROM auth_permission WHERE id = %s',
@ -413,6 +507,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, User, Permission)
self.assert_query_cached(qs, [self.user, self.admin])
@all_final_sql_checks
def test_only(self):
with self.assertNumQueries(1):
t1 = Test.objects.only('name').first()
@ -428,6 +523,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(t2.name, t1.name)
self.assertEqual(t2.public, t1.public)
@all_final_sql_checks
def test_defer(self):
with self.assertNumQueries(1):
t1 = Test.objects.defer('name').first()
@ -443,6 +539,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(t2.name, t1.name)
self.assertEqual(t2.public, t1.public)
@all_final_sql_checks
def test_select_related(self):
# Simple select_related
with self.assertNumQueries(1):
@ -468,6 +565,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertEqual(t4, t3)
self.assertEqual(t4, self.t1)
@all_final_sql_checks
def test_prefetch_related(self):
# Simple prefetch_related
with self.assertNumQueries(2):
@ -530,35 +628,74 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(permissions8, permissions7)
self.assertListEqual(permissions8, self.group__permissions)
def test_filtered_relation(self):
@all_final_sql_checks
def test_test_parent(self):
child = TestChild.objects.create(name='child')
qs = TestChild.objects.filter(name='child')
self.assert_query_cached(qs)
parent = TestParent.objects.all().first()
parent.name = 'another name'
parent.save()
child = TestChild.objects.all().first()
self.assertEqual(child.name, 'another name')
def _filtered_relation(self):
"""
Resulting query:
SELECT "cachalot_testparent"."id", "cachalot_testparent"."name",
"cachalot_testchild"."testparent_ptr_id", "cachalot_testchild"."public"
FROM "cachalot_testchild" INNER JOIN "cachalot_testparent" ON
("cachalot_testchild"."testparent_ptr_id" = "cachalot_testparent"."id")
"""
from django.db.models import FilteredRelation
qs = TestChild.objects.annotate(
filtered_permissions=FilteredRelation(
'permissions', condition=Q(permissions__pk__gt=1)))
self.assert_tables(qs, TestChild)
'permissions', condition=Q(permissions__pk__gt=1))
)
return qs
def _filtered_relation_common_asserts(self, qs):
self.assert_query_cached(qs)
values_qs = qs.values('filtered_permissions')
self.assert_tables(
values_qs, TestChild, TestChild.permissions.through, Permission)
values_qs, TestParent, TestChild, TestChild.permissions.through, Permission
)
self.assert_query_cached(values_qs)
filtered_qs = qs.filter(filtered_permissions__pk__gt=2)
self.assert_tables(
values_qs, TestChild, TestChild.permissions.through, Permission)
values_qs, TestParent, TestChild, TestChild.permissions.through, Permission
)
self.assert_query_cached(filtered_qs)
@skipUnlessDBFeature('supports_select_union')
def test_union(self):
qs = (Test.objects.filter(pk__lt=5)
| Test.objects.filter(permission__name__contains='a'))
@with_final_sql_check
def test_filtered_relation_with_check(self):
qs = self._filtered_relation()
self.assert_tables(qs, TestParent, TestChild)
self._filtered_relation_common_asserts(qs)
@no_final_sql_check
def test_filtered_relation_no_check(self):
qs = self._filtered_relation()
self.assert_tables(qs, TestChild)
self._filtered_relation_common_asserts(qs)
def _test_union(self, check: bool):
qs = (
Test.objects.filter(pk__lt=5)
| Test.objects.filter(permission__name__contains='a')
)
self.assert_tables(qs, Test, Permission)
self.assert_query_cached(qs)
with self.assertRaisesMessage(
AssertionError if django_version[0] < 4 else TypeError,
'Cannot combine queries on two different base models.'):
AssertionError if django_version[0] < 4 else TypeError,
'Cannot combine queries on two different base models.'
):
Test.objects.all() | Permission.objects.all()
qs = Test.objects.filter(pk__lt=5)
@ -576,12 +713,25 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
qs = qs.order_by()
sub_qs = sub_qs.order_by()
qs = qs.union(sub_qs)
self.assert_tables(qs, Test, Permission)
tables = {Test, Permission}
# Sqlite does not do an ORDER BY django_content_type
if not self.is_sqlite and check:
tables.add(ContentType)
self.assert_tables(qs, *tables)
with self.assertRaises((ProgrammingError, OperationalError)):
self.assert_query_cached(qs)
@skipUnlessDBFeature('supports_select_intersection')
def test_intersection(self):
@with_final_sql_check
@skipUnlessDBFeature('supports_select_union')
def test_union_with_sql_check(self):
self._test_union(check=True)
@no_final_sql_check
@skipUnlessDBFeature('supports_select_union')
def test_union_with_sql_check(self):
self._test_union(check=False)
def _test_intersection(self, check: bool):
qs = (Test.objects.filter(pk__lt=5)
& Test.objects.filter(permission__name__contains='a'))
self.assert_tables(qs, Test, Permission)
@ -607,12 +757,24 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
qs = qs.order_by()
sub_qs = sub_qs.order_by()
qs = qs.intersection(sub_qs)
self.assert_tables(qs, Test, Permission)
tables = {Test, Permission}
if not self.is_sqlite and check:
tables.add(ContentType)
self.assert_tables(qs, *tables)
with self.assertRaises((ProgrammingError, OperationalError)):
self.assert_query_cached(qs)
@skipUnlessDBFeature('supports_select_difference')
def test_difference(self):
@with_final_sql_check
@skipUnlessDBFeature('supports_select_intersection')
def test_intersection_with_check(self):
self._test_intersection(check=True)
@no_final_sql_check
@skipUnlessDBFeature('supports_select_intersection')
def test_intersection_with_check(self):
self._test_intersection(check=False)
def _test_difference(self, check: bool):
qs = Test.objects.filter(pk__lt=5)
sub_qs = Test.objects.filter(permission__name__contains='a')
if self.is_sqlite:
@ -628,10 +790,23 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
qs = qs.order_by()
sub_qs = sub_qs.order_by()
qs = qs.difference(sub_qs)
self.assert_tables(qs, Test, Permission)
tables = {Test, Permission}
if not self.is_sqlite and check:
tables.add(ContentType)
self.assert_tables(qs, *tables)
with self.assertRaises((ProgrammingError, OperationalError)):
self.assert_query_cached(qs)
@with_final_sql_check
@skipUnlessDBFeature('supports_select_difference')
def test_difference_with_check(self):
self._test_difference(check=True)
@no_final_sql_check
@skipUnlessDBFeature('supports_select_difference')
def test_difference_with_check(self):
self._test_difference(check=False)
@skipUnlessDBFeature('has_select_for_update')
def test_select_for_update(self):
"""
@ -665,6 +840,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual([t.name for t in data4],
['test1', 'test2'])
@all_final_sql_checks
def test_having(self):
qs = (User.objects.annotate(n=Count('user_permissions')).filter(n__gte=1))
self.assert_tables(qs, User, User.user_permissions.through, Permission)
@ -697,6 +873,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, [self.t1, self.t2])
self.assertListEqual([o.username_length for o in data2], [4, 5])
@all_final_sql_checks
def test_extra_where(self):
sql_condition = ("owner_id IN "
"(SELECT id FROM auth_user WHERE username = 'admin')")
@ -704,12 +881,14 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test, User)
self.assert_query_cached(qs, [self.t2])
@all_final_sql_checks
def test_extra_tables(self):
qs = Test.objects.extra(tables=['auth_user'],
select={'extra_id': 'auth_user.id'})
self.assert_tables(qs, Test, User)
self.assert_query_cached(qs)
@all_final_sql_checks
def test_extra_order_by(self):
qs = Test.objects.extra(order_by=['-cachalot_test.name'])
self.assert_tables(qs, Test)
@ -850,6 +1029,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assertListEqual(data2, data1)
self.assertListEqual(data2, [(1,), (2,)])
@all_final_sql_checks
def test_missing_table_cache_key(self):
qs = Test.objects.all()
self.assert_tables(qs, Test)
@ -861,6 +1041,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_query_cached(qs)
@all_final_sql_checks
def test_broken_query_cache_value(self):
"""
In some undetermined cases, cache.get_many return wrong values such
@ -889,6 +1070,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertRaises(Test.DoesNotExist):
Test.objects.get(name='Clémentine')
@all_final_sql_checks
def test_unicode_table_name(self):
"""
Tests if using unicode in table names does not break caching.
@ -908,6 +1090,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
with connection.cursor() as cursor:
cursor.execute('DROP TABLE %s;' % table_name)
@all_final_sql_checks
def test_unmanaged_model(self):
qs = UnmanagedModel.objects.all()
self.assert_tables(qs, UnmanagedModel)
@ -917,9 +1100,10 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
"""Check that queries with a Now() annotation are not cached #193"""
qs = Test.objects.annotate(now=Now())
self.assert_query_cached(qs, after=1)
class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
@all_final_sql_checks
def test_tuple(self):
qs = Test.objects.filter(pk__in=(1, 2, 3))
self.assert_tables(qs, Test)
@ -929,6 +1113,7 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test)
self.assert_query_cached(qs)
@all_final_sql_checks
def test_list(self):
qs = Test.objects.filter(pk__in=[1, 2, 3])
self.assert_tables(qs, Test)
@ -949,6 +1134,7 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, Test)
self.assert_query_cached(qs)
@all_final_sql_checks
def test_binary(self):
"""
Binary data should be cached on PostgreSQL & MySQL, but not on SQLite,
@ -990,11 +1176,12 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(0):
Test.objects.get(a_float=0.123456789)
@all_final_sql_checks
def test_decimal(self):
with self.assertNumQueries(1):
Test.objects.create(name='test1', a_decimal=Decimal('123.45'))
test1 = Test.objects.create(name='test1', a_decimal=Decimal('123.45'))
with self.assertNumQueries(1):
Test.objects.create(name='test1', a_decimal=Decimal('12.3'))
test2 = Test.objects.create(name='test2', a_decimal=Decimal('12.3'))
qs = Test.objects.values_list('a_decimal', flat=True).filter(
a_decimal__isnull=False).order_by('a_decimal')
@ -1006,11 +1193,15 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(0):
Test.objects.get(a_decimal=Decimal('123.45'))
test1.delete()
test2.delete()
@all_final_sql_checks
def test_ipv4_address(self):
with self.assertNumQueries(1):
Test.objects.create(name='test1', ip='127.0.0.1')
test1 = Test.objects.create(name='test1', ip='127.0.0.1')
with self.assertNumQueries(1):
Test.objects.create(name='test2', ip='192.168.0.1')
test2 = Test.objects.create(name='test2', ip='192.168.0.1')
qs = Test.objects.values_list('ip', flat=True).filter(
ip__isnull=False).order_by('ip')
@ -1022,11 +1213,15 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(0):
Test.objects.get(ip='127.0.0.1')
test1.delete()
test2.delete()
@all_final_sql_checks
def test_ipv6_address(self):
with self.assertNumQueries(1):
Test.objects.create(name='test1', ip='2001:db8:a0b:12f0::1/64')
test1 = Test.objects.create(name='test1', ip='2001:db8:a0b:12f0::1/64')
with self.assertNumQueries(1):
Test.objects.create(name='test2', ip='2001:db8:0:85a3::ac1f:8001')
test2 = Test.objects.create(name='test2', ip='2001:db8:0:85a3::ac1f:8001')
qs = Test.objects.values_list('ip', flat=True).filter(
ip__isnull=False).order_by('ip')
@ -1039,11 +1234,15 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(0):
Test.objects.get(ip='2001:db8:0:85a3::ac1f:8001')
test1.delete()
test2.delete()
@all_final_sql_checks
def test_duration(self):
with self.assertNumQueries(1):
Test.objects.create(name='test1', duration=datetime.timedelta(30))
test1 = Test.objects.create(name='test1', duration=datetime.timedelta(30))
with self.assertNumQueries(1):
Test.objects.create(name='test2', duration=datetime.timedelta(60))
test2 = Test.objects.create(name='test2', duration=datetime.timedelta(60))
qs = Test.objects.values_list('duration', flat=True).filter(
duration__isnull=False).order_by('duration')
@ -1056,12 +1255,16 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(0):
Test.objects.get(duration=datetime.timedelta(30))
test1.delete()
test2.delete()
@all_final_sql_checks
def test_uuid(self):
with self.assertNumQueries(1):
Test.objects.create(name='test1',
test1 = Test.objects.create(name='test1',
uuid='1cc401b7-09f4-4520-b8d0-c267576d196b')
with self.assertNumQueries(1):
Test.objects.create(name='test2',
test2 = Test.objects.create(name='test2',
uuid='ebb3b6e1-1737-4321-93e3-4c35d61ff491')
qs = Test.objects.values_list('uuid', flat=True).filter(
@ -1076,6 +1279,9 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(0):
Test.objects.get(uuid=UUID('1cc401b7-09f4-4520-b8d0-c267576d196b'))
test1.delete()
test2.delete()
def test_now(self):
"""
Checks that queries with a Now() parameter are not cached.

View file

@ -1,17 +1,19 @@
from time import sleep
from unittest import skipIf
from unittest.mock import MagicMock, patch
from django.conf import settings
from django.contrib.auth.models import User
from django.core.cache import DEFAULT_CACHE_ALIAS
from django.core.checks import run_checks, Tags, Warning, Error
from django.core.checks import Error, Tags, Warning, run_checks
from django.db import connection
from django.test import TransactionTestCase
from django.test.utils import override_settings
from ..api import invalidate
from ..settings import SUPPORTED_ONLY, SUPPORTED_DATABASE_ENGINES
from .models import Test, TestParent, TestChild, UnmanagedModel
from ..settings import SUPPORTED_DATABASE_ENGINES, SUPPORTED_ONLY
from ..utils import _get_tables
from .models import Test, TestChild, TestParent, UnmanagedModel
from .test_utils import TestUtilsMixin
@ -314,3 +316,29 @@ class SettingsTestCase(TestUtilsMixin, TransactionTestCase):
with self.settings(CACHALOT_DATABASES='invalid value'):
errors = run_checks(tags=[Tags.compatibility])
self.assertListEqual(errors, [error002])
def call_get_tables(self):
qs = Test.objects.all()
compiler_mock = MagicMock()
compiler_mock.__cachalot_generated_sql = ''
tables = _get_tables(qs.db, qs.query, compiler_mock)
self.assertTrue(tables)
return tables
@override_settings(CACHALOT_FINAL_SQL_CHECK=True)
@patch('cachalot.utils._get_tables_from_sql')
def test_cachalot_final_sql_check_when_true(self, _get_tables_from_sql):
_get_tables_from_sql.return_value = {'patched'}
tables = self.call_get_tables()
_get_tables_from_sql.assert_called_once()
self.assertIn('patched', tables)
@override_settings(CACHALOT_FINAL_SQL_CHECK=False)
@patch('cachalot.utils._get_tables_from_sql')
def test_cachalot_final_sql_check_when_false(self, _get_tables_from_sql):
_get_tables_from_sql.return_value = {'patched'}
tables = self.call_get_tables()
_get_tables_from_sql.assert_not_called()
self.assertNotIn('patched', tables)

View file

@ -2,8 +2,8 @@ from django import VERSION as DJANGO_VERSION
from django.core.management.color import no_style
from django.db import connection, transaction
from .models import PostgresModel
from ..utils import _get_tables
from .models import PostgresModel
class TestUtilsMixin:
@ -36,7 +36,7 @@ class TestUtilsMixin:
def assert_tables(self, queryset, *tables):
tables = {table if isinstance(table, str)
else table._meta.db_table for table in tables}
self.assertSetEqual(_get_tables(queryset.db, queryset.query), tables)
self.assertSetEqual(_get_tables(queryset.db, queryset.query), tables, str(queryset.query))
def assert_query_cached(self, queryset, result=None, result_type=None,
compare_results=True, before=1, after=0):

View file

@ -0,0 +1,50 @@
import logging
from functools import wraps
from django.core.cache import cache
from django.test.utils import override_settings
logger = logging.getLogger(__name__)
def all_final_sql_checks(func):
"""
Runs test as two sub-tests:
one with CACHALOT_FINAL_SQL_CHECK setting True, one with False
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
for final_sql_check in (True, False):
with self.subTest(msg=f'CACHALOT_FINAL_SQL_CHECK = {final_sql_check}'):
with override_settings(
CACHALOT_FINAL_SQL_CHECK=final_sql_check
):
func(self, *args, **kwargs)
cache.clear()
return wrapper
def no_final_sql_check(func):
"""
Runs test with CACHALOT_FINAL_SQL_CHECK = False
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
with override_settings(CACHALOT_FINAL_SQL_CHECK=False):
func(self, *args, **kwargs)
return wrapper
def with_final_sql_check(func):
"""
Runs test with CACHALOT_FINAL_SQL_CHECK = True
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
with override_settings(CACHALOT_FINAL_SQL_CHECK=True):
func(self, *args, **kwargs)
return wrapper

View file

@ -1,6 +1,8 @@
from cachalot.transaction import AtomicCache
from django.contrib.auth.models import User
from django.core.cache import cache
from django.db import transaction, connection, IntegrityError
from django.test import TransactionTestCase, skipUnlessDBFeature
from django.test import SimpleTestCase, TransactionTestCase, skipUnlessDBFeature
from .models import Test
from .test_utils import TestUtilsMixin
@ -167,7 +169,7 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
with self.assertNumQueries(1):
data3 = list(Test.objects.all())
self.assertListEqual(data3, [t1])
@skipUnlessDBFeature('can_defer_constraint_checks')
def test_deferred_error(self):
"""
@ -187,3 +189,13 @@ class AtomicTestCase(TestUtilsMixin, TransactionTestCase):
'-- ' + Test._meta.db_table) # Should invalidate Test.
with self.assertNumQueries(1):
list(Test.objects.all())
class AtomicCacheTestCase(SimpleTestCase):
def setUp(self):
self.atomic_cache = AtomicCache(cache, 'db_alias')
def test_set(self):
self.assertDictEqual(self.atomic_cache, {})
self.atomic_cache.set('key', 'value', None)
self.assertDictEqual(self.atomic_cache, {'key': 'value'})

View file

@ -3,7 +3,7 @@ from .settings import cachalot_settings
class AtomicCache(dict):
def __init__(self, parent_cache, db_alias):
super(AtomicCache, self).__init__()
super().__init__()
self.parent_cache = parent_cache
self.db_alias = db_alias
self.to_be_invalidated = set()

View file

@ -83,6 +83,10 @@ def get_query_cache_key(compiler):
check_parameter_types(params)
cache_key = '%s:%s:%s' % (compiler.using, sql,
[str(p) for p in params])
# Set attribute on compiler for later access
# to the generated SQL. This prevents another as_sql() call!
compiler.__cachalot_generated_sql = sql.lower()
return sha1(cache_key.encode('utf-8')).hexdigest()
@ -101,9 +105,23 @@ def get_table_cache_key(db_alias, table):
return sha1(cache_key.encode('utf-8')).hexdigest()
def _get_tables_from_sql(connection, lowercased_sql):
return {t for t in connection.introspection.django_table_names()
+ cachalot_settings.CACHALOT_ADDITIONAL_TABLES if t in lowercased_sql}
def _get_tables_from_sql(connection, lowercased_sql, enable_quote: bool = False):
"""Returns names of involved tables after analyzing the final SQL query."""
return {table for table in (connection.introspection.django_table_names()
+ cachalot_settings.CACHALOT_ADDITIONAL_TABLES)
if _quote_table_name(table, connection, enable_quote) in lowercased_sql}
def _quote_table_name(table_name, connection, enable_quote: bool):
"""
Returns quoted table name.
Put database-specific quotation marks around the table name
to preven that tables with substrings of the table are considered.
E.g. cachalot_testparent must not return cachalot_test.
"""
return f'{connection.ops.quote_name(table_name)}' \
if enable_quote else table_name
def _find_rhs_lhs_subquery(side):
@ -170,7 +188,7 @@ def filter_cachable(tables):
return tables
def _flatten(expression: "BaseExpression"):
def _flatten(expression: 'BaseExpression'):
"""
Recursively yield this expression and all subexpressions, in
depth-first order.
@ -187,7 +205,7 @@ def _flatten(expression: "BaseExpression"):
yield expr
def _get_tables(db_alias, query):
def _get_tables(db_alias, query, compiler=False):
if query.select_for_update or (
not cachalot_settings.CACHALOT_CACHE_RANDOM
and '?' in query.order_by):
@ -196,6 +214,7 @@ def _get_tables(db_alias, query):
try:
if query.extra_select:
raise IsRawQuery
# Gets all tables already found by the ORM.
tables = set(query.table_map)
tables.add(query.get_meta().db_table)
@ -206,8 +225,10 @@ def _get_tables(db_alias, query):
raise UncachableQuery
for expression in _flatten(annotation):
if isinstance(expression, Subquery):
if hasattr(expression, "queryset"):
# Django 2.2 only: no query, only queryset
if not hasattr(expression, 'query'):
tables.update(_get_tables(db_alias, expression.queryset.query))
# Django 3+
else:
tables.update(_get_tables(db_alias, expression.query))
elif isinstance(expression, RawSQL):
@ -230,6 +251,18 @@ def _get_tables(db_alias, query):
except IsRawQuery:
sql = query.get_compiler(db_alias).as_sql()[0].lower()
tables = _get_tables_from_sql(connections[db_alias], sql)
else:
# Additional check of the final SQL.
# Potentially overlooked tables are added here. Tables may be overlooked by the regular checks
# as not all expressions are handled yet. This final check acts as safety net.
if cachalot_settings.CACHALOT_FINAL_SQL_CHECK:
if compiler:
# Access generated SQL stored when caching the query!
sql = compiler.__cachalot_generated_sql
else:
sql = query.get_compiler(db_alias).as_sql()[0].lower()
final_check_tables = _get_tables_from_sql(connections[db_alias], sql, enable_quote=True)
tables.update(final_check_tables)
if not are_all_cachable(tables):
raise UncachableQuery
@ -240,7 +273,7 @@ def _get_table_cache_keys(compiler):
db_alias = compiler.using
get_table_cache_key = cachalot_settings.CACHALOT_TABLE_KEYGEN
return [get_table_cache_key(db_alias, t)
for t in _get_tables(db_alias, compiler.query)]
for t in _get_tables(db_alias, compiler.query, compiler)]
def _invalidate_tables(cache, db_alias, tables):

View file

@ -196,6 +196,39 @@ Settings
Clear your cache after changing this setting (its not enough
to use ``./manage.py invalidate_cachalot``).
``CACHALOT_FINAL_SQL_CHECK``
~~~~~~~~~~~~~~~~~~~~~~~~~
:Default: ``False``
:Description:
If set to ``True``, the final SQL check will be performed.
The `Final SQL check` checks for potentially overlooked tables when looking up involved tables
(eg. Ordering by referenced table). See tests for more details
(eg. ``test_order_by_field_of_another_table_with_check``).
Enabling this setting comes with a small performance cost::
CACHALOT_FINAL_SQL_CHECK=False:
mysql is 1.4× slower then 9.9× faster
postgresql is 1.3× slower then 11.7× faster
sqlite is 1.4× slower then 3.0× faster
filebased is 1.4× slower then 9.5× faster
locmem is 1.3× slower then 11.3× faster
pylibmc is 1.4× slower then 8.5× faster
pymemcache is 1.4× slower then 7.3× faster
redis is 1.4× slower then 6.8× faster
CACHALOT_FINAL_SQL_CHECK=True:
mysql is 1.5× slower then 9.0× faster
postgresql is 1.3× slower then 10.5× faster
sqlite is 1.4× slower then 2.6× faster
filebased is 1.4× slower then 9.1× faster
locmem is 1.3× slower then 9.9× faster
pylibmc is 1.4× slower then 7.5× faster
pymemcache is 1.4× slower then 6.5× faster
redis is 1.5× slower then 6.2× faster
.. _Command: