diff --git a/axes/handlers/cache.py b/axes/handlers/cache.py index e1f21ef..cfaa56a 100644 --- a/axes/handlers/cache.py +++ b/axes/handlers/cache.py @@ -128,6 +128,7 @@ class AxesCacheHandler(AbstractAxesHandler, AxesBaseHandler): ) request.axes_locked_out = True + request.axes_credentials = credentials user_locked_out.send( "axes", request=request, diff --git a/axes/handlers/database.py b/axes/handlers/database.py index 8939814..2995c06 100644 --- a/axes/handlers/database.py +++ b/axes/handlers/database.py @@ -186,6 +186,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler): ) request.axes_locked_out = True + request.axes_credentials = credentials user_locked_out.send( "axes", request=request, diff --git a/axes/handlers/proxy.py b/axes/handlers/proxy.py index 670fad1..a539db3 100644 --- a/axes/handlers/proxy.py +++ b/axes/handlers/proxy.py @@ -78,6 +78,7 @@ class AxesProxyHandler(AbstractAxesHandler, AxesBaseHandler): request.axes_http_accept = get_client_http_accept(request) request.axes_failures_since_start = None request.axes_updated = True + request.axes_credentials = None @classmethod def is_locked(cls, request, credentials: dict = None) -> bool: diff --git a/axes/middleware.py b/axes/middleware.py index b493f2c..8e29707 100644 --- a/axes/middleware.py +++ b/axes/middleware.py @@ -37,6 +37,7 @@ class AxesMiddleware: if settings.AXES_ENABLED: if getattr(request, "axes_locked_out", None): - response = get_lockout_response(request) # type: ignore + credentials = getattr(request, 'axes_credentials', None) + response = get_lockout_response(request, credentials) # type: ignore return response diff --git a/tests/test_middleware.py b/tests/test_middleware.py index a5c8e41..8b098d1 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,3 +1,4 @@ +from django.conf import settings from django.http import HttpResponse, HttpRequest from django.test import override_settings @@ -5,6 +6,10 @@ from axes.middleware import AxesMiddleware from tests.base import AxesTestCase +def get_username(request, credentials: dict) -> str: + return credentials.get(settings.AXES_USERNAME_FORM_FIELD) + + class MiddlewareTestCase(AxesTestCase): STATUS_SUCCESS = 200 STATUS_LOCKOUT = 403 @@ -28,6 +33,17 @@ class MiddlewareTestCase(AxesTestCase): response = AxesMiddleware(get_response)(self.request) self.assertEqual(response.status_code, self.STATUS_LOCKOUT) + @override_settings(AXES_USERNAME_CALLABLE="tests.test_middleware.get_username") + def test_lockout_response_with_axes_callable_username(self): + def get_response(request): + request.axes_locked_out = True + request.axes_credentials = {settings.AXES_USERNAME_FORM_FIELD: 'username'} + + return HttpResponse() + + response = AxesMiddleware(get_response)(self.request) + self.assertEqual(response.status_code, self.STATUS_LOCKOUT) + @override_settings(AXES_ENABLED=False) def test_respects_enabled_switch(self): def get_response(request):