Enhance get_lockout_response to support original_response parameter

This commit is contained in:
Mounir Messelmeni 2025-09-04 13:57:13 +02:00 committed by Aleksi Häkli
parent 69c97d5c7b
commit 04fd39fa57
3 changed files with 39 additions and 9 deletions

View file

@ -460,15 +460,26 @@ def get_lockout_message() -> str:
def get_lockout_response(
request: HttpRequest, credentials: Optional[dict] = None
request: HttpRequest,
original_response: Optional[HttpResponse] = None,
credentials: Optional[dict] = None,
) -> HttpResponse:
if settings.AXES_LOCKOUT_CALLABLE:
if callable(settings.AXES_LOCKOUT_CALLABLE):
return settings.AXES_LOCKOUT_CALLABLE( # pylint: disable=not-callable
request, credentials
)
# Try calling with 3 args, fallback to 2 for backward compatibility
try:
return settings.AXES_LOCKOUT_CALLABLE(
request, original_response, credentials
)
except TypeError:
# Fallback: old signature without original_response
return settings.AXES_LOCKOUT_CALLABLE(request, credentials)
if isinstance(settings.AXES_LOCKOUT_CALLABLE, str):
return import_string(settings.AXES_LOCKOUT_CALLABLE)(request, credentials)
callable_obj = import_string(settings.AXES_LOCKOUT_CALLABLE)
try:
return callable_obj(request, original_response, credentials)
except TypeError:
return callable_obj(request, credentials)
raise TypeError(
"settings.AXES_LOCKOUT_CALLABLE needs to be a string, callable, or None."
)

View file

@ -48,7 +48,7 @@ class AxesMiddleware:
if settings.AXES_ENABLED:
if getattr(request, "axes_locked_out", None):
credentials = getattr(request, "axes_credentials", None)
response = get_lockout_response(request, credentials) # type: ignore
response = get_lockout_response(request, response, credentials) # type: ignore
return response
@ -60,8 +60,6 @@ class AxesMiddleware:
credentials = getattr(request, "axes_credentials", None)
response = await sync_to_async(
get_lockout_response, thread_sensitive=True
)(
request, credentials
) # type: ignore
)(request, credentials) # type: ignore
return response

View file

@ -1013,9 +1013,16 @@ def mock_get_lockout_response(request, credentials):
return HttpResponse(status=400)
def mock_get_lockout_response_with_original_response_param(
request, response, credentials
):
return HttpResponse(status=400)
class AxesLockoutTestCase(AxesTestCase):
def setUp(self):
self.request = HttpRequest()
self.response = HttpResponse()
self.credentials = dict()
def test_get_lockout_response(self):
@ -1039,6 +1046,20 @@ class AxesLockoutTestCase(AxesTestCase):
response = get_lockout_response(self.request, self.credentials)
self.assertEqual(400, response.status_code)
@override_settings(
AXES_LOCKOUT_CALLABLE=mock_get_lockout_response_with_original_response_param
)
def test_get_lockout_response_override_callable_with_original_response_param(self):
response = get_lockout_response(self.request, self.response, self.credentials)
self.assertEqual(400, response.status_code)
@override_settings(
AXES_LOCKOUT_CALLABLE="tests.test_helpers.mock_get_lockout_response_with_original_response_param"
)
def test_get_lockout_response_override_path_with_original_response_param(self):
response = get_lockout_response(self.request, self.response, self.credentials)
self.assertEqual(400, response.status_code)
@override_settings(AXES_LOCKOUT_CALLABLE=42)
def test_get_lockout_response_override_invalid(self):
with self.assertRaises(TypeError):