diff --git a/axes/handlers/base.py b/axes/handlers/base.py index 3e5aca9..04391b1 100644 --- a/axes/handlers/base.py +++ b/axes/handlers/base.py @@ -1,5 +1,6 @@ from axes.conf import settings from axes.helpers import ( + get_failure_limit, is_client_ip_address_blacklisted, is_client_ip_address_whitelisted, is_client_method_whitelisted, @@ -98,7 +99,7 @@ class AxesHandler: # pylint: disable=unused-argument """ if settings.AXES_LOCK_OUT_AT_FAILURE: - return self.get_failures(request, credentials) >= settings.AXES_FAILURE_LIMIT + return self.get_failures(request, credentials) >= get_failure_limit(request, credentials) return False diff --git a/axes/handlers/cache.py b/axes/handlers/cache.py index 517ced2..e3b067c 100644 --- a/axes/handlers/cache.py +++ b/axes/handlers/cache.py @@ -10,6 +10,7 @@ from axes.helpers import ( get_client_str, get_client_username, get_credentials, + get_failure_limit, ) log = getLogger(settings.AXES_LOGGER) @@ -59,7 +60,7 @@ class AxesCacheHandler(AxesHandler): # pylint: disable=too-many-locals 'AXES: Repeated login failure by %s. Count = %d of %d. Updating existing record in the cache.', client_str, failures_since_start, - settings.AXES_FAILURE_LIMIT, + get_failure_limit(request, credentials), ) else: log.warning( diff --git a/axes/handlers/database.py b/axes/handlers/database.py index 07b9aab..19fc01b 100644 --- a/axes/handlers/database.py +++ b/axes/handlers/database.py @@ -17,6 +17,7 @@ from axes.helpers import ( get_client_str, get_client_username, get_credentials, + get_failure_limit, get_query_str, ) @@ -82,7 +83,7 @@ class AxesDatabaseHandler(AxesHandler): # pylint: disable=too-many-locals 'AXES: Repeated login failure by %s. Count = %d of %d. Updating existing record in the database.', client_str, failures_since_start, - settings.AXES_FAILURE_LIMIT, + get_failure_limit(request, credentials), ) separator = '\n---------\n' @@ -118,7 +119,7 @@ class AxesDatabaseHandler(AxesHandler): # pylint: disable=too-many-locals attempt_time=request.axes_attempt_time, ) - if failures_since_start >= settings.AXES_FAILURE_LIMIT: + if failures_since_start >= get_failure_limit(request, credentials): log.warning('AXES: Locking out %s after repeated login failures.', client_str) request.axes_locked_out = True diff --git a/axes/helpers.py b/axes/helpers.py index d02e5aa..da9ae2b 100644 --- a/axes/helpers.py +++ b/axes/helpers.py @@ -252,6 +252,14 @@ def get_query_str(query: Type[QueryDict], max_length: int = 1024) -> str: return query_str[:max_length] +def get_failure_limit(request, credentials) -> int: + if callable(settings.AXES_FAILURE_LIMIT): + return settings.AXES_FAILURE_LIMIT(request, credentials) + if isinstance(settings.AXES_FAILURE_LIMIT, int): + return settings.AXES_FAILURE_LIMIT + raise TypeError('settings.AXES_FAILURE_LIMIT needs to be a callable or an integer') + + def get_lockout_message() -> str: if settings.AXES_COOLOFF_TIME: return settings.AXES_COOLOFF_MESSAGE @@ -261,7 +269,7 @@ def get_lockout_message() -> str: def get_lockout_response(request, credentials: dict = None) -> HttpResponse: status = 403 context = { - 'failure_limit': settings.AXES_FAILURE_LIMIT, + 'failure_limit': get_failure_limit(request, credentials), 'username': get_client_username(request, credentials) or '' } diff --git a/axes/tests/base.py b/axes/tests/base.py index fd18c73..499be9d 100644 --- a/axes/tests/base.py +++ b/axes/tests/base.py @@ -18,6 +18,7 @@ from axes.helpers import ( get_client_user_agent, get_cool_off, get_credentials, + get_failure_limit, ) from axes.models import AccessAttempt @@ -144,7 +145,7 @@ class AxesTestCase(TestCase): self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=self.ALLOWED, html=True) def almost_lockout(self): - for _ in range(1, settings.AXES_FAILURE_LIMIT): + for _ in range(1, get_failure_limit(None, None)): response = self.login() self.assertContains(response, self.LOGIN_FORM_KEY, html=True) diff --git a/axes/tests/test_handlers.py b/axes/tests/test_handlers.py index f5cc249..e90eafe 100644 --- a/axes/tests/test_handlers.py +++ b/axes/tests/test_handlers.py @@ -111,6 +111,15 @@ class AxesDatabaseHandlerTestCase(AxesHandlerBaseTestCase): def test_handler_without_reset(self): self.check_handler() + @override_settings(AXES_FAILURE_LIMIT=lambda *args: 3) + def test_handler_callable_failure_limit(self): + self.check_handler() + + @override_settings(AXES_FAILURE_LIMIT='3') + def test_handler_invalid_failure_limit(self): + with self.assertRaises(TypeError): + self.check_handler() + @patch('axes.handlers.database.log') def test_empty_request(self, log): self.check_empty_request(log, 'AxesDatabaseHandler')