From 76d4ab4c8df2a51e6dd081f662f2f8a854bb91de Mon Sep 17 00:00:00 2001 From: Dominik Bartenstein Date: Sun, 22 Aug 2021 02:55:01 +0200 Subject: [PATCH] Simplify annotation handling by using the flatten method: (#197) * Simplify annotation handling by using the flatten method: https://github.com/django/django/blob/f42ccdd835e5b3f0914b5e6f87621c648136ea36/django/db/models/expressions.py#L370 Handle annotated cases when Subquery is part of the When. * Before Django 3.2 flatten did not check for existence of flatten in the processed nodes. * Add type BaseExpression to function "flatten". * Add test case with annotated Coalesce. * Add support for annotated raw SQL. * Remove unnecessary code. * Use as_sql instead of repr. * Reorganize code * Fix var name * Improve naming: element -> expression Co-authored-by: Dominik Bartenstein Co-authored-by: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com> --- cachalot/tests/read.py | 34 ++++++++++++++++++++++++++-- cachalot/utils.py | 50 ++++++++++++++++++++++++++++-------------- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/cachalot/tests/read.py b/cachalot/tests/read.py index 50293d4..a46effe 100644 --- a/cachalot/tests/read.py +++ b/cachalot/tests/read.py @@ -11,7 +11,7 @@ from django.db import ( OperationalError) from django.db.models import Case, Count, Q, Value, When from django.db.models.expressions import RawSQL, Subquery, OuterRef, Exists -from django.db.models.functions import Now +from django.db.models.functions import Coalesce, Now from django.db.transaction import TransactionManagementError from django.test import TransactionTestCase, skipUnlessDBFeature, override_settings from pytz import UTC @@ -372,7 +372,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase): self.assert_tables(qs, User, Test) self.assert_query_cached(qs, [self.user, self.admin]) - def test_annotate_case_with_when(self): + def test_annotate_case_with_when_and_query_in_default(self): tests = Test.objects.filter(owner=OuterRef('pk')).values('name') qs = User.objects.annotate( first_test=Case( @@ -383,6 +383,36 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase): self.assert_tables(qs, User, Test) self.assert_query_cached(qs, [self.user, self.admin]) + def test_annotate_case_with_when(self): + tests = Test.objects.filter(owner=OuterRef('pk')).values('name') + qs = User.objects.annotate( + first_test=Case( + When(Q(pk=1), then=Subquery(tests[:1])), + default=Value('noname') + ) + ) + self.assert_tables(qs, User, Test) + self.assert_query_cached(qs, [self.user, self.admin]) + + def test_annotate_coalesce(self): + tests = Test.objects.filter(owner=OuterRef('pk')).values('name') + qs = User.objects.annotate( + name=Coalesce( + Subquery(tests[:1]), + Value('notest') + ) + ) + self.assert_tables(qs, User, Test) + self.assert_query_cached(qs, [self.user, self.admin]) + + def test_annotate_raw(self): + qs = User.objects.annotate( + perm_id=RawSQL('SELECT id FROM auth_permission WHERE id = %s', + (self.t1__permission.pk,)) + ) + self.assert_tables(qs, User, Permission) + self.assert_query_cached(qs, [self.user, self.admin]) + def test_only(self): with self.assertNumQueries(1): t1 = Test.objects.only('name').first() diff --git a/cachalot/utils.py b/cachalot/utils.py index dcab941..da5f2e4 100644 --- a/cachalot/utils.py +++ b/cachalot/utils.py @@ -2,11 +2,13 @@ import datetime from decimal import Decimal from hashlib import sha1 from time import time +from typing import TYPE_CHECKING from uuid import UUID from django.contrib.postgres.functions import TransactionNow from django.db import connections -from django.db.models import Case, Exists, QuerySet, Subquery +from django.db.models import Exists, QuerySet, Subquery +from django.db.models.expressions import RawSQL from django.db.models.functions import Now from django.db.models.sql import Query, AggregateQuery from django.db.models.sql.where import ExtraWhere, WhereNode, NothingNode @@ -15,6 +17,10 @@ from .settings import ITERABLES, cachalot_settings from .transaction import AtomicCache +if TYPE_CHECKING: + from django.db.models.expressions import BaseExpression + + class UncachableQuery(Exception): pass @@ -159,6 +165,23 @@ def filter_cachable(tables): return tables +def _flatten(expression: "BaseExpression"): + """ + Recursively yield this expression and all subexpressions, in + depth-first order. + + Taken from Django 3.2 as the previous Django versions don’t check + for existence of flatten. + """ + yield expression + for expr in expression.get_source_expressions(): + if expr: + if hasattr(expr, 'flatten'): + yield from _flatten(expr) + else: + yield expr + + def _get_tables(db_alias, query): if query.select_for_update or ( not cachalot_settings.CACHALOT_CACHE_RANDOM @@ -172,24 +195,19 @@ def _get_tables(db_alias, query): tables = set(query.table_map) tables.add(query.get_meta().db_table) - def __update_annotated_subquery(_annotation: Subquery): - if hasattr(_annotation, "queryset"): - tables.update(_get_tables(db_alias, _annotation.queryset.query)) - else: - tables.update(_get_tables(db_alias, _annotation.query)) - # Gets tables in subquery annotations. for annotation in query.annotations.values(): - if isinstance(annotation, Case): - for case in annotation.cases: - for subquery in _find_subqueries_in_where(case.condition.children): - tables.update(_get_tables(db_alias, subquery)) - if isinstance(annotation.default, Subquery): - __update_annotated_subquery(annotation.default) - elif isinstance(annotation, Subquery): - __update_annotated_subquery(annotation) - elif type(annotation) in UNCACHABLE_FUNCS: + if type(annotation) in UNCACHABLE_FUNCS: raise UncachableQuery + for expression in _flatten(annotation): + if isinstance(expression, Subquery): + if hasattr(expression, "queryset"): + tables.update(_get_tables(db_alias, expression.queryset.query)) + else: + tables.update(_get_tables(db_alias, expression.query)) + elif isinstance(expression, RawSQL): + sql = expression.as_sql(None, None)[0].lower() + tables.update(_get_tables_from_sql(connections[db_alias], sql)) # Gets tables in WHERE subqueries. for subquery in _find_subqueries_in_where(query.where.children): tables.update(_get_tables(db_alias, subquery))