From 6f2584b440959fb29a5ccf2e323629c0bba3629c Mon Sep 17 00:00:00 2001 From: Taikono-Himazin Date: Wed, 13 Dec 2023 13:35:22 +0900 Subject: [PATCH] Add async support to middleware --- axes/middleware.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/axes/middleware.py b/axes/middleware.py index 1f70ce9..b965907 100644 --- a/axes/middleware.py +++ b/axes/middleware.py @@ -1,5 +1,6 @@ from typing import Callable +from asgiref.sync import iscoroutinefunction, markcoroutinefunction, sync_to_async from django.conf import settings from django.http import HttpRequest, HttpResponse @@ -30,15 +31,36 @@ class AxesMiddleware: - ``AXES_PERMALOCK_MESSAGE``. """ + async_capable = True + sync_capable = True + def __init__(self, get_response: Callable) -> None: self.get_response = get_response + if iscoroutinefunction(self.get_response): + markcoroutinefunction(self) def __call__(self, request: HttpRequest) -> HttpResponse: - response = self.get_response(request) + # Exit out to async mode, if needed + if iscoroutinefunction(self): + return self.__acall__(request) + response = self.get_response(request) if settings.AXES_ENABLED: if getattr(request, "axes_locked_out", None): credentials = getattr(request, "axes_credentials", None) response = get_lockout_response(request, credentials) # type: ignore return response + + async def __acall__(self, request: HttpRequest) -> HttpResponse: + response = await self.get_response(request) + + if settings.AXES_ENABLED: + if getattr(request, "axes_locked_out", None): + credentials = getattr(request, "axes_credentials", None) + response = await sync_to_async( + get_lockout_response, + thread_sensitive=True + )(request, credentials) # type: ignore + + return response