feat: Support F and Concat expressions in annotate()

Refs #735
This commit is contained in:
wookkl 2024-05-24 22:17:30 +09:00 committed by GitHub
parent 47dcb9cd46
commit a0aeb58b47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 64 additions and 3 deletions

View file

@ -15,13 +15,15 @@ from django.contrib.admin.utils import get_model_from_relation
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.backends.utils import CursorWrapper
from django.db.models import Field, Model
from django.db.models import Field, Model, F
from django.db.models.expressions import Col
from django.db.models.functions import Concat, ConcatPair
from django.db.models.lookups import Lookup
from django.db.models.query import QuerySet, ValuesIterable
from django.db.models.utils import create_namedtuple_class
from django.utils.tree import Node
from modeltranslation._typing import Self, AutoPopulate
from modeltranslation.fields import TranslationField
from modeltranslation.thread_context import auto_populate_mode
from modeltranslation.utils import (
@ -30,7 +32,6 @@ from modeltranslation.utils import (
get_language,
resolution_order,
)
from modeltranslation._typing import Self, AutoPopulate
_C2F_CACHE: dict[tuple[type[Model], str], Field] = {}
_F2TM_CACHE: dict[type[Model], dict[str, type[Model]]] = {}
@ -513,6 +514,27 @@ class MultilingualQuerySet(QuerySet[_T]):
new_key = rewrite_lookup_key(self.model, field_name)
return super().dates(new_key, *args, **kwargs)
def _rewrite_concat(self, concat: Concat | ConcatPair):
new_source_expressions = []
for exp in concat.source_expressions:
if isinstance(exp, (Concat, ConcatPair)):
exp = self._rewrite_concat(exp)
if isinstance(exp, F):
exp = self._rewrite_f(exp)
new_source_expressions.append(exp)
concat.set_source_expressions(new_source_expressions)
return concat
def annotate(self, *args: Any, **kwargs: Any) -> Self:
if not self._rewrite:
return super().annotate(*args, **kwargs)
for key, val in list(kwargs.items()):
if isinstance(val, models.F):
kwargs[key] = self._rewrite_f(val)
if isinstance(val, Concat):
kwargs[key] = self._rewrite_concat(val)
return super().annotate(*args, **kwargs)
class FallbackValuesIterable(ValuesIterable):
queryset: MultilingualQuerySet[Model]

View file

@ -18,7 +18,7 @@ from django.core.management import call_command
from django.core.management.base import CommandError
from django.db import IntegrityError
from django.db.models import CharField, Count, F, Q, TextField, Value
from django.db.models.functions import Cast
from django.db.models.functions import Cast, Concat
from django.test import TestCase, TransactionTestCase
from django.test.utils import override_settings
from django.utils.translation import get_language, override, trans_real
@ -3670,6 +3670,45 @@ class TestManager(ModeltranslationTestBase):
assert titles_for_en == (("title_1_en", "desc_1_en"), ("title_2_en", "desc_1_en"))
assert titles_for_de == (("title_1_de", "desc_1_de"), ("title_2_de", "desc_1_de"))
def test_annotate(self):
"""Test if annotating is language-aware."""
test = models.TestModel.objects.create(title_en="title_en", title_de="title_de")
assert "en" == get_language()
assert (
models.TestModel.objects.annotate(custom_title=F("title")).values_list(
"custom_title", flat=True
)[0]
== "title_en"
)
with override("de"):
assert (
models.TestModel.objects.annotate(custom_title=F("title")).values_list(
"custom_title", flat=True
)[0]
== "title_de"
)
assert (
models.TestModel.objects.annotate(
custom_title=Concat(F("title"), Value("value1"), Value("value2"))
).values_list("custom_title", flat=True)[0]
== "title_devalue1value2"
)
assert (
models.TestModel.objects.annotate(
custom_title=Concat(F("title"), Concat(F("title"), Value("value")))
).values_list("custom_title", flat=True)[0]
== "title_detitle_devalue"
)
models.ForeignKeyModel.objects.create(test=test)
models.ForeignKeyModel.objects.create(test=test)
assert (
models.TestModel.objects.annotate(Count("test_fks")).values_list(
"test_fks__count", flat=True
)[0]
== 2
)
class TranslationModelFormTest(ModeltranslationTestBase):
def test_fields(self):