mirror of
https://github.com/Hopiu/django-cachalot.git
synced 2026-03-16 21:30:23 +00:00
278 lines
9.5 KiB
Python
278 lines
9.5 KiB
Python
import datetime
|
||
from decimal import Decimal
|
||
from hashlib import sha1
|
||
from time import time
|
||
from typing import TYPE_CHECKING
|
||
from uuid import UUID
|
||
|
||
from django.contrib.postgres.functions import TransactionNow
|
||
from django.db import connections
|
||
from django.db.models import Exists, QuerySet, Subquery
|
||
from django.db.models.expressions import RawSQL
|
||
from django.db.models.functions import Now
|
||
from django.db.models.sql import Query, AggregateQuery
|
||
from django.db.models.sql.where import ExtraWhere, WhereNode, NothingNode
|
||
|
||
from .settings import ITERABLES, cachalot_settings
|
||
from .transaction import AtomicCache
|
||
|
||
|
||
if TYPE_CHECKING:
|
||
from django.db.models.expressions import BaseExpression
|
||
|
||
|
||
class UncachableQuery(Exception):
|
||
pass
|
||
|
||
|
||
class IsRawQuery(Exception):
|
||
pass
|
||
|
||
|
||
CACHABLE_PARAM_TYPES = {
|
||
bool, int, float, Decimal, bytearray, bytes, str, type(None),
|
||
datetime.date, datetime.time, datetime.datetime, datetime.timedelta, UUID,
|
||
}
|
||
UNCACHABLE_FUNCS = {Now, TransactionNow}
|
||
|
||
try:
|
||
from psycopg2 import Binary
|
||
from psycopg2.extras import (
|
||
NumericRange, DateRange, DateTimeRange, DateTimeTZRange, Inet, Json)
|
||
except ImportError:
|
||
pass
|
||
else:
|
||
CACHABLE_PARAM_TYPES.update((
|
||
Binary, NumericRange, DateRange, DateTimeRange, DateTimeTZRange, Inet,
|
||
Json,))
|
||
|
||
|
||
def check_parameter_types(params):
|
||
for p in params:
|
||
cl = p.__class__
|
||
if cl not in CACHABLE_PARAM_TYPES:
|
||
if cl in ITERABLES:
|
||
check_parameter_types(p)
|
||
elif cl is dict:
|
||
check_parameter_types(p.items())
|
||
else:
|
||
raise UncachableQuery
|
||
|
||
|
||
def get_query_cache_key(compiler):
|
||
"""
|
||
Generates a cache key from a SQLCompiler.
|
||
|
||
This cache key is specific to the SQL query and its context
|
||
(which database is used). The same query in the same context
|
||
(= the same database) must generate the same cache key.
|
||
|
||
:arg compiler: A SQLCompiler that will generate the SQL query
|
||
:type compiler: django.db.models.sql.compiler.SQLCompiler
|
||
:return: A cache key
|
||
:rtype: int
|
||
"""
|
||
sql, params = compiler.as_sql()
|
||
check_parameter_types(params)
|
||
cache_key = '%s:%s:%s' % (compiler.using, sql,
|
||
[str(p) for p in params])
|
||
# Set attribute on compiler for later access
|
||
# to the generated SQL. This prevents another as_sql() call!
|
||
compiler.__cachalot_generated_sql = sql.lower()
|
||
|
||
return sha1(cache_key.encode('utf-8')).hexdigest()
|
||
|
||
|
||
def get_table_cache_key(db_alias, table):
|
||
"""
|
||
Generates a cache key from a SQL table.
|
||
|
||
:arg db_alias: Alias of the used database
|
||
:type db_alias: str or unicode
|
||
:arg table: Name of the SQL table
|
||
:type table: str or unicode
|
||
:return: A cache key
|
||
:rtype: int
|
||
"""
|
||
cache_key = '%s:%s' % (db_alias, table)
|
||
return sha1(cache_key.encode('utf-8')).hexdigest()
|
||
|
||
|
||
def _get_tables_from_sql(connection, lowercased_sql, enable_quote: bool = False):
|
||
"""Returns names of involved tables after analyzing the final SQL query."""
|
||
return {table for table in (connection.introspection.django_table_names()
|
||
+ cachalot_settings.CACHALOT_ADDITIONAL_TABLES)
|
||
if _quote_table_name(table, connection, enable_quote) in lowercased_sql}
|
||
|
||
|
||
def _quote_table_name(table_name, connection, enable_quote: bool):
|
||
"""
|
||
Returns quoted table name.
|
||
|
||
Put database-specific quotation marks around the table name
|
||
to preven that tables with substrings of the table are considered.
|
||
E.g. cachalot_testparent must not return cachalot_test.
|
||
"""
|
||
return f'{connection.ops.quote_name(table_name)}' \
|
||
if enable_quote else table_name
|
||
|
||
|
||
def _find_rhs_lhs_subquery(side):
|
||
h_class = side.__class__
|
||
if h_class is Query:
|
||
return side
|
||
elif h_class is QuerySet:
|
||
return side.query
|
||
elif h_class in (Subquery, Exists): # Subquery allows QuerySet & Query
|
||
return side.query.query if side.query.__class__ is QuerySet else side.query
|
||
elif h_class in UNCACHABLE_FUNCS:
|
||
raise UncachableQuery
|
||
|
||
|
||
def _find_subqueries_in_where(children):
|
||
for child in children:
|
||
child_class = child.__class__
|
||
if child_class is WhereNode:
|
||
for grand_child in _find_subqueries_in_where(child.children):
|
||
yield grand_child
|
||
elif child_class is ExtraWhere:
|
||
raise IsRawQuery
|
||
elif child_class is NothingNode:
|
||
pass
|
||
else:
|
||
try:
|
||
child_rhs = child.rhs
|
||
child_lhs = child.lhs
|
||
except AttributeError:
|
||
raise UncachableQuery
|
||
rhs = _find_rhs_lhs_subquery(child_rhs)
|
||
if rhs is not None:
|
||
yield rhs
|
||
lhs = _find_rhs_lhs_subquery(child_lhs)
|
||
if lhs is not None:
|
||
yield lhs
|
||
|
||
|
||
def is_cachable(table):
|
||
whitelist = cachalot_settings.CACHALOT_ONLY_CACHABLE_TABLES
|
||
if whitelist and table not in whitelist:
|
||
return False
|
||
return table not in cachalot_settings.CACHALOT_UNCACHABLE_TABLES
|
||
|
||
|
||
def are_all_cachable(tables):
|
||
whitelist = cachalot_settings.CACHALOT_ONLY_CACHABLE_TABLES
|
||
if whitelist and not tables.issubset(whitelist):
|
||
return False
|
||
return tables.isdisjoint(cachalot_settings.CACHALOT_UNCACHABLE_TABLES)
|
||
|
||
|
||
def filter_cachable(tables):
|
||
whitelist = cachalot_settings.CACHALOT_ONLY_CACHABLE_TABLES
|
||
tables = tables.difference(cachalot_settings.CACHALOT_UNCACHABLE_TABLES)
|
||
if whitelist:
|
||
return tables.intersection(whitelist)
|
||
return tables
|
||
|
||
|
||
def _flatten(expression: 'BaseExpression'):
|
||
"""
|
||
Recursively yield this expression and all subexpressions, in
|
||
depth-first order.
|
||
|
||
Taken from Django 3.2 as the previous Django versions don’t check
|
||
for existence of flatten.
|
||
"""
|
||
yield expression
|
||
for expr in expression.get_source_expressions():
|
||
if expr:
|
||
if hasattr(expr, 'flatten'):
|
||
yield from _flatten(expr)
|
||
else:
|
||
yield expr
|
||
|
||
|
||
def _get_tables(db_alias, query, compiler=False):
|
||
if query.select_for_update or (
|
||
not cachalot_settings.CACHALOT_CACHE_RANDOM
|
||
and '?' in query.order_by):
|
||
raise UncachableQuery
|
||
|
||
try:
|
||
if query.extra_select:
|
||
raise IsRawQuery
|
||
|
||
# Gets all tables already found by the ORM.
|
||
tables = set(query.table_map)
|
||
if query.get_meta():
|
||
tables.add(query.get_meta().db_table)
|
||
|
||
# Gets tables in subquery annotations.
|
||
for annotation in query.annotations.values():
|
||
if type(annotation) in UNCACHABLE_FUNCS:
|
||
raise UncachableQuery
|
||
for expression in _flatten(annotation):
|
||
if isinstance(expression, Subquery):
|
||
# Django 2.2 only: no query, only queryset
|
||
if not hasattr(expression, 'query'):
|
||
tables.update(_get_tables(db_alias, expression.queryset.query))
|
||
# Django 3+
|
||
else:
|
||
tables.update(_get_tables(db_alias, expression.query))
|
||
elif isinstance(expression, RawSQL):
|
||
sql = expression.as_sql(None, None)[0].lower()
|
||
tables.update(_get_tables_from_sql(connections[db_alias], sql))
|
||
# Gets tables in WHERE subqueries.
|
||
for subquery in _find_subqueries_in_where(query.where.children):
|
||
tables.update(_get_tables(db_alias, subquery))
|
||
# Gets tables in HAVING subqueries.
|
||
if isinstance(query, AggregateQuery):
|
||
try:
|
||
tables.update(_get_tables_from_sql(connections[db_alias], query.subquery))
|
||
except TypeError: # For Django 3.2+
|
||
tables.update(_get_tables(db_alias, query.inner_query))
|
||
# Gets tables in combined queries
|
||
# using `.union`, `.intersection`, or `difference`.
|
||
if query.combined_queries:
|
||
for combined_query in query.combined_queries:
|
||
tables.update(_get_tables(db_alias, combined_query))
|
||
except IsRawQuery:
|
||
sql = query.get_compiler(db_alias).as_sql()[0].lower()
|
||
tables = _get_tables_from_sql(connections[db_alias], sql)
|
||
else:
|
||
# Additional check of the final SQL.
|
||
# Potentially overlooked tables are added here. Tables may be overlooked by the regular checks
|
||
# as not all expressions are handled yet. This final check acts as safety net.
|
||
if cachalot_settings.CACHALOT_FINAL_SQL_CHECK:
|
||
if compiler:
|
||
# Access generated SQL stored when caching the query!
|
||
sql = compiler.__cachalot_generated_sql
|
||
else:
|
||
sql = query.get_compiler(db_alias).as_sql()[0].lower()
|
||
final_check_tables = _get_tables_from_sql(connections[db_alias], sql, enable_quote=True)
|
||
tables.update(final_check_tables)
|
||
|
||
if not are_all_cachable(tables):
|
||
raise UncachableQuery
|
||
return tables
|
||
|
||
|
||
def _get_table_cache_keys(compiler):
|
||
db_alias = compiler.using
|
||
get_table_cache_key = cachalot_settings.CACHALOT_TABLE_KEYGEN
|
||
return [get_table_cache_key(db_alias, t)
|
||
for t in _get_tables(db_alias, compiler.query, compiler)]
|
||
|
||
|
||
def _invalidate_tables(cache, db_alias, tables):
|
||
tables = filter_cachable(set(tables))
|
||
if not tables:
|
||
return
|
||
now = time()
|
||
get_table_cache_key = cachalot_settings.CACHALOT_TABLE_KEYGEN
|
||
cache.set_many(
|
||
{get_table_cache_key(db_alias, t): now for t in tables},
|
||
cachalot_settings.CACHALOT_TIMEOUT)
|
||
|
||
if isinstance(cache, AtomicCache):
|
||
cache.to_be_invalidated.update(tables)
|