diff --git a/axes/signals.py b/axes/signals.py index 1c59acb..b41dc30 100644 --- a/axes/signals.py +++ b/axes/signals.py @@ -3,6 +3,7 @@ from logging import getLogger from django.contrib.auth.signals import user_logged_in from django.contrib.auth.signals import user_logged_out from django.contrib.auth.signals import user_login_failed +from django.core.signals import setting_changed from django.db.models.signals import post_save, post_delete from django.dispatch import receiver from django.dispatch import Signal @@ -32,14 +33,14 @@ class ProxyHandler: implementation = None # concrete handler that is bootstrapped by the Django application loader @classmethod - def initialize(cls): + def initialize(cls, force=False): """ Fetch and initialize concrete handler implementation and memoize it to avoid reinitialization. This method is re-entrant and can be called multiple times. """ - if cls.implementation is None: + if force or cls.implementation is None: cls.implementation = import_string(settings.AXES_HANDLER)() @classmethod @@ -124,3 +125,14 @@ def handle_post_save_access_attempt(*args, **kwargs): @receiver(post_delete, sender=AccessAttempt) def handle_post_delete_access_attempt(*args, **kwargs): ProxyHandler.post_delete_access_attempt(*args, **kwargs) + + +@receiver(setting_changed) +def handle_setting_changed(sender, setting, value, enter, **kwargs): # pylint: disable=unused-argument + """ + Reinitialize handler implementation if a relevant setting changes + in e.g. application reconfiguration or during testing. + """ + + if enter and setting == 'AXES_HANDLER': + ProxyHandler.initialize(force=enter) diff --git a/axes/tests/test_handlers.py b/axes/tests/test_handlers.py index ebd6fdf..6d05181 100644 --- a/axes/tests/test_handlers.py +++ b/axes/tests/test_handlers.py @@ -16,6 +16,15 @@ class ProxyHandlerTestCase(TestCase): self.user = MagicMock() self.instance = MagicMock() + @patch('axes.signals.import_string', return_value=AxesHandler) + def test_setting_changed_signal_triggers_handler_reimport(self, importer): + self.assertEqual(0, importer.call_count) + + with self.settings( + AXES_HANDLER='axes.handlers.AxesHandler' + ): + self.assertEqual(1, importer.call_count) + @patch('axes.signals.ProxyHandler.implementation', None) @patch('axes.signals.import_string', return_value=AxesHandler) def test_initialize(self, importer):