mirror of
https://github.com/jazzband/django-axes.git
synced 2026-03-16 22:30:23 +00:00
Enhance get_lockout_response to support original_response parameter
This commit is contained in:
parent
69c97d5c7b
commit
04fd39fa57
3 changed files with 39 additions and 9 deletions
|
|
@ -460,15 +460,26 @@ def get_lockout_message() -> str:
|
|||
|
||||
|
||||
def get_lockout_response(
|
||||
request: HttpRequest, credentials: Optional[dict] = None
|
||||
request: HttpRequest,
|
||||
original_response: Optional[HttpResponse] = None,
|
||||
credentials: Optional[dict] = None,
|
||||
) -> HttpResponse:
|
||||
if settings.AXES_LOCKOUT_CALLABLE:
|
||||
if callable(settings.AXES_LOCKOUT_CALLABLE):
|
||||
return settings.AXES_LOCKOUT_CALLABLE( # pylint: disable=not-callable
|
||||
request, credentials
|
||||
)
|
||||
# Try calling with 3 args, fallback to 2 for backward compatibility
|
||||
try:
|
||||
return settings.AXES_LOCKOUT_CALLABLE(
|
||||
request, original_response, credentials
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback: old signature without original_response
|
||||
return settings.AXES_LOCKOUT_CALLABLE(request, credentials)
|
||||
if isinstance(settings.AXES_LOCKOUT_CALLABLE, str):
|
||||
return import_string(settings.AXES_LOCKOUT_CALLABLE)(request, credentials)
|
||||
callable_obj = import_string(settings.AXES_LOCKOUT_CALLABLE)
|
||||
try:
|
||||
return callable_obj(request, original_response, credentials)
|
||||
except TypeError:
|
||||
return callable_obj(request, credentials)
|
||||
raise TypeError(
|
||||
"settings.AXES_LOCKOUT_CALLABLE needs to be a string, callable, or None."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class AxesMiddleware:
|
|||
if settings.AXES_ENABLED:
|
||||
if getattr(request, "axes_locked_out", None):
|
||||
credentials = getattr(request, "axes_credentials", None)
|
||||
response = get_lockout_response(request, credentials) # type: ignore
|
||||
response = get_lockout_response(request, response, credentials) # type: ignore
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -60,8 +60,6 @@ class AxesMiddleware:
|
|||
credentials = getattr(request, "axes_credentials", None)
|
||||
response = await sync_to_async(
|
||||
get_lockout_response, thread_sensitive=True
|
||||
)(
|
||||
request, credentials
|
||||
) # type: ignore
|
||||
)(request, credentials) # type: ignore
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1013,9 +1013,16 @@ def mock_get_lockout_response(request, credentials):
|
|||
return HttpResponse(status=400)
|
||||
|
||||
|
||||
def mock_get_lockout_response_with_original_response_param(
|
||||
request, response, credentials
|
||||
):
|
||||
return HttpResponse(status=400)
|
||||
|
||||
|
||||
class AxesLockoutTestCase(AxesTestCase):
|
||||
def setUp(self):
|
||||
self.request = HttpRequest()
|
||||
self.response = HttpResponse()
|
||||
self.credentials = dict()
|
||||
|
||||
def test_get_lockout_response(self):
|
||||
|
|
@ -1039,6 +1046,20 @@ class AxesLockoutTestCase(AxesTestCase):
|
|||
response = get_lockout_response(self.request, self.credentials)
|
||||
self.assertEqual(400, response.status_code)
|
||||
|
||||
@override_settings(
|
||||
AXES_LOCKOUT_CALLABLE=mock_get_lockout_response_with_original_response_param
|
||||
)
|
||||
def test_get_lockout_response_override_callable_with_original_response_param(self):
|
||||
response = get_lockout_response(self.request, self.response, self.credentials)
|
||||
self.assertEqual(400, response.status_code)
|
||||
|
||||
@override_settings(
|
||||
AXES_LOCKOUT_CALLABLE="tests.test_helpers.mock_get_lockout_response_with_original_response_param"
|
||||
)
|
||||
def test_get_lockout_response_override_path_with_original_response_param(self):
|
||||
response = get_lockout_response(self.request, self.response, self.credentials)
|
||||
self.assertEqual(400, response.status_code)
|
||||
|
||||
@override_settings(AXES_LOCKOUT_CALLABLE=42)
|
||||
def test_get_lockout_response_override_invalid(self):
|
||||
with self.assertRaises(TypeError):
|
||||
|
|
|
|||
Loading…
Reference in a new issue