feat: pass down the request in a few more places

This commit is contained in:
Bruno Alla 2024-10-01 10:02:12 -03:00 committed by Aleksi Häkli
parent 510c8d18f5
commit a304380853
3 changed files with 13 additions and 9 deletions

View file

@ -12,12 +12,14 @@ from axes.models import AccessAttempt
log = getLogger(__name__)
def get_cool_off_threshold(attempt_time: Optional[datetime] = None) -> datetime:
def get_cool_off_threshold(
attempt_time: Optional[datetime] = None, request: Optional[HttpRequest] = None
) -> datetime:
"""
Get threshold for fetching access attempts from the database.
"""
cool_off = get_cool_off()
cool_off = get_cool_off(request)
if cool_off is None:
raise TypeError(
"Cool off threshold can not be calculated with settings.AXES_COOLOFF_TIME set to None"
@ -62,12 +64,14 @@ def get_user_attempts(
)
return attempts_list
threshold = get_cool_off_threshold(request.axes_attempt_time)
threshold = get_cool_off_threshold(request.axes_attempt_time, request)
log.debug("AXES: Getting access attempts that are newer than %s", threshold)
return [attempts.filter(attempt_time__gte=threshold) for attempts in attempts_list]
def clean_expired_user_attempts(attempt_time: Optional[datetime] = None) -> int:
def clean_expired_user_attempts(
attempt_time: Optional[datetime] = None, request: Optional[HttpRequest] = None
) -> int:
"""
Clean expired user attempts from the database.
"""
@ -78,7 +82,7 @@ def clean_expired_user_attempts(attempt_time: Optional[datetime] = None) -> int:
)
return 0
threshold = get_cool_off_threshold(attempt_time)
threshold = get_cool_off_threshold(attempt_time, request)
count, _ = AccessAttempt.objects.filter(attempt_time__lt=threshold).delete()
log.info(
"AXES: Cleaned up %s expired access attempts from database that were older than %s",

View file

@ -132,7 +132,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler):
return
# 1. database query: Clean up expired user attempts from the database before logging new attempts
clean_expired_user_attempts(request.axes_attempt_time)
clean_expired_user_attempts(request.axes_attempt_time, request)
username = get_client_username(request, credentials)
client_str = get_client_str(
@ -262,7 +262,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler):
"""
# 1. database query: Clean up expired user attempts from the database
clean_expired_user_attempts(request.axes_attempt_time)
clean_expired_user_attempts(request.axes_attempt_time, request)
username = user.get_username()
credentials = get_credentials(username)
@ -305,7 +305,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler):
"""
# 1. database query: Clean up expired user attempts from the database
clean_expired_user_attempts(request.axes_attempt_time)
clean_expired_user_attempts(request.axes_attempt_time, request)
username = user.get_username() if user else None
client_str = get_client_str(

View file

@ -474,7 +474,7 @@ def get_lockout_response(
"username": get_client_username(request, credentials) or "",
}
cool_off = get_cool_off()
cool_off = get_cool_off(request)
if cool_off:
context.update(
{