mirror of
https://github.com/Hopiu/django.git
synced 2026-04-28 10:44:50 +00:00
Thanks to Adam Johnson, Aymeric Augustin, David Smith, Mariusz Felisiak, Nick Pope, and Paul Ganssle for reviews.
339 lines
12 KiB
Python
339 lines
12 KiB
Python
from datetime import datetime
|
|
|
|
from django.conf import settings
|
|
from django.db.models.expressions import Func
|
|
from django.db.models.fields import (
|
|
DateField, DateTimeField, DurationField, Field, IntegerField, TimeField,
|
|
)
|
|
from django.db.models.lookups import (
|
|
Transform, YearExact, YearGt, YearGte, YearLt, YearLte,
|
|
)
|
|
from django.utils import timezone
|
|
|
|
|
|
class TimezoneMixin:
|
|
tzinfo = None
|
|
|
|
def get_tzname(self):
|
|
# Timezone conversions must happen to the input datetime *before*
|
|
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
|
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
|
# based on the input datetime not the stored datetime.
|
|
tzname = None
|
|
if settings.USE_TZ:
|
|
if self.tzinfo is None:
|
|
tzname = timezone.get_current_timezone_name()
|
|
else:
|
|
tzname = timezone._get_timezone_name(self.tzinfo)
|
|
return tzname
|
|
|
|
|
|
class Extract(TimezoneMixin, Transform):
|
|
lookup_name = None
|
|
output_field = IntegerField()
|
|
|
|
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
|
if self.lookup_name is None:
|
|
self.lookup_name = lookup_name
|
|
if self.lookup_name is None:
|
|
raise ValueError('lookup_name must be provided')
|
|
self.tzinfo = tzinfo
|
|
super().__init__(expression, **extra)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
sql, params = compiler.compile(self.lhs)
|
|
lhs_output_field = self.lhs.output_field
|
|
if isinstance(lhs_output_field, DateTimeField):
|
|
tzname = self.get_tzname()
|
|
sql = connection.ops.datetime_extract_sql(self.lookup_name, sql, tzname)
|
|
elif self.tzinfo is not None:
|
|
raise ValueError('tzinfo can only be used with DateTimeField.')
|
|
elif isinstance(lhs_output_field, DateField):
|
|
sql = connection.ops.date_extract_sql(self.lookup_name, sql)
|
|
elif isinstance(lhs_output_field, TimeField):
|
|
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
|
elif isinstance(lhs_output_field, DurationField):
|
|
if not connection.features.has_native_duration_field:
|
|
raise ValueError('Extract requires native DurationField database support.')
|
|
sql = connection.ops.time_extract_sql(self.lookup_name, sql)
|
|
else:
|
|
# resolve_expression has already validated the output_field so this
|
|
# assert should never be hit.
|
|
assert False, "Tried to Extract from an invalid type."
|
|
return sql, params
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
field = getattr(copy.lhs, 'output_field', None)
|
|
if field is None:
|
|
return copy
|
|
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
|
|
raise ValueError(
|
|
'Extract input expression must be DateField, DateTimeField, '
|
|
'TimeField, or DurationField.'
|
|
)
|
|
# Passing dates to functions expecting datetimes is most likely a mistake.
|
|
if type(field) == DateField and copy.lookup_name in ('hour', 'minute', 'second'):
|
|
raise ValueError(
|
|
"Cannot extract time component '%s' from DateField '%s'." % (copy.lookup_name, field.name)
|
|
)
|
|
if (
|
|
isinstance(field, DurationField) and
|
|
copy.lookup_name in ('year', 'iso_year', 'month', 'week', 'week_day', 'iso_week_day', 'quarter')
|
|
):
|
|
raise ValueError(
|
|
"Cannot extract component '%s' from DurationField '%s'."
|
|
% (copy.lookup_name, field.name)
|
|
)
|
|
return copy
|
|
|
|
|
|
class ExtractYear(Extract):
|
|
lookup_name = 'year'
|
|
|
|
|
|
class ExtractIsoYear(Extract):
|
|
"""Return the ISO-8601 week-numbering year."""
|
|
lookup_name = 'iso_year'
|
|
|
|
|
|
class ExtractMonth(Extract):
|
|
lookup_name = 'month'
|
|
|
|
|
|
class ExtractDay(Extract):
|
|
lookup_name = 'day'
|
|
|
|
|
|
class ExtractWeek(Extract):
|
|
"""
|
|
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
|
|
week.
|
|
"""
|
|
lookup_name = 'week'
|
|
|
|
|
|
class ExtractWeekDay(Extract):
|
|
"""
|
|
Return Sunday=1 through Saturday=7.
|
|
|
|
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
|
|
"""
|
|
lookup_name = 'week_day'
|
|
|
|
|
|
class ExtractIsoWeekDay(Extract):
|
|
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
|
|
lookup_name = 'iso_week_day'
|
|
|
|
|
|
class ExtractQuarter(Extract):
|
|
lookup_name = 'quarter'
|
|
|
|
|
|
class ExtractHour(Extract):
|
|
lookup_name = 'hour'
|
|
|
|
|
|
class ExtractMinute(Extract):
|
|
lookup_name = 'minute'
|
|
|
|
|
|
class ExtractSecond(Extract):
|
|
lookup_name = 'second'
|
|
|
|
|
|
DateField.register_lookup(ExtractYear)
|
|
DateField.register_lookup(ExtractMonth)
|
|
DateField.register_lookup(ExtractDay)
|
|
DateField.register_lookup(ExtractWeekDay)
|
|
DateField.register_lookup(ExtractIsoWeekDay)
|
|
DateField.register_lookup(ExtractWeek)
|
|
DateField.register_lookup(ExtractIsoYear)
|
|
DateField.register_lookup(ExtractQuarter)
|
|
|
|
TimeField.register_lookup(ExtractHour)
|
|
TimeField.register_lookup(ExtractMinute)
|
|
TimeField.register_lookup(ExtractSecond)
|
|
|
|
DateTimeField.register_lookup(ExtractHour)
|
|
DateTimeField.register_lookup(ExtractMinute)
|
|
DateTimeField.register_lookup(ExtractSecond)
|
|
|
|
ExtractYear.register_lookup(YearExact)
|
|
ExtractYear.register_lookup(YearGt)
|
|
ExtractYear.register_lookup(YearGte)
|
|
ExtractYear.register_lookup(YearLt)
|
|
ExtractYear.register_lookup(YearLte)
|
|
|
|
ExtractIsoYear.register_lookup(YearExact)
|
|
ExtractIsoYear.register_lookup(YearGt)
|
|
ExtractIsoYear.register_lookup(YearGte)
|
|
ExtractIsoYear.register_lookup(YearLt)
|
|
ExtractIsoYear.register_lookup(YearLte)
|
|
|
|
|
|
class Now(Func):
|
|
template = 'CURRENT_TIMESTAMP'
|
|
output_field = DateTimeField()
|
|
|
|
def as_postgresql(self, compiler, connection, **extra_context):
|
|
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
|
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
|
# other databases.
|
|
return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
|
|
|
|
|
|
class TruncBase(TimezoneMixin, Transform):
|
|
kind = None
|
|
tzinfo = None
|
|
|
|
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
|
# argument.
|
|
def __init__(self, expression, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
|
|
self.tzinfo = tzinfo
|
|
self.is_dst = is_dst
|
|
super().__init__(expression, output_field=output_field, **extra)
|
|
|
|
def as_sql(self, compiler, connection):
|
|
inner_sql, inner_params = compiler.compile(self.lhs)
|
|
tzname = None
|
|
if isinstance(self.lhs.output_field, DateTimeField):
|
|
tzname = self.get_tzname()
|
|
elif self.tzinfo is not None:
|
|
raise ValueError('tzinfo can only be used with DateTimeField.')
|
|
if isinstance(self.output_field, DateTimeField):
|
|
sql = connection.ops.datetime_trunc_sql(self.kind, inner_sql, tzname)
|
|
elif isinstance(self.output_field, DateField):
|
|
sql = connection.ops.date_trunc_sql(self.kind, inner_sql, tzname)
|
|
elif isinstance(self.output_field, TimeField):
|
|
sql = connection.ops.time_trunc_sql(self.kind, inner_sql, tzname)
|
|
else:
|
|
raise ValueError('Trunc only valid on DateField, TimeField, or DateTimeField.')
|
|
return sql, inner_params
|
|
|
|
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
|
copy = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
|
field = copy.lhs.output_field
|
|
# DateTimeField is a subclass of DateField so this works for both.
|
|
if not isinstance(field, (DateField, TimeField)):
|
|
raise TypeError(
|
|
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
|
|
)
|
|
# If self.output_field was None, then accessing the field will trigger
|
|
# the resolver to assign it to self.lhs.output_field.
|
|
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
|
|
raise ValueError('output_field must be either DateField, TimeField, or DateTimeField')
|
|
# Passing dates or times to functions expecting datetimes is most
|
|
# likely a mistake.
|
|
class_output_field = self.__class__.output_field if isinstance(self.__class__.output_field, Field) else None
|
|
output_field = class_output_field or copy.output_field
|
|
has_explicit_output_field = class_output_field or field.__class__ is not copy.output_field.__class__
|
|
if type(field) == DateField and (
|
|
isinstance(output_field, DateTimeField) or copy.kind in ('hour', 'minute', 'second', 'time')):
|
|
raise ValueError("Cannot truncate DateField '%s' to %s." % (
|
|
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
|
))
|
|
elif isinstance(field, TimeField) and (
|
|
isinstance(output_field, DateTimeField) or
|
|
copy.kind in ('year', 'quarter', 'month', 'week', 'day', 'date')):
|
|
raise ValueError("Cannot truncate TimeField '%s' to %s." % (
|
|
field.name, output_field.__class__.__name__ if has_explicit_output_field else 'DateTimeField'
|
|
))
|
|
return copy
|
|
|
|
def convert_value(self, value, expression, connection):
|
|
if isinstance(self.output_field, DateTimeField):
|
|
if not settings.USE_TZ:
|
|
pass
|
|
elif value is not None:
|
|
value = value.replace(tzinfo=None)
|
|
value = timezone.make_aware(value, self.tzinfo, is_dst=self.is_dst)
|
|
elif not connection.features.has_zoneinfo_database:
|
|
raise ValueError(
|
|
'Database returned an invalid datetime value. Are time '
|
|
'zone definitions for your database installed?'
|
|
)
|
|
elif isinstance(value, datetime):
|
|
if value is None:
|
|
pass
|
|
elif isinstance(self.output_field, DateField):
|
|
value = value.date()
|
|
elif isinstance(self.output_field, TimeField):
|
|
value = value.time()
|
|
return value
|
|
|
|
|
|
class Trunc(TruncBase):
|
|
|
|
# RemovedInDjango50Warning: when the deprecation ends, remove is_dst
|
|
# argument.
|
|
def __init__(self, expression, kind, output_field=None, tzinfo=None, is_dst=timezone.NOT_PASSED, **extra):
|
|
self.kind = kind
|
|
super().__init__(
|
|
expression, output_field=output_field, tzinfo=tzinfo,
|
|
is_dst=is_dst, **extra
|
|
)
|
|
|
|
|
|
class TruncYear(TruncBase):
|
|
kind = 'year'
|
|
|
|
|
|
class TruncQuarter(TruncBase):
|
|
kind = 'quarter'
|
|
|
|
|
|
class TruncMonth(TruncBase):
|
|
kind = 'month'
|
|
|
|
|
|
class TruncWeek(TruncBase):
|
|
"""Truncate to midnight on the Monday of the week."""
|
|
kind = 'week'
|
|
|
|
|
|
class TruncDay(TruncBase):
|
|
kind = 'day'
|
|
|
|
|
|
class TruncDate(TruncBase):
|
|
kind = 'date'
|
|
lookup_name = 'date'
|
|
output_field = DateField()
|
|
|
|
def as_sql(self, compiler, connection):
|
|
# Cast to date rather than truncate to date.
|
|
lhs, lhs_params = compiler.compile(self.lhs)
|
|
tzname = self.get_tzname()
|
|
sql = connection.ops.datetime_cast_date_sql(lhs, tzname)
|
|
return sql, lhs_params
|
|
|
|
|
|
class TruncTime(TruncBase):
|
|
kind = 'time'
|
|
lookup_name = 'time'
|
|
output_field = TimeField()
|
|
|
|
def as_sql(self, compiler, connection):
|
|
# Cast to time rather than truncate to time.
|
|
lhs, lhs_params = compiler.compile(self.lhs)
|
|
tzname = self.get_tzname()
|
|
sql = connection.ops.datetime_cast_time_sql(lhs, tzname)
|
|
return sql, lhs_params
|
|
|
|
|
|
class TruncHour(TruncBase):
|
|
kind = 'hour'
|
|
|
|
|
|
class TruncMinute(TruncBase):
|
|
kind = 'minute'
|
|
|
|
|
|
class TruncSecond(TruncBase):
|
|
kind = 'second'
|
|
|
|
|
|
DateTimeField.register_lookup(TruncDate)
|
|
DateTimeField.register_lookup(TruncTime)
|