diff --git a/cachalot/monkey_patch.py b/cachalot/monkey_patch.py index 1914130..6c0f95c 100644 --- a/cachalot/monkey_patch.py +++ b/cachalot/monkey_patch.py @@ -17,6 +17,7 @@ from django.db.models.sql.compiler import ( SQLInsertCompiler, SQLUpdateCompiler, SQLDeleteCompiler) from django.db.models.sql.where import ExtraWhere from django.db.transaction import Atomic +from django.test import TransactionTestCase from .settings import cachalot_settings @@ -256,34 +257,22 @@ def _patch_atomic(): Atomic.__exit__ = patch_exit(Atomic.__exit__) -def _patch_test_db(): - def patch_creation(original): +def _patch_test_teardown(): + def patch_teardown(original): @wraps(original) def inner(*args, **kwargs): - out = original(*args, **kwargs) + original(*args, **kwargs) clear_all_caches() - return out inner.original = original return inner - def patch_destruction(original): - @wraps(original) - def inner(*args, **kwargs): - clear_all_caches() - return original(*args, **kwargs) - - inner.original = original - return inner - - creation = connection.creation - creation.create_test_db = patch_creation(creation.create_test_db) - creation.destroy_test_db = patch_destruction(creation.destroy_test_db) + TransactionTestCase._fixture_teardown = patch_teardown(TransactionTestCase._fixture_teardown) def patch(): global PATCHED - _patch_test_db() + _patch_test_teardown() _patch_orm_write() _patch_orm_read() _patch_atomic()