diff --git a/axes/middleware.py b/axes/middleware.py index 1f70ce9..b965907 100644 --- a/axes/middleware.py +++ b/axes/middleware.py @@ -1,5 +1,6 @@ from typing import Callable +from asgiref.sync import iscoroutinefunction, markcoroutinefunction, sync_to_async from django.conf import settings from django.http import HttpRequest, HttpResponse @@ -30,15 +31,36 @@ class AxesMiddleware: - ``AXES_PERMALOCK_MESSAGE``. """ + async_capable = True + sync_capable = True + def __init__(self, get_response: Callable) -> None: self.get_response = get_response + if iscoroutinefunction(self.get_response): + markcoroutinefunction(self) def __call__(self, request: HttpRequest) -> HttpResponse: - response = self.get_response(request) + # Exit out to async mode, if needed + if iscoroutinefunction(self): + return self.__acall__(request) + response = self.get_response(request) if settings.AXES_ENABLED: if getattr(request, "axes_locked_out", None): credentials = getattr(request, "axes_credentials", None) response = get_lockout_response(request, credentials) # type: ignore return response + + async def __acall__(self, request: HttpRequest) -> HttpResponse: + response = await self.get_response(request) + + if settings.AXES_ENABLED: + if getattr(request, "axes_locked_out", None): + credentials = getattr(request, "axes_credentials", None) + response = await sync_to_async( + get_lockout_response, + thread_sensitive=True + )(request, credentials) # type: ignore + + return response