diff --git a/cachalot/utils.py b/cachalot/utils.py index 3113c6b..72792fc 100644 --- a/cachalot/utils.py +++ b/cachalot/utils.py @@ -9,8 +9,10 @@ from uuid import UUID from django import VERSION as django_version from django.db import connections +from django.db.models.expressions import RawSQL from django.db.models.sql import Query -from django.db.models.sql.where import ExtraWhere, SubqueryConstraint +from django.db.models.sql.where import ( + ExtraWhere, SubqueryConstraint, WhereNode) from django.utils.module_loading import import_string from django.utils.six import text_type, binary_type @@ -22,6 +24,10 @@ class UncachableQuery(Exception): pass +class IsRawQuery(Exception): + pass + + TUPLE_OR_LIST = {tuple, list} CACHABLE_PARAM_TYPES = { @@ -112,25 +118,30 @@ def _get_tables_from_sql(connection, lowercased_sql): def _find_subqueries(children): for child in children: - if child.__class__ is SubqueryConstraint: + child_class = child.__class__ + if child_class is WhereNode: + for grand_child in _find_subqueries(child.children): + yield grand_child + elif child_class is SubqueryConstraint: if child.query_object.__class__ is Query: yield child.query_object else: yield child.query_object.query + elif child_class is ExtraWhere: + raise IsRawQuery else: rhs = None if hasattr(child, 'rhs'): rhs = child.rhs rhs_class = rhs.__class__ + if rhs_class is RawSQL: + raise IsRawQuery if rhs_class is Query: yield rhs elif hasattr(rhs, 'query'): yield rhs.query elif rhs_class in UNCACHABLE_FUNCS: raise UncachableQuery - if hasattr(child, 'children'): - for grand_child in _find_subqueries(child.children): - yield grand_child def is_cachable(table): @@ -161,16 +172,16 @@ def _get_tables(db_alias, query): and not cachalot_settings.CACHALOT_CACHE_RANDOM): raise UncachableQuery - if query.extra_select or getattr(query, 'subquery', False) \ - or any(c.__class__ is ExtraWhere for c in query.where.children): - sql = query.get_compiler(db_alias).as_sql()[0].lower() - tables = _get_tables_from_sql(connections[db_alias], sql) - else: + try: + if query.extra_select or getattr(query, 'subquery', False): + raise IsRawQuery tables = set(query.table_map) tables.add(query.get_meta().db_table) - subquery_constraints = _find_subqueries(query.where.children) - for subquery in subquery_constraints: + for subquery in _find_subqueries(query.where.children): tables.update(_get_tables(db_alias, subquery)) + except IsRawQuery: + sql = query.get_compiler(db_alias).as_sql()[0].lower() + tables = _get_tables_from_sql(connections[db_alias], sql) if not are_all_cachable(tables): raise UncachableQuery