mirror of
https://github.com/Hopiu/django-cachalot.git
synced 2026-03-16 21:30:23 +00:00
Simplify annotation handling by using the flatten method: (#197)
* Simplify annotation handling by using the flatten method:
f42ccdd835/django/db/models/expressions.py (L370)
Handle annotated cases when Subquery is part of the When.
* Before Django 3.2 flatten did not check for existence of flatten in the processed nodes.
* Add type BaseExpression to function "flatten".
* Add test case with annotated Coalesce.
* Add support for annotated raw SQL.
* Remove unnecessary code.
* Use as_sql instead of repr.
* Reorganize code
* Fix var name
* Improve naming: element -> expression
Co-authored-by: Dominik Bartenstein <db@zemtu.com>
Co-authored-by: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com>
This commit is contained in:
parent
b15027a627
commit
76d4ab4c8d
2 changed files with 66 additions and 18 deletions
|
|
@ -11,7 +11,7 @@ from django.db import (
|
|||
OperationalError)
|
||||
from django.db.models import Case, Count, Q, Value, When
|
||||
from django.db.models.expressions import RawSQL, Subquery, OuterRef, Exists
|
||||
from django.db.models.functions import Now
|
||||
from django.db.models.functions import Coalesce, Now
|
||||
from django.db.transaction import TransactionManagementError
|
||||
from django.test import TransactionTestCase, skipUnlessDBFeature, override_settings
|
||||
from pytz import UTC
|
||||
|
|
@ -372,7 +372,7 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
|
|||
self.assert_tables(qs, User, Test)
|
||||
self.assert_query_cached(qs, [self.user, self.admin])
|
||||
|
||||
def test_annotate_case_with_when(self):
|
||||
def test_annotate_case_with_when_and_query_in_default(self):
|
||||
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
|
||||
qs = User.objects.annotate(
|
||||
first_test=Case(
|
||||
|
|
@ -383,6 +383,36 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
|
|||
self.assert_tables(qs, User, Test)
|
||||
self.assert_query_cached(qs, [self.user, self.admin])
|
||||
|
||||
def test_annotate_case_with_when(self):
|
||||
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
|
||||
qs = User.objects.annotate(
|
||||
first_test=Case(
|
||||
When(Q(pk=1), then=Subquery(tests[:1])),
|
||||
default=Value('noname')
|
||||
)
|
||||
)
|
||||
self.assert_tables(qs, User, Test)
|
||||
self.assert_query_cached(qs, [self.user, self.admin])
|
||||
|
||||
def test_annotate_coalesce(self):
|
||||
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
|
||||
qs = User.objects.annotate(
|
||||
name=Coalesce(
|
||||
Subquery(tests[:1]),
|
||||
Value('notest')
|
||||
)
|
||||
)
|
||||
self.assert_tables(qs, User, Test)
|
||||
self.assert_query_cached(qs, [self.user, self.admin])
|
||||
|
||||
def test_annotate_raw(self):
|
||||
qs = User.objects.annotate(
|
||||
perm_id=RawSQL('SELECT id FROM auth_permission WHERE id = %s',
|
||||
(self.t1__permission.pk,))
|
||||
)
|
||||
self.assert_tables(qs, User, Permission)
|
||||
self.assert_query_cached(qs, [self.user, self.admin])
|
||||
|
||||
def test_only(self):
|
||||
with self.assertNumQueries(1):
|
||||
t1 = Test.objects.only('name').first()
|
||||
|
|
|
|||
|
|
@ -2,11 +2,13 @@ import datetime
|
|||
from decimal import Decimal
|
||||
from hashlib import sha1
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from django.contrib.postgres.functions import TransactionNow
|
||||
from django.db import connections
|
||||
from django.db.models import Case, Exists, QuerySet, Subquery
|
||||
from django.db.models import Exists, QuerySet, Subquery
|
||||
from django.db.models.expressions import RawSQL
|
||||
from django.db.models.functions import Now
|
||||
from django.db.models.sql import Query, AggregateQuery
|
||||
from django.db.models.sql.where import ExtraWhere, WhereNode, NothingNode
|
||||
|
|
@ -15,6 +17,10 @@ from .settings import ITERABLES, cachalot_settings
|
|||
from .transaction import AtomicCache
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.db.models.expressions import BaseExpression
|
||||
|
||||
|
||||
class UncachableQuery(Exception):
|
||||
pass
|
||||
|
||||
|
|
@ -159,6 +165,23 @@ def filter_cachable(tables):
|
|||
return tables
|
||||
|
||||
|
||||
def _flatten(expression: "BaseExpression"):
|
||||
"""
|
||||
Recursively yield this expression and all subexpressions, in
|
||||
depth-first order.
|
||||
|
||||
Taken from Django 3.2 as the previous Django versions don’t check
|
||||
for existence of flatten.
|
||||
"""
|
||||
yield expression
|
||||
for expr in expression.get_source_expressions():
|
||||
if expr:
|
||||
if hasattr(expr, 'flatten'):
|
||||
yield from _flatten(expr)
|
||||
else:
|
||||
yield expr
|
||||
|
||||
|
||||
def _get_tables(db_alias, query):
|
||||
if query.select_for_update or (
|
||||
not cachalot_settings.CACHALOT_CACHE_RANDOM
|
||||
|
|
@ -172,24 +195,19 @@ def _get_tables(db_alias, query):
|
|||
tables = set(query.table_map)
|
||||
tables.add(query.get_meta().db_table)
|
||||
|
||||
def __update_annotated_subquery(_annotation: Subquery):
|
||||
if hasattr(_annotation, "queryset"):
|
||||
tables.update(_get_tables(db_alias, _annotation.queryset.query))
|
||||
else:
|
||||
tables.update(_get_tables(db_alias, _annotation.query))
|
||||
|
||||
# Gets tables in subquery annotations.
|
||||
for annotation in query.annotations.values():
|
||||
if isinstance(annotation, Case):
|
||||
for case in annotation.cases:
|
||||
for subquery in _find_subqueries_in_where(case.condition.children):
|
||||
tables.update(_get_tables(db_alias, subquery))
|
||||
if isinstance(annotation.default, Subquery):
|
||||
__update_annotated_subquery(annotation.default)
|
||||
elif isinstance(annotation, Subquery):
|
||||
__update_annotated_subquery(annotation)
|
||||
elif type(annotation) in UNCACHABLE_FUNCS:
|
||||
if type(annotation) in UNCACHABLE_FUNCS:
|
||||
raise UncachableQuery
|
||||
for expression in _flatten(annotation):
|
||||
if isinstance(expression, Subquery):
|
||||
if hasattr(expression, "queryset"):
|
||||
tables.update(_get_tables(db_alias, expression.queryset.query))
|
||||
else:
|
||||
tables.update(_get_tables(db_alias, expression.query))
|
||||
elif isinstance(expression, RawSQL):
|
||||
sql = expression.as_sql(None, None)[0].lower()
|
||||
tables.update(_get_tables_from_sql(connections[db_alias], sql))
|
||||
# Gets tables in WHERE subqueries.
|
||||
for subquery in _find_subqueries_in_where(query.where.children):
|
||||
tables.update(_get_tables(db_alias, subquery))
|
||||
|
|
|
|||
Loading…
Reference in a new issue