From 9e1447f6ff34be5d994548d8a266c473e7a2a638 Mon Sep 17 00:00:00 2001 From: Bertrand Bordage Date: Wed, 18 Feb 2015 00:32:51 +0100 Subject: [PATCH] Handles rare complex queries. --- cachalot/utils.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/cachalot/utils.py b/cachalot/utils.py index 59b8b8a..c1c7ee0 100644 --- a/cachalot/utils.py +++ b/cachalot/utils.py @@ -6,7 +6,7 @@ from time import time import django from django.db import connections -from django.db.models.sql.where import ExtraWhere +from django.db.models.sql.where import ExtraWhere, SubqueryConstraint if django.VERSION[:2] >= (1, 7): from django.utils.module_loading import import_string else: @@ -75,6 +75,32 @@ def _get_tables_from_sql(connection, lowercased_sql): if t in lowercased_sql] +def _find_subqueries(children): + for child in children: + if isinstance(child, SubqueryConstraint): + yield child.query_object.query + elif isinstance(child, tuple) and hasattr(child[-1], 'query'): + yield child[-1].query + if hasattr(child, 'children'): + for grand_child in _find_subqueries(child.children): + yield grand_child + + +def _get_tables(query, db_alias): + tables = set(query.table_map) + tables.add(query.model._meta.db_table) + subquery_constraints = _find_subqueries(query.where.children + + query.having.children) + for subquery in subquery_constraints: + tables.update(_get_tables(subquery, db_alias)) + if query.extra_select or hasattr(query, 'subquery') \ + or any(isinstance(c, ExtraWhere) for c in query.where.children): + sql = query.get_compiler(db_alias).as_sql()[0].lower() + additional_tables = _get_tables_from_sql(connections[db_alias], sql) + tables.update(additional_tables) + return tables + + def _get_table_cache_keys(compiler): """ Returns a ``list`` of cache keys for all the SQL tables used @@ -86,18 +112,9 @@ def _get_table_cache_keys(compiler): :rtype: list """ - query = compiler.query - using = compiler.using - - tables = set(query.table_map) - tables.add(query.model._meta.db_table) - if query.extra_select or any([isinstance(c, ExtraWhere) - for c in query.where.children]): - sql = compiler.as_sql()[0].lower() - connection = connections[using] - additional_tables = _get_tables_from_sql(connection, sql) - tables.update(additional_tables) - return [_get_table_cache_key(using, t) for t in tables] + db_alias = compiler.using + tables = _get_tables(compiler.query, db_alias) + return [_get_table_cache_key(db_alias, t) for t in tables] def _invalidate_table_cache_keys(cache, table_cache_keys):