diff --git a/cachalot/monkey_patch.py b/cachalot/monkey_patch.py index 8713dff..048d00e 100644 --- a/cachalot/monkey_patch.py +++ b/cachalot/monkey_patch.py @@ -16,8 +16,7 @@ if django_version >= (1, 7): from django.db.models.signals import post_migrate else: from django.db.models.signals import post_syncdb as post_migrate -from django.db.models.sql.compiler import ( - SQLCompiler, SQLInsertCompiler, SQLUpdateCompiler, SQLDeleteCompiler) +from django.db.models.sql.compiler import SQLCompiler from django.db.transaction import Atomic, get_connection from django.test import TransactionTestCase @@ -26,10 +25,8 @@ from .cache import cachalot_caches from .settings import cachalot_settings from .utils import ( _get_query_cache_key, _invalidate_tables, - _get_table_cache_keys, _get_tables_from_sql, RandomQueryException) - - -WRITE_COMPILERS = (SQLInsertCompiler, SQLUpdateCompiler, SQLDeleteCompiler) + _get_table_cache_keys, _get_tables_from_sql, RandomQueryException, + WRITE_COMPILERS) PATCHED = False @@ -102,9 +99,9 @@ def _patch_compiler(original): def _patch_write_compiler(original): @wraps(original) - def inner(compiler, *args, **kwargs): - _invalidate_tables(cachalot_caches.get_cache(), compiler) - return original(compiler, *args, **kwargs) + def inner(write_compiler, *args, **kwargs): + _invalidate_tables(cachalot_caches.get_cache(), write_compiler) + return original(write_compiler, *args, **kwargs) inner.original = original return inner diff --git a/cachalot/utils.py b/cachalot/utils.py index 631871f..02102ca 100644 --- a/cachalot/utils.py +++ b/cachalot/utils.py @@ -7,6 +7,8 @@ from time import time import django from django.db import connections from django.db.models.sql import Query +from django.db.models.sql.compiler import ( + SQLInsertCompiler, SQLUpdateCompiler, SQLDeleteCompiler) from django.db.models.sql.where import ExtraWhere, SubqueryConstraint DJANGO_GTE_1_7 = django.VERSION[:2] >= (1, 7) if DJANGO_GTE_1_7: @@ -18,6 +20,9 @@ from .settings import cachalot_settings from .signals import post_invalidation +WRITE_COMPILERS = (SQLInsertCompiler, SQLUpdateCompiler, SQLDeleteCompiler) + + class RandomQueryException(Exception): pass @@ -125,11 +130,11 @@ def _invalidate_table_cache_keys(cache, table_cache_keys): cache.set_many(d, None) -def _invalidate_tables(cache, compiler): - db_alias = compiler.using - tables = _get_tables(compiler.query, db_alias) - table_cache_keys = [_get_table_cache_key(db_alias, t) for t in tables] - _invalidate_table_cache_keys(cache, table_cache_keys) +def _invalidate_tables(cache, write_compiler): + db_alias = write_compiler.using - for table in tables: - post_invalidation.send(table, db_alias=db_alias) + table = write_compiler.query.get_meta().db_table + table_cache_key = _get_table_cache_key(db_alias, table) + _invalidate_table_cache_keys(cache, (table_cache_key,)) + + post_invalidation.send(table, db_alias=db_alias)