Fixes the independence issue when nesting atomic block of different databases.

This commit is contained in:
Bertrand Bordage 2015-10-26 02:42:13 +01:00
parent 5a077ee607
commit 43bd5e1b38
5 changed files with 93 additions and 38 deletions

View file

@ -1,7 +1,6 @@
# coding: utf-8
from __future__ import unicode_literals
from collections import defaultdict
from django.conf import settings
from django.db import connections
@ -14,11 +13,10 @@ from .utils import _get_table_cache_key, _invalidate_table_cache_keys
__all__ = ('invalidate', 'get_last_invalidation')
def _get_table_cache_keys_per_cache(tables, cache_alias, db_alias):
def _get_table_cache_keys_per_cache_and_db(tables, cache_alias, db_alias):
no_tables = not tables
cache_aliases = settings.CACHES if cache_alias is None else (cache_alias,)
db_aliases = settings.DATABASES if db_alias is None else (db_alias,)
table_cache_keys_per_cache = defaultdict(list)
for db_alias in db_aliases:
if no_tables:
tables = connections[db_alias].introspection.table_names()
@ -26,8 +24,7 @@ def _get_table_cache_keys_per_cache(tables, cache_alias, db_alias):
table_cache_keys = [
_get_table_cache_key(db_alias, t) for t in tables]
if table_cache_keys:
table_cache_keys_per_cache[cache_alias].extend(table_cache_keys)
return table_cache_keys_per_cache
yield cache_alias, db_alias, table_cache_keys
def _get_tables(tables_or_models):
@ -64,11 +61,11 @@ def invalidate(*tables_or_models, **kwargs):
raise TypeError(
"invalidate() got an unexpected keyword argument '%s'" % k)
table_cache_keys_per_cache = _get_table_cache_keys_per_cache(
table_cache_keys_per_cache = _get_table_cache_keys_per_cache_and_db(
_get_tables(tables_or_models), cache_alias, db_alias)
for cache_alias, table_cache_keys in table_cache_keys_per_cache.items():
_invalidate_table_cache_keys(cachalot_caches.get_cache(cache_alias),
table_cache_keys)
for cache_alias, db_alias, table_cache_keys in table_cache_keys_per_cache:
_invalidate_table_cache_keys(
cachalot_caches.get_cache(cache_alias, db_alias), table_cache_keys)
def get_last_invalidation(*tables_or_models, **kwargs):
@ -100,11 +97,11 @@ def get_last_invalidation(*tables_or_models, **kwargs):
"keyword argument '%s'" % k)
last_invalidation = 0.0
table_cache_keys_per_cache = _get_table_cache_keys_per_cache(
table_cache_keys_per_cache = _get_table_cache_keys_per_cache_and_db(
_get_tables(tables_or_models), cache_alias, db_alias)
for cache_alias, table_cache_keys in table_cache_keys_per_cache.items():
for cache_alias, db_alias, table_cache_keys in table_cache_keys_per_cache:
invalidations = cachalot_caches.get_cache(
cache_alias).get_many(table_cache_keys).values()
cache_alias, db_alias).get_many(table_cache_keys).values()
if invalidations:
current_last_invalidation = max(invalidations)
if current_last_invalidation > last_invalidation:

View file

@ -1,9 +1,11 @@
# coding: utf-8
from __future__ import unicode_literals
from collections import defaultdict
from threading import local
from django.core.cache import caches
from django.db import DEFAULT_DB_ALIAS
from .settings import cachalot_settings
from .transaction import AtomicCache
@ -13,29 +15,35 @@ class CacheHandler(local):
@property
def atomic_caches(self):
if not hasattr(self, '_atomic_caches'):
self._atomic_caches = []
self._atomic_caches = defaultdict(list)
return self._atomic_caches
def get_atomic_cache(self, cache_alias, level):
if cache_alias not in self.atomic_caches[level]:
self.atomic_caches[level][cache_alias] = AtomicCache(
self.get_cache(cache_alias, level-1))
return self.atomic_caches[level][cache_alias]
def get_atomic_cache(self, cache_alias, db_alias, level):
if cache_alias not in self.atomic_caches[db_alias][level]:
self.atomic_caches[db_alias][level][cache_alias] = AtomicCache(
self.get_cache(cache_alias, db_alias, level-1))
return self.atomic_caches[db_alias][level][cache_alias]
def get_cache(self, cache_alias=None, atomic_level=-1):
def get_cache(self, cache_alias=None, db_alias=None, atomic_level=-1):
if db_alias is None:
db_alias = DEFAULT_DB_ALIAS
if cache_alias is None:
cache_alias = cachalot_settings.CACHALOT_CACHE
min_level = -len(self.atomic_caches)
min_level = -len(self.atomic_caches[db_alias])
if atomic_level < min_level:
return caches[cache_alias]
return self.get_atomic_cache(cache_alias, atomic_level)
return self.get_atomic_cache(cache_alias, db_alias, atomic_level)
def enter_atomic(self):
self.atomic_caches.append({})
def enter_atomic(self, db_alias):
if db_alias is None:
db_alias = DEFAULT_DB_ALIAS
self.atomic_caches[db_alias].append({})
def exit_atomic(self, commit):
atomic_caches = self.atomic_caches.pop().values()
def exit_atomic(self, db_alias, commit):
if db_alias is None:
db_alias = DEFAULT_DB_ALIAS
atomic_caches = self.atomic_caches[db_alias].pop().values()
if commit:
for atomic_cache in atomic_caches:
atomic_cache.commit()

View file

@ -35,9 +35,8 @@ def _unset_raw_connection(original):
TUPLE_OR_LIST = (tuple, list)
def _get_result_or_execute_query(execute_query_func, cache_key,
table_cache_keys):
cache = cachalot_caches.get_cache()
def _get_result_or_execute_query(execute_query_func, cache,
cache_key, table_cache_keys):
data = cache.get_many(table_cache_keys + [cache_key])
new_table_cache_keys = set(table_cache_keys)
@ -80,7 +79,9 @@ def _patch_compiler(original):
return execute_query_func()
return _get_result_or_execute_query(
execute_query_func, cache_key, table_cache_keys)
execute_query_func,
cachalot_caches.get_cache(db_alias=compiler.using),
cache_key, table_cache_keys)
return inner
@ -88,7 +89,10 @@ def _patch_compiler(original):
def _patch_write_compiler(original):
@wraps(original)
def inner(write_compiler, *args, **kwargs):
_invalidate_table(cachalot_caches.get_cache(), write_compiler)
db_alias = write_compiler.using
table = write_compiler.query.get_meta().db_table
_invalidate_table(cachalot_caches.get_cache(db_alias=db_alias),
db_alias, table)
return original(write_compiler, *args, **kwargs)
return inner
@ -123,7 +127,7 @@ def _patch_atomic():
def patch_enter(original):
@wraps(original)
def inner(self):
cachalot_caches.enter_atomic()
cachalot_caches.enter_atomic(self.using)
original(self)
return inner
@ -133,8 +137,8 @@ def _patch_atomic():
def inner(self, exc_type, exc_value, traceback):
needs_rollback = get_connection(self.using).needs_rollback
original(self, exc_type, exc_value, traceback)
cachalot_caches.exit_atomic(exc_type is None
and not needs_rollback)
cachalot_caches.exit_atomic(
self.using, exc_type is None and not needs_rollback)
return inner

View file

@ -4,7 +4,7 @@ from __future__ import unicode_literals
from unittest import skipIf
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS, connections
from django.db import DEFAULT_DB_ALIAS, connections, transaction
from django.test import TransactionTestCase
from .models import Test
@ -72,3 +72,52 @@ class MultiDatabaseTestCase(TransactionTestCase):
with self.assertNumQueries(0):
data2 = list(Test.objects.all())
self.assertListEqual(data2, [self.t1, self.t2])
def test_heterogeneous_atomics(self):
"""
Checks that an atomic block for a database nested inside
another atomic block for another database has no impact on their
caching.
"""
with transaction.atomic():
with transaction.atomic(self.db_alias2):
with self.assertNumQueries(1):
data1 = list(Test.objects.all())
self.assertListEqual(data1, [self.t1, self.t2])
with self.assertNumQueries(1, using=self.db_alias2):
data2 = list(Test.objects.using(self.db_alias2))
self.assertListEqual(data2, [])
t3 = Test.objects.using(self.db_alias2).create(name='test3')
with self.assertNumQueries(1, using=self.db_alias2):
data3 = list(Test.objects.using(self.db_alias2))
self.assertListEqual(data3, [t3])
with self.assertNumQueries(0):
data4 = list(Test.objects.all())
self.assertListEqual(data4, [self.t1, self.t2])
with self.assertNumQueries(1):
data5 = list(Test.objects.filter(name='test3'))
self.assertListEqual(data5, [])
def test_heterogeneous_atomics_independence(self):
"""
Checks that interrupting an atomic block after the commit of another
atomic block for another database nested inside it
correctly invalidates the cache for the committed transaction.
"""
with self.assertNumQueries(1, using=self.db_alias2):
data1 = list(Test.objects.using(self.db_alias2))
self.assertListEqual(data1, [])
try:
with transaction.atomic():
with transaction.atomic(self.db_alias2):
t3 = Test.objects.using(
self.db_alias2).create(name='test3')
raise ZeroDivisionError
except ZeroDivisionError:
pass
with self.assertNumQueries(1, using=self.db_alias2):
data2 = list(Test.objects.using(self.db_alias2))
self.assertListEqual(data2, [t3])

View file

@ -155,10 +155,7 @@ def _invalidate_table_cache_keys(cache, table_cache_keys):
cache.set_many(d, None)
def _invalidate_table(cache, write_compiler):
db_alias = write_compiler.using
table = write_compiler.query.get_meta().db_table
def _invalidate_table(cache, db_alias, table):
table_cache_key = _get_table_cache_key(db_alias, table)
_invalidate_table_cache_keys(cache, (table_cache_key,))