Check for annotated case (#196)

This commit is contained in:
Andrew Chen Wang 2021-08-19 15:01:56 -04:00 committed by GitHub
parent 4fb23ab029
commit b15027a627
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 11 deletions

View file

@ -9,12 +9,11 @@ from django.contrib.contenttypes.models import ContentType
from django.db import (
connection, transaction, DEFAULT_DB_ALIAS, ProgrammingError,
OperationalError)
from django.db.models import Count, Q
from django.db.models import Case, Count, Q, Value, When
from django.db.models.expressions import RawSQL, Subquery, OuterRef, Exists
from django.db.models.functions import Now
from django.db.transaction import TransactionManagementError
from django.test import (
TransactionTestCase, skipUnlessDBFeature, override_settings)
from django.test import TransactionTestCase, skipUnlessDBFeature, override_settings
from pytz import UTC
from cachalot.cache import cachalot_caches
@ -373,6 +372,17 @@ class ReadTestCase(TestUtilsMixin, TransactionTestCase):
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
def test_annotate_case_with_when(self):
tests = Test.objects.filter(owner=OuterRef('pk')).values('name')
qs = User.objects.annotate(
first_test=Case(
When(Q(pk=1), then=Value('noname')),
default=Subquery(tests[:1])
)
)
self.assert_tables(qs, User, Test)
self.assert_query_cached(qs, [self.user, self.admin])
def test_only(self):
with self.assertNumQueries(1):
t1 = Test.objects.only('name').first()
@ -1041,8 +1051,7 @@ class ParameterTypeTestCase(TestUtilsMixin, TransactionTestCase):
Checks that queries with a Now() parameter are not cached.
"""
obj = Test.objects.create(datetime='1992-07-02T12:00:00')
qs = Test.objects.filter(
datetime__lte=Now())
qs = Test.objects.filter(datetime__lte=Now())
with self.assertNumQueries(1):
obj1 = qs.get()
with self.assertNumQueries(1):

View file

@ -6,7 +6,7 @@ from uuid import UUID
from django.contrib.postgres.functions import TransactionNow
from django.db import connections
from django.db.models import QuerySet, Subquery, Exists
from django.db.models import Case, Exists, QuerySet, Subquery
from django.db.models.functions import Now
from django.db.models.sql import Query, AggregateQuery
from django.db.models.sql.where import ExtraWhere, WhereNode, NothingNode
@ -171,13 +171,23 @@ def _get_tables(db_alias, query):
# Gets all tables already found by the ORM.
tables = set(query.table_map)
tables.add(query.get_meta().db_table)
def __update_annotated_subquery(_annotation: Subquery):
if hasattr(_annotation, "queryset"):
tables.update(_get_tables(db_alias, _annotation.queryset.query))
else:
tables.update(_get_tables(db_alias, _annotation.query))
# Gets tables in subquery annotations.
for annotation in query.annotations.values():
if isinstance(annotation, Subquery):
if hasattr(annotation, "queryset"):
tables.update(_get_tables(db_alias, annotation.queryset.query))
else:
tables.update(_get_tables(db_alias, annotation.query))
if isinstance(annotation, Case):
for case in annotation.cases:
for subquery in _find_subqueries_in_where(case.condition.children):
tables.update(_get_tables(db_alias, subquery))
if isinstance(annotation.default, Subquery):
__update_annotated_subquery(annotation.default)
elif isinstance(annotation, Subquery):
__update_annotated_subquery(annotation)
elif type(annotation) in UNCACHABLE_FUNCS:
raise UncachableQuery
# Gets tables in WHERE subqueries.