diff --git a/axes/helpers.py b/axes/helpers.py index 5e1a5f9..8347974 100644 --- a/axes/helpers.py +++ b/axes/helpers.py @@ -460,15 +460,26 @@ def get_lockout_message() -> str: def get_lockout_response( - request: HttpRequest, credentials: Optional[dict] = None + request: HttpRequest, + original_response: Optional[HttpResponse] = None, + credentials: Optional[dict] = None, ) -> HttpResponse: if settings.AXES_LOCKOUT_CALLABLE: if callable(settings.AXES_LOCKOUT_CALLABLE): - return settings.AXES_LOCKOUT_CALLABLE( # pylint: disable=not-callable - request, credentials - ) + # Try calling with 3 args, fallback to 2 for backward compatibility + try: + return settings.AXES_LOCKOUT_CALLABLE( + request, original_response, credentials + ) + except TypeError: + # Fallback: old signature without original_response + return settings.AXES_LOCKOUT_CALLABLE(request, credentials) if isinstance(settings.AXES_LOCKOUT_CALLABLE, str): - return import_string(settings.AXES_LOCKOUT_CALLABLE)(request, credentials) + callable_obj = import_string(settings.AXES_LOCKOUT_CALLABLE) + try: + return callable_obj(request, original_response, credentials) + except TypeError: + return callable_obj(request, credentials) raise TypeError( "settings.AXES_LOCKOUT_CALLABLE needs to be a string, callable, or None." ) diff --git a/axes/middleware.py b/axes/middleware.py index 0b8d16c..189ee78 100644 --- a/axes/middleware.py +++ b/axes/middleware.py @@ -48,7 +48,7 @@ class AxesMiddleware: if settings.AXES_ENABLED: if getattr(request, "axes_locked_out", None): credentials = getattr(request, "axes_credentials", None) - response = get_lockout_response(request, credentials) # type: ignore + response = get_lockout_response(request, response, credentials) # type: ignore return response @@ -60,8 +60,6 @@ class AxesMiddleware: credentials = getattr(request, "axes_credentials", None) response = await sync_to_async( get_lockout_response, thread_sensitive=True - )( - request, credentials - ) # type: ignore + )(request, credentials) # type: ignore return response diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 1d5c8aa..584cfc4 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1013,9 +1013,16 @@ def mock_get_lockout_response(request, credentials): return HttpResponse(status=400) +def mock_get_lockout_response_with_original_response_param( + request, response, credentials +): + return HttpResponse(status=400) + + class AxesLockoutTestCase(AxesTestCase): def setUp(self): self.request = HttpRequest() + self.response = HttpResponse() self.credentials = dict() def test_get_lockout_response(self): @@ -1039,6 +1046,20 @@ class AxesLockoutTestCase(AxesTestCase): response = get_lockout_response(self.request, self.credentials) self.assertEqual(400, response.status_code) + @override_settings( + AXES_LOCKOUT_CALLABLE=mock_get_lockout_response_with_original_response_param + ) + def test_get_lockout_response_override_callable_with_original_response_param(self): + response = get_lockout_response(self.request, self.response, self.credentials) + self.assertEqual(400, response.status_code) + + @override_settings( + AXES_LOCKOUT_CALLABLE="tests.test_helpers.mock_get_lockout_response_with_original_response_param" + ) + def test_get_lockout_response_override_path_with_original_response_param(self): + response = get_lockout_response(self.request, self.response, self.credentials) + self.assertEqual(400, response.status_code) + @override_settings(AXES_LOCKOUT_CALLABLE=42) def test_get_lockout_response_override_invalid(self): with self.assertRaises(TypeError):