mirror of
https://github.com/Hopiu/django-cachalot.git
synced 2026-05-20 18:31:51 +00:00
Fixes the independence issue when nesting atomic block of different databases.
This commit is contained in:
parent
5a077ee607
commit
43bd5e1b38
5 changed files with 93 additions and 38 deletions
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf-8
|
||||
|
||||
from __future__ import unicode_literals
|
||||
from collections import defaultdict
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connections
|
||||
|
|
@ -14,11 +13,10 @@ from .utils import _get_table_cache_key, _invalidate_table_cache_keys
|
|||
__all__ = ('invalidate', 'get_last_invalidation')
|
||||
|
||||
|
||||
def _get_table_cache_keys_per_cache(tables, cache_alias, db_alias):
|
||||
def _get_table_cache_keys_per_cache_and_db(tables, cache_alias, db_alias):
|
||||
no_tables = not tables
|
||||
cache_aliases = settings.CACHES if cache_alias is None else (cache_alias,)
|
||||
db_aliases = settings.DATABASES if db_alias is None else (db_alias,)
|
||||
table_cache_keys_per_cache = defaultdict(list)
|
||||
for db_alias in db_aliases:
|
||||
if no_tables:
|
||||
tables = connections[db_alias].introspection.table_names()
|
||||
|
|
@ -26,8 +24,7 @@ def _get_table_cache_keys_per_cache(tables, cache_alias, db_alias):
|
|||
table_cache_keys = [
|
||||
_get_table_cache_key(db_alias, t) for t in tables]
|
||||
if table_cache_keys:
|
||||
table_cache_keys_per_cache[cache_alias].extend(table_cache_keys)
|
||||
return table_cache_keys_per_cache
|
||||
yield cache_alias, db_alias, table_cache_keys
|
||||
|
||||
|
||||
def _get_tables(tables_or_models):
|
||||
|
|
@ -64,11 +61,11 @@ def invalidate(*tables_or_models, **kwargs):
|
|||
raise TypeError(
|
||||
"invalidate() got an unexpected keyword argument '%s'" % k)
|
||||
|
||||
table_cache_keys_per_cache = _get_table_cache_keys_per_cache(
|
||||
table_cache_keys_per_cache = _get_table_cache_keys_per_cache_and_db(
|
||||
_get_tables(tables_or_models), cache_alias, db_alias)
|
||||
for cache_alias, table_cache_keys in table_cache_keys_per_cache.items():
|
||||
_invalidate_table_cache_keys(cachalot_caches.get_cache(cache_alias),
|
||||
table_cache_keys)
|
||||
for cache_alias, db_alias, table_cache_keys in table_cache_keys_per_cache:
|
||||
_invalidate_table_cache_keys(
|
||||
cachalot_caches.get_cache(cache_alias, db_alias), table_cache_keys)
|
||||
|
||||
|
||||
def get_last_invalidation(*tables_or_models, **kwargs):
|
||||
|
|
@ -100,11 +97,11 @@ def get_last_invalidation(*tables_or_models, **kwargs):
|
|||
"keyword argument '%s'" % k)
|
||||
|
||||
last_invalidation = 0.0
|
||||
table_cache_keys_per_cache = _get_table_cache_keys_per_cache(
|
||||
table_cache_keys_per_cache = _get_table_cache_keys_per_cache_and_db(
|
||||
_get_tables(tables_or_models), cache_alias, db_alias)
|
||||
for cache_alias, table_cache_keys in table_cache_keys_per_cache.items():
|
||||
for cache_alias, db_alias, table_cache_keys in table_cache_keys_per_cache:
|
||||
invalidations = cachalot_caches.get_cache(
|
||||
cache_alias).get_many(table_cache_keys).values()
|
||||
cache_alias, db_alias).get_many(table_cache_keys).values()
|
||||
if invalidations:
|
||||
current_last_invalidation = max(invalidations)
|
||||
if current_last_invalidation > last_invalidation:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
# coding: utf-8
|
||||
|
||||
from __future__ import unicode_literals
|
||||
from collections import defaultdict
|
||||
from threading import local
|
||||
|
||||
from django.core.cache import caches
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
|
||||
from .settings import cachalot_settings
|
||||
from .transaction import AtomicCache
|
||||
|
|
@ -13,29 +15,35 @@ class CacheHandler(local):
|
|||
@property
|
||||
def atomic_caches(self):
|
||||
if not hasattr(self, '_atomic_caches'):
|
||||
self._atomic_caches = []
|
||||
self._atomic_caches = defaultdict(list)
|
||||
return self._atomic_caches
|
||||
|
||||
def get_atomic_cache(self, cache_alias, level):
|
||||
if cache_alias not in self.atomic_caches[level]:
|
||||
self.atomic_caches[level][cache_alias] = AtomicCache(
|
||||
self.get_cache(cache_alias, level-1))
|
||||
return self.atomic_caches[level][cache_alias]
|
||||
def get_atomic_cache(self, cache_alias, db_alias, level):
|
||||
if cache_alias not in self.atomic_caches[db_alias][level]:
|
||||
self.atomic_caches[db_alias][level][cache_alias] = AtomicCache(
|
||||
self.get_cache(cache_alias, db_alias, level-1))
|
||||
return self.atomic_caches[db_alias][level][cache_alias]
|
||||
|
||||
def get_cache(self, cache_alias=None, atomic_level=-1):
|
||||
def get_cache(self, cache_alias=None, db_alias=None, atomic_level=-1):
|
||||
if db_alias is None:
|
||||
db_alias = DEFAULT_DB_ALIAS
|
||||
if cache_alias is None:
|
||||
cache_alias = cachalot_settings.CACHALOT_CACHE
|
||||
|
||||
min_level = -len(self.atomic_caches)
|
||||
min_level = -len(self.atomic_caches[db_alias])
|
||||
if atomic_level < min_level:
|
||||
return caches[cache_alias]
|
||||
return self.get_atomic_cache(cache_alias, atomic_level)
|
||||
return self.get_atomic_cache(cache_alias, db_alias, atomic_level)
|
||||
|
||||
def enter_atomic(self):
|
||||
self.atomic_caches.append({})
|
||||
def enter_atomic(self, db_alias):
|
||||
if db_alias is None:
|
||||
db_alias = DEFAULT_DB_ALIAS
|
||||
self.atomic_caches[db_alias].append({})
|
||||
|
||||
def exit_atomic(self, commit):
|
||||
atomic_caches = self.atomic_caches.pop().values()
|
||||
def exit_atomic(self, db_alias, commit):
|
||||
if db_alias is None:
|
||||
db_alias = DEFAULT_DB_ALIAS
|
||||
atomic_caches = self.atomic_caches[db_alias].pop().values()
|
||||
if commit:
|
||||
for atomic_cache in atomic_caches:
|
||||
atomic_cache.commit()
|
||||
|
|
|
|||
|
|
@ -35,9 +35,8 @@ def _unset_raw_connection(original):
|
|||
TUPLE_OR_LIST = (tuple, list)
|
||||
|
||||
|
||||
def _get_result_or_execute_query(execute_query_func, cache_key,
|
||||
table_cache_keys):
|
||||
cache = cachalot_caches.get_cache()
|
||||
def _get_result_or_execute_query(execute_query_func, cache,
|
||||
cache_key, table_cache_keys):
|
||||
data = cache.get_many(table_cache_keys + [cache_key])
|
||||
|
||||
new_table_cache_keys = set(table_cache_keys)
|
||||
|
|
@ -80,7 +79,9 @@ def _patch_compiler(original):
|
|||
return execute_query_func()
|
||||
|
||||
return _get_result_or_execute_query(
|
||||
execute_query_func, cache_key, table_cache_keys)
|
||||
execute_query_func,
|
||||
cachalot_caches.get_cache(db_alias=compiler.using),
|
||||
cache_key, table_cache_keys)
|
||||
|
||||
return inner
|
||||
|
||||
|
|
@ -88,7 +89,10 @@ def _patch_compiler(original):
|
|||
def _patch_write_compiler(original):
|
||||
@wraps(original)
|
||||
def inner(write_compiler, *args, **kwargs):
|
||||
_invalidate_table(cachalot_caches.get_cache(), write_compiler)
|
||||
db_alias = write_compiler.using
|
||||
table = write_compiler.query.get_meta().db_table
|
||||
_invalidate_table(cachalot_caches.get_cache(db_alias=db_alias),
|
||||
db_alias, table)
|
||||
return original(write_compiler, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
|
@ -123,7 +127,7 @@ def _patch_atomic():
|
|||
def patch_enter(original):
|
||||
@wraps(original)
|
||||
def inner(self):
|
||||
cachalot_caches.enter_atomic()
|
||||
cachalot_caches.enter_atomic(self.using)
|
||||
original(self)
|
||||
|
||||
return inner
|
||||
|
|
@ -133,8 +137,8 @@ def _patch_atomic():
|
|||
def inner(self, exc_type, exc_value, traceback):
|
||||
needs_rollback = get_connection(self.using).needs_rollback
|
||||
original(self, exc_type, exc_value, traceback)
|
||||
cachalot_caches.exit_atomic(exc_type is None
|
||||
and not needs_rollback)
|
||||
cachalot_caches.exit_atomic(
|
||||
self.using, exc_type is None and not needs_rollback)
|
||||
|
||||
return inner
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import unicode_literals
|
|||
from unittest import skipIf
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DEFAULT_DB_ALIAS, connections
|
||||
from django.db import DEFAULT_DB_ALIAS, connections, transaction
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from .models import Test
|
||||
|
|
@ -72,3 +72,52 @@ class MultiDatabaseTestCase(TransactionTestCase):
|
|||
with self.assertNumQueries(0):
|
||||
data2 = list(Test.objects.all())
|
||||
self.assertListEqual(data2, [self.t1, self.t2])
|
||||
|
||||
def test_heterogeneous_atomics(self):
|
||||
"""
|
||||
Checks that an atomic block for a database nested inside
|
||||
another atomic block for another database has no impact on their
|
||||
caching.
|
||||
"""
|
||||
with transaction.atomic():
|
||||
with transaction.atomic(self.db_alias2):
|
||||
with self.assertNumQueries(1):
|
||||
data1 = list(Test.objects.all())
|
||||
self.assertListEqual(data1, [self.t1, self.t2])
|
||||
with self.assertNumQueries(1, using=self.db_alias2):
|
||||
data2 = list(Test.objects.using(self.db_alias2))
|
||||
self.assertListEqual(data2, [])
|
||||
t3 = Test.objects.using(self.db_alias2).create(name='test3')
|
||||
with self.assertNumQueries(1, using=self.db_alias2):
|
||||
data3 = list(Test.objects.using(self.db_alias2))
|
||||
self.assertListEqual(data3, [t3])
|
||||
|
||||
with self.assertNumQueries(0):
|
||||
data4 = list(Test.objects.all())
|
||||
self.assertListEqual(data4, [self.t1, self.t2])
|
||||
|
||||
with self.assertNumQueries(1):
|
||||
data5 = list(Test.objects.filter(name='test3'))
|
||||
self.assertListEqual(data5, [])
|
||||
|
||||
def test_heterogeneous_atomics_independence(self):
|
||||
"""
|
||||
Checks that interrupting an atomic block after the commit of another
|
||||
atomic block for another database nested inside it
|
||||
correctly invalidates the cache for the committed transaction.
|
||||
"""
|
||||
with self.assertNumQueries(1, using=self.db_alias2):
|
||||
data1 = list(Test.objects.using(self.db_alias2))
|
||||
self.assertListEqual(data1, [])
|
||||
|
||||
try:
|
||||
with transaction.atomic():
|
||||
with transaction.atomic(self.db_alias2):
|
||||
t3 = Test.objects.using(
|
||||
self.db_alias2).create(name='test3')
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
with self.assertNumQueries(1, using=self.db_alias2):
|
||||
data2 = list(Test.objects.using(self.db_alias2))
|
||||
self.assertListEqual(data2, [t3])
|
||||
|
|
|
|||
|
|
@ -155,10 +155,7 @@ def _invalidate_table_cache_keys(cache, table_cache_keys):
|
|||
cache.set_many(d, None)
|
||||
|
||||
|
||||
def _invalidate_table(cache, write_compiler):
|
||||
db_alias = write_compiler.using
|
||||
|
||||
table = write_compiler.query.get_meta().db_table
|
||||
def _invalidate_table(cache, db_alias, table):
|
||||
table_cache_key = _get_table_cache_key(db_alias, table)
|
||||
_invalidate_table_cache_keys(cache, (table_cache_key,))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue