From 8ed0d82384353564e30738906aca2da03086b87d Mon Sep 17 00:00:00 2001 From: Bruno Alla Date: Tue, 1 Oct 2024 10:13:49 -0300 Subject: [PATCH] refactor: remove attempt_time parameter As we pass down the whole request, we no longer need to extract the axes_attempt_time anymore. This is a potential breaking change, but the impacted functions are not part of the documented API. --- axes/attempts.py | 13 +++++-------- axes/handlers/database.py | 6 +++--- tests/test_attempts.py | 11 ++++++----- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/axes/attempts.py b/axes/attempts.py index 9eed18b..d518b53 100644 --- a/axes/attempts.py +++ b/axes/attempts.py @@ -12,9 +12,7 @@ from axes.models import AccessAttempt log = getLogger(__name__) -def get_cool_off_threshold( - attempt_time: Optional[datetime] = None, request: Optional[HttpRequest] = None -) -> datetime: +def get_cool_off_threshold(request: Optional[HttpRequest] = None) -> datetime: """ Get threshold for fetching access attempts from the database. """ @@ -25,6 +23,7 @@ def get_cool_off_threshold( "Cool off threshold can not be calculated with settings.AXES_COOLOFF_TIME set to None" ) + attempt_time = request.axes_attempt_time if attempt_time is None: return now() - cool_off return attempt_time - cool_off @@ -64,14 +63,12 @@ def get_user_attempts( ) return attempts_list - threshold = get_cool_off_threshold(request.axes_attempt_time, request) + threshold = get_cool_off_threshold(request) log.debug("AXES: Getting access attempts that are newer than %s", threshold) return [attempts.filter(attempt_time__gte=threshold) for attempts in attempts_list] -def clean_expired_user_attempts( - attempt_time: Optional[datetime] = None, request: Optional[HttpRequest] = None -) -> int: +def clean_expired_user_attempts(request: Optional[HttpRequest] = None) -> int: """ Clean expired user attempts from the database. """ @@ -82,7 +79,7 @@ def clean_expired_user_attempts( ) return 0 - threshold = get_cool_off_threshold(attempt_time, request) + threshold = get_cool_off_threshold(request) count, _ = AccessAttempt.objects.filter(attempt_time__lt=threshold).delete() log.info( "AXES: Cleaned up %s expired access attempts from database that were older than %s", diff --git a/axes/handlers/database.py b/axes/handlers/database.py index 80df685..1b5f1d0 100644 --- a/axes/handlers/database.py +++ b/axes/handlers/database.py @@ -132,7 +132,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler): return # 1. database query: Clean up expired user attempts from the database before logging new attempts - clean_expired_user_attempts(request.axes_attempt_time, request) + clean_expired_user_attempts(request) username = get_client_username(request, credentials) client_str = get_client_str( @@ -262,7 +262,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler): """ # 1. database query: Clean up expired user attempts from the database - clean_expired_user_attempts(request.axes_attempt_time, request) + clean_expired_user_attempts(request) username = user.get_username() credentials = get_credentials(username) @@ -305,7 +305,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler): """ # 1. database query: Clean up expired user attempts from the database - clean_expired_user_attempts(request.axes_attempt_time, request) + clean_expired_user_attempts(request) username = user.get_username() if user else None client_str = get_client_str( diff --git a/tests/test_attempts.py b/tests/test_attempts.py index 04af617..188071b 100644 --- a/tests/test_attempts.py +++ b/tests/test_attempts.py @@ -1,7 +1,7 @@ from unittest.mock import patch from django.http import HttpRequest -from django.test import override_settings +from django.test import override_settings, RequestFactory from django.utils.timezone import now from axes.attempts import get_cool_off_threshold @@ -15,12 +15,13 @@ class GetCoolOffThresholdTestCase(AxesTestCase): def test_get_cool_off_threshold(self): timestamp = now() + request = RequestFactory().post("/") with patch("axes.attempts.now", return_value=timestamp): - attempt_time = timestamp - threshold_now = get_cool_off_threshold(attempt_time) + request.axes_attempt_time = timestamp + threshold_now = get_cool_off_threshold(request) - attempt_time = None - threshold_none = get_cool_off_threshold(attempt_time) + request.axes_attempt_time = None + threshold_none = get_cool_off_threshold(request) self.assertEqual(threshold_now, threshold_none)