diff --git a/cachalot/tests/read.py b/cachalot/tests/read.py index f62b049..50293d4 100644 --- a/cachalot/tests/read.py +++ b/cachalot/tests/read.py @@ -9,12 +9,11 @@ from django.contrib.contenttypes.models import ContentType from django.db import ( connection, transaction, DEFAULT_DB_ALIAS, ProgrammingError, OperationalError) -from django.db.models import Count, Q +from django.db.models import Case, Count, Q, Value, When from django.db.models.expressions import RawSQL, Subquery, OuterRef, Exists from django.db.models.functions import Now from django.db.transaction import TransactionManagementError -from django.test import ( - TransactionTestCase, skipUnlessDBFeature, override_settings) +from django.test import TransactionTestCase, skipUnlessDBFeature, override_settings from pytz import UTC from cachalot.cache import cachalot_caches @@ -373,6 +372,17 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase): self.assert_tables(qs, User, Test) self.assert_query_cached(qs, [self.user, self.admin]) + def test_annotate_case_with_when(self): + tests = Test.objects.filter(owner=OuterRef('pk')).values('name') + qs = User.objects.annotate( + first_test=Case( + When(Q(pk=1), then=Value('noname')), + default=Subquery(tests[:1]) + ) + ) + self.assert_tables(qs, User, Test) + self.assert_query_cached(qs, [self.user, self.admin]) + def test_only(self): with self.assertNumQueries(1): t1 = Test.objects.only('name').first() @@ -1041,8 +1051,7 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase): Checks that queries with a Now() parameter are not cached. """ obj = Test.objects.create(datetime='1992-07-02T12:00:00') - qs = Test.objects.filter( - datetime__lte=Now()) + qs = Test.objects.filter(datetime__lte=Now()) with self.assertNumQueries(1): obj1 = qs.get() with self.assertNumQueries(1): diff --git a/cachalot/utils.py b/cachalot/utils.py index 52ded2a..dcab941 100644 --- a/cachalot/utils.py +++ b/cachalot/utils.py @@ -6,7 +6,7 @@ from uuid import UUID from django.contrib.postgres.functions import TransactionNow from django.db import connections -from django.db.models import QuerySet, Subquery, Exists +from django.db.models import Case, Exists, QuerySet, Subquery from django.db.models.functions import Now from django.db.models.sql import Query, AggregateQuery from django.db.models.sql.where import ExtraWhere, WhereNode, NothingNode @@ -171,13 +171,23 @@ def _get_tables(db_alias, query): # Gets all tables already found by the ORM. tables = set(query.table_map) tables.add(query.get_meta().db_table) + + def __update_annotated_subquery(_annotation: Subquery): + if hasattr(_annotation, "queryset"): + tables.update(_get_tables(db_alias, _annotation.queryset.query)) + else: + tables.update(_get_tables(db_alias, _annotation.query)) + # Gets tables in subquery annotations. for annotation in query.annotations.values(): - if isinstance(annotation, Subquery): - if hasattr(annotation, "queryset"): - tables.update(_get_tables(db_alias, annotation.queryset.query)) - else: - tables.update(_get_tables(db_alias, annotation.query)) + if isinstance(annotation, Case): + for case in annotation.cases: + for subquery in _find_subqueries_in_where(case.condition.children): + tables.update(_get_tables(db_alias, subquery)) + if isinstance(annotation.default, Subquery): + __update_annotated_subquery(annotation.default) + elif isinstance(annotation, Subquery): + __update_annotated_subquery(annotation) elif type(annotation) in UNCACHABLE_FUNCS: raise UncachableQuery # Gets tables in WHERE subqueries.