Simplify annotation handling by using the flatten method: (#197)

* Simplify annotation handling by using the flatten method:
f42ccdd835/django/db/models/expressions.py (L370)

Handle annotated cases when Subquery is part of the When.

* Before Django 3.2 flatten did not check for existence of flatten in the processed nodes.

* Add type BaseExpression to function "flatten".

* Add test case with annotated Coalesce.

* Add support for annotated raw SQL.

* Remove unnecessary code.

* Use as_sql instead of repr.

* Reorganize code

* Fix var name

* Improve naming: element -> expression

Co-authored-by: Dominik Bartenstein <db@zemtu.com>
Co-authored-by: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com>
This commit is contained in:
Dominik Bartenstein 2021-08-22 02:55:01 +02:00 committed by GitHub
parent b15027a627
commit 76d4ab4c8d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 18 deletions

View file

@ -11,7 +11,7 @@ from django.db import (
OperationalError)
from django.db.models import Case, Count, Q, Value, When
from django.db.models.expressions import RawSQL, Subquery, OuterRef, Exists
from django.db.models.functions import Now
from django.db.models.functions import Coalesce, Now
from django.db.transaction import TransactionManagementError
from django.test import TransactionTestCase, skipUnlessDBFeature, override_settings
from pytz import UTC
@ -372,7 +372,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
def test_annotate_case_with_when(self):
def test_annotate_case_with_when_and_query_in_default(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
first_test=Case(
@ -383,6 +383,36 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
def test_annotate_case_with_when(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
first_test=Case(
When(Q(pk=1), then=Subquery(tests[:1])),
default=Value('noname')
)
)
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
def test_annotate_coalesce(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
name=Coalesce(
Subquery(tests[:1]),
Value('notest')
)
)
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
def test_annotate_raw(self):
qs = User.objects.annotate(
perm_id=RawSQL('SELECT id FROM auth_permission WHERE id = %s',
(self.t1__permission.pk,))
)
self.assert_tables(qs, User, Permission)
self.assert_query_cached(qs, [self.user, self.admin])
def test_only(self):
with self.assertNumQueries(1):
t1 = Test.objects.only('name').first()

View file

@ -2,11 +2,13 @@ import datetime
from decimal import Decimal
from hashlib import sha1
from time import time
from typing import TYPE_CHECKING
from uuid import UUID
from django.contrib.postgres.functions import TransactionNow
from django.db import connections
from django.db.models import Case, Exists, QuerySet, Subquery
from django.db.models import Exists, QuerySet, Subquery
from django.db.models.expressions import RawSQL
from django.db.models.functions import Now
from django.db.models.sql import Query, AggregateQuery
from django.db.models.sql.where import ExtraWhere, WhereNode, NothingNode
@ -15,6 +17,10 @@ from .settings import ITERABLES, cachalot_settings
from .transaction import AtomicCache
if TYPE_CHECKING:
from django.db.models.expressions import BaseExpression
class UncachableQuery(Exception):
pass
@ -159,6 +165,23 @@ def filter_cachable(tables):
return tables
def _flatten(expression: "BaseExpression"):
"""
Recursively yield this expression and all subexpressions, in
depth-first order.
Taken from Django 3.2 as the previous Django versions dont check
for existence of flatten.
"""
yield expression
for expr in expression.get_source_expressions():
if expr:
if hasattr(expr, 'flatten'):
yield from _flatten(expr)
else:
yield expr
def _get_tables(db_alias, query):
if query.select_for_update or (
not cachalot_settings.CACHALOT_CACHE_RANDOM
@ -172,24 +195,19 @@ def _get_tables(db_alias, query):
tables = set(query.table_map)
tables.add(query.get_meta().db_table)
def __update_annotated_subquery(_annotation: Subquery):
if hasattr(_annotation, "queryset"):
tables.update(_get_tables(db_alias, _annotation.queryset.query))
else:
tables.update(_get_tables(db_alias, _annotation.query))
# Gets tables in subquery annotations.
for annotation in query.annotations.values():
if isinstance(annotation, Case):
for case in annotation.cases:
for subquery in _find_subqueries_in_where(case.condition.children):
tables.update(_get_tables(db_alias, subquery))
if isinstance(annotation.default, Subquery):
__update_annotated_subquery(annotation.default)
elif isinstance(annotation, Subquery):
__update_annotated_subquery(annotation)
elif type(annotation) in UNCACHABLE_FUNCS:
if type(annotation) in UNCACHABLE_FUNCS:
raise UncachableQuery
for expression in _flatten(annotation):
if isinstance(expression, Subquery):
if hasattr(expression, "queryset"):
tables.update(_get_tables(db_alias, expression.queryset.query))
else:
tables.update(_get_tables(db_alias, expression.query))
elif isinstance(expression, RawSQL):
sql = expression.as_sql(None, None)[0].lower()
tables.update(_get_tables_from_sql(connections[db_alias], sql))
# Gets tables in WHERE subqueries.
for subquery in _find_subqueries_in_where(query.where.children):
tables.update(_get_tables(db_alias, subquery))