diff --git a/cachalot/monkey_patch.py b/cachalot/monkey_patch.py index 305dec5..dd49f03 100644 --- a/cachalot/monkey_patch.py +++ b/cachalot/monkey_patch.py @@ -257,25 +257,38 @@ def _patch_atomic(): Atomic.__exit__ = patch_exit(Atomic.__exit__) -def _patch_test_teardown(): - def patch_teardown(original): +def _patch_tests(): + def patch_before(original): @wraps(original) def inner(*args, **kwargs): - original(*args, **kwargs) clear_all_caches() + return original(*args, **kwargs) inner.original = original return inner - TransactionTestCase._fixture_setup = patch_teardown( + def patch_after(original): + @wraps(original) + def inner(*args, **kwargs): + out = original(*args, **kwargs) + clear_all_caches() + return out + + inner.original = original + return inner + + creation = connection.creation + creation.create_test_db = patch_after(creation.create_test_db) + creation.destroy_test_db = patch_before(creation.destroy_test_db) + TransactionTestCase._fixture_setup = patch_after( TransactionTestCase._fixture_setup) - TransactionTestCase._fixture_teardown = patch_teardown( + TransactionTestCase._fixture_teardown = patch_after( TransactionTestCase._fixture_teardown) def patch(): global PATCHED - _patch_test_teardown() + _patch_tests() _patch_orm_write() _patch_orm_read() _patch_atomic()