From 43bd5e1b38b11ae3f724d21761aa8923add435cf Mon Sep 17 00:00:00 2001 From: Bertrand Bordage Date: Mon, 26 Oct 2015 02:42:13 +0100 Subject: [PATCH] Fixes the independence issue when nesting atomic block of different databases. --- cachalot/api.py | 21 +++++++--------- cachalot/cache.py | 34 +++++++++++++++---------- cachalot/monkey_patch.py | 20 +++++++++------ cachalot/tests/multi_db.py | 51 +++++++++++++++++++++++++++++++++++++- cachalot/utils.py | 5 +--- 5 files changed, 93 insertions(+), 38 deletions(-) diff --git a/cachalot/api.py b/cachalot/api.py index 710024f..91f8a50 100644 --- a/cachalot/api.py +++ b/cachalot/api.py @@ -1,7 +1,6 @@ # coding: utf-8 from __future__ import unicode_literals -from collections import defaultdict from django.conf import settings from django.db import connections @@ -14,11 +13,10 @@ from .utils import _get_table_cache_key, _invalidate_table_cache_keys __all__ = ('invalidate', 'get_last_invalidation') -def _get_table_cache_keys_per_cache(tables, cache_alias, db_alias): +def _get_table_cache_keys_per_cache_and_db(tables, cache_alias, db_alias): no_tables = not tables cache_aliases = settings.CACHES if cache_alias is None else (cache_alias,) db_aliases = settings.DATABASES if db_alias is None else (db_alias,) - table_cache_keys_per_cache = defaultdict(list) for db_alias in db_aliases: if no_tables: tables = connections[db_alias].introspection.table_names() @@ -26,8 +24,7 @@ def _get_table_cache_keys_per_cache(tables, cache_alias, db_alias): table_cache_keys = [ _get_table_cache_key(db_alias, t) for t in tables] if table_cache_keys: - table_cache_keys_per_cache[cache_alias].extend(table_cache_keys) - return table_cache_keys_per_cache + yield cache_alias, db_alias, table_cache_keys def _get_tables(tables_or_models): @@ -64,11 +61,11 @@ def invalidate(*tables_or_models, **kwargs): raise TypeError( "invalidate() got an unexpected keyword argument '%s'" % k) - table_cache_keys_per_cache = _get_table_cache_keys_per_cache( + table_cache_keys_per_cache = _get_table_cache_keys_per_cache_and_db( _get_tables(tables_or_models), cache_alias, db_alias) - for cache_alias, table_cache_keys in table_cache_keys_per_cache.items(): - _invalidate_table_cache_keys(cachalot_caches.get_cache(cache_alias), - table_cache_keys) + for cache_alias, db_alias, table_cache_keys in table_cache_keys_per_cache: + _invalidate_table_cache_keys( + cachalot_caches.get_cache(cache_alias, db_alias), table_cache_keys) def get_last_invalidation(*tables_or_models, **kwargs): @@ -100,11 +97,11 @@ def get_last_invalidation(*tables_or_models, **kwargs): "keyword argument '%s'" % k) last_invalidation = 0.0 - table_cache_keys_per_cache = _get_table_cache_keys_per_cache( + table_cache_keys_per_cache = _get_table_cache_keys_per_cache_and_db( _get_tables(tables_or_models), cache_alias, db_alias) - for cache_alias, table_cache_keys in table_cache_keys_per_cache.items(): + for cache_alias, db_alias, table_cache_keys in table_cache_keys_per_cache: invalidations = cachalot_caches.get_cache( - cache_alias).get_many(table_cache_keys).values() + cache_alias, db_alias).get_many(table_cache_keys).values() if invalidations: current_last_invalidation = max(invalidations) if current_last_invalidation > last_invalidation: diff --git a/cachalot/cache.py b/cachalot/cache.py index fb68c15..6a35350 100644 --- a/cachalot/cache.py +++ b/cachalot/cache.py @@ -1,9 +1,11 @@ # coding: utf-8 from __future__ import unicode_literals +from collections import defaultdict from threading import local from django.core.cache import caches +from django.db import DEFAULT_DB_ALIAS from .settings import cachalot_settings from .transaction import AtomicCache @@ -13,29 +15,35 @@ class CacheHandler(local): @property def atomic_caches(self): if not hasattr(self, '_atomic_caches'): - self._atomic_caches = [] + self._atomic_caches = defaultdict(list) return self._atomic_caches - def get_atomic_cache(self, cache_alias, level): - if cache_alias not in self.atomic_caches[level]: - self.atomic_caches[level][cache_alias] = AtomicCache( - self.get_cache(cache_alias, level-1)) - return self.atomic_caches[level][cache_alias] + def get_atomic_cache(self, cache_alias, db_alias, level): + if cache_alias not in self.atomic_caches[db_alias][level]: + self.atomic_caches[db_alias][level][cache_alias] = AtomicCache( + self.get_cache(cache_alias, db_alias, level-1)) + return self.atomic_caches[db_alias][level][cache_alias] - def get_cache(self, cache_alias=None, atomic_level=-1): + def get_cache(self, cache_alias=None, db_alias=None, atomic_level=-1): + if db_alias is None: + db_alias = DEFAULT_DB_ALIAS if cache_alias is None: cache_alias = cachalot_settings.CACHALOT_CACHE - min_level = -len(self.atomic_caches) + min_level = -len(self.atomic_caches[db_alias]) if atomic_level < min_level: return caches[cache_alias] - return self.get_atomic_cache(cache_alias, atomic_level) + return self.get_atomic_cache(cache_alias, db_alias, atomic_level) - def enter_atomic(self): - self.atomic_caches.append({}) + def enter_atomic(self, db_alias): + if db_alias is None: + db_alias = DEFAULT_DB_ALIAS + self.atomic_caches[db_alias].append({}) - def exit_atomic(self, commit): - atomic_caches = self.atomic_caches.pop().values() + def exit_atomic(self, db_alias, commit): + if db_alias is None: + db_alias = DEFAULT_DB_ALIAS + atomic_caches = self.atomic_caches[db_alias].pop().values() if commit: for atomic_cache in atomic_caches: atomic_cache.commit() diff --git a/cachalot/monkey_patch.py b/cachalot/monkey_patch.py index 530e9c3..1708c43 100644 --- a/cachalot/monkey_patch.py +++ b/cachalot/monkey_patch.py @@ -35,9 +35,8 @@ def _unset_raw_connection(original): TUPLE_OR_LIST = (tuple, list) -def _get_result_or_execute_query(execute_query_func, cache_key, - table_cache_keys): - cache = cachalot_caches.get_cache() +def _get_result_or_execute_query(execute_query_func, cache, + cache_key, table_cache_keys): data = cache.get_many(table_cache_keys + [cache_key]) new_table_cache_keys = set(table_cache_keys) @@ -80,7 +79,9 @@ def _patch_compiler(original): return execute_query_func() return _get_result_or_execute_query( - execute_query_func, cache_key, table_cache_keys) + execute_query_func, + cachalot_caches.get_cache(db_alias=compiler.using), + cache_key, table_cache_keys) return inner @@ -88,7 +89,10 @@ def _patch_compiler(original): def _patch_write_compiler(original): @wraps(original) def inner(write_compiler, *args, **kwargs): - _invalidate_table(cachalot_caches.get_cache(), write_compiler) + db_alias = write_compiler.using + table = write_compiler.query.get_meta().db_table + _invalidate_table(cachalot_caches.get_cache(db_alias=db_alias), + db_alias, table) return original(write_compiler, *args, **kwargs) return inner @@ -123,7 +127,7 @@ def _patch_atomic(): def patch_enter(original): @wraps(original) def inner(self): - cachalot_caches.enter_atomic() + cachalot_caches.enter_atomic(self.using) original(self) return inner @@ -133,8 +137,8 @@ def _patch_atomic(): def inner(self, exc_type, exc_value, traceback): needs_rollback = get_connection(self.using).needs_rollback original(self, exc_type, exc_value, traceback) - cachalot_caches.exit_atomic(exc_type is None - and not needs_rollback) + cachalot_caches.exit_atomic( + self.using, exc_type is None and not needs_rollback) return inner diff --git a/cachalot/tests/multi_db.py b/cachalot/tests/multi_db.py index 61c8151..413dc5e 100644 --- a/cachalot/tests/multi_db.py +++ b/cachalot/tests/multi_db.py @@ -4,7 +4,7 @@ from __future__ import unicode_literals from unittest import skipIf from django.conf import settings -from django.db import DEFAULT_DB_ALIAS, connections +from django.db import DEFAULT_DB_ALIAS, connections, transaction from django.test import TransactionTestCase from .models import Test @@ -72,3 +72,52 @@ class MultiDatabaseTestCase(TransactionTestCase): with self.assertNumQueries(0): data2 = list(Test.objects.all()) self.assertListEqual(data2, [self.t1, self.t2]) + + def test_heterogeneous_atomics(self): + """ + Checks that an atomic block for a database nested inside + another atomic block for another database has no impact on their + caching. + """ + with transaction.atomic(): + with transaction.atomic(self.db_alias2): + with self.assertNumQueries(1): + data1 = list(Test.objects.all()) + self.assertListEqual(data1, [self.t1, self.t2]) + with self.assertNumQueries(1, using=self.db_alias2): + data2 = list(Test.objects.using(self.db_alias2)) + self.assertListEqual(data2, []) + t3 = Test.objects.using(self.db_alias2).create(name='test3') + with self.assertNumQueries(1, using=self.db_alias2): + data3 = list(Test.objects.using(self.db_alias2)) + self.assertListEqual(data3, [t3]) + + with self.assertNumQueries(0): + data4 = list(Test.objects.all()) + self.assertListEqual(data4, [self.t1, self.t2]) + + with self.assertNumQueries(1): + data5 = list(Test.objects.filter(name='test3')) + self.assertListEqual(data5, []) + + def test_heterogeneous_atomics_independence(self): + """ + Checks that interrupting an atomic block after the commit of another + atomic block for another database nested inside it + correctly invalidates the cache for the committed transaction. + """ + with self.assertNumQueries(1, using=self.db_alias2): + data1 = list(Test.objects.using(self.db_alias2)) + self.assertListEqual(data1, []) + + try: + with transaction.atomic(): + with transaction.atomic(self.db_alias2): + t3 = Test.objects.using( + self.db_alias2).create(name='test3') + raise ZeroDivisionError + except ZeroDivisionError: + pass + with self.assertNumQueries(1, using=self.db_alias2): + data2 = list(Test.objects.using(self.db_alias2)) + self.assertListEqual(data2, [t3]) diff --git a/cachalot/utils.py b/cachalot/utils.py index 3d5eb60..2e4980b 100644 --- a/cachalot/utils.py +++ b/cachalot/utils.py @@ -155,10 +155,7 @@ def _invalidate_table_cache_keys(cache, table_cache_keys): cache.set_many(d, None) -def _invalidate_table(cache, write_compiler): - db_alias = write_compiler.using - - table = write_compiler.query.get_meta().db_table +def _invalidate_table(cache, db_alias, table): table_cache_key = _get_table_cache_key(db_alias, table) _invalidate_table_cache_keys(cache, (table_cache_key,))