diff --git a/axes/handlers/cache.py b/axes/handlers/cache.py index 3f65d44..06da63c 100644 --- a/axes/handlers/cache.py +++ b/axes/handlers/cache.py @@ -11,6 +11,7 @@ from axes.helpers import ( get_credentials, get_failure_limit, ) +from axes.models import AccessBase from axes.signals import user_locked_out log = getLogger(__name__) @@ -25,6 +26,35 @@ class AxesCacheHandler(AbstractAxesHandler, AxesBaseHandler): self.cache = get_cache() self.cache_timeout = get_cache_timeout() + def reset_attempts( + self, + *, + ip_address: str = None, + username: str = None, + ip_or_username: bool = False, + ) -> int: + cache_keys: list = [] + count = 0 + + if ip_address is None and username is None: + raise NotImplementedError("Cannot clear all entries from cache") + if ip_or_username: + raise NotImplementedError( + "Due to the cache key ip_or_username=True is not supported" + ) + + cache_keys.extend( + get_client_cache_key(AccessBase(username=username, ip_address=ip_address)) + ) + + for cache_key in cache_keys: + deleted = self.cache.delete(cache_key) + count += int(deleted) if deleted is not None else 1 + + log.info("AXES: Reset %d access attempts from database.", count) + + return count + def get_failures(self, request, credentials: dict = None) -> int: cache_keys = get_client_cache_key(request, credentials) failure_count = max( diff --git a/tests/base.py b/tests/base.py index 4dc3407..3cfecad 100644 --- a/tests/base.py +++ b/tests/base.py @@ -105,7 +105,13 @@ class AxesTestCase(TestCase): def reset(self, ip=None, username=None): return reset(ip, username) - def login(self, is_valid_username=False, is_valid_password=False, **kwargs): + def login( + self, + is_valid_username=False, + is_valid_password=False, + remote_addr=None, + **kwargs + ): """ Login a user. @@ -128,7 +134,7 @@ class AxesTestCase(TestCase): return self.client.post( reverse("admin:login"), post_data, - REMOTE_ADDR=self.ip_address, + REMOTE_ADDR=remote_addr or self.ip_address, HTTP_USER_AGENT=self.user_agent, ) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 7c31d1a..0e397de 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -330,6 +330,98 @@ class AxesDatabaseHandlerTestCase(AxesHandlerBaseTestCase): AccessAttempt.objects.all().delete() +@override_settings(AXES_HANDLER="axes.handlers.cache.AxesCacheHandler") +class ResetAttemptsCacheHandlerTestCase(AxesHandlerBaseTestCase): + """ Test reset attempts for the cache handler """ + + USERNAME_1 = "foo_username" + USERNAME_2 = "bar_username" + IP_1 = "127.1.0.1" + IP_2 = "127.1.0.2" + + def set_up_login_attemtps(self): + """Set up the login attempts.""" + self.login(username=self.USERNAME_1, remote_addr=self.IP_1) + self.login(username=self.USERNAME_1, remote_addr=self.IP_2) + self.login(username=self.USERNAME_2, remote_addr=self.IP_1) + self.login(username=self.USERNAME_2, remote_addr=self.IP_2) + + def check_failures(self, failures, username=None, ip_address=None): + if ip_address is None and username is None: + raise NotImplementedError("Must supply ip_address or username") + try: + prev_ip = self.request.META["REMOTE_ADDR"] + credentials = {"username": username} if username else {} + if ip_address is not None: + self.request.META["REMOTE_ADDR"] = ip_address + self.assertEqual( + failures, + AxesProxyHandler.get_failures(self.request, credentials=credentials), + ) + finally: + self.request.META["REMOTE_ADDR"] = prev_ip + + def test_handler_reset_attempts(self): + with self.assertRaises(NotImplementedError): + AxesProxyHandler.reset_attempts() + + @override_settings(AXES_ONLY_USER_FAILURES=True) + def test_handler_reset_attempts_username(self): + self.set_up_login_attemtps() + self.assertEqual( + 2, + AxesProxyHandler.get_failures( + self.request, credentials={"username": self.USERNAME_1} + ), + ) + self.assertEqual( + 2, + AxesProxyHandler.get_failures( + self.request, credentials={"username": self.USERNAME_2} + ), + ) + self.assertEqual(1, AxesProxyHandler.reset_attempts(username=self.USERNAME_1)) + self.assertEqual( + 0, + AxesProxyHandler.get_failures( + self.request, credentials={"username": self.USERNAME_1} + ), + ) + self.assertEqual( + 2, + AxesProxyHandler.get_failures( + self.request, credentials={"username": self.USERNAME_2} + ), + ) + + def test_handler_reset_attempts_ip(self): + self.set_up_login_attemtps() + self.check_failures(2, ip_address=self.IP_1) + self.assertEqual(1, AxesProxyHandler.reset_attempts(ip_address=self.IP_1)) + self.check_failures(0, ip_address=self.IP_1) + self.check_failures(2, ip_address=self.IP_2) + + @override_settings(AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=True) + def test_handler_reset_attempts_ip_and_username(self): + self.set_up_login_attemtps() + self.check_failures(1, username=self.USERNAME_1, ip_address=self.IP_1) + self.check_failures(1, username=self.USERNAME_2, ip_address=self.IP_1) + self.check_failures(1, username=self.USERNAME_1, ip_address=self.IP_2) + self.assertEqual( + 1, + AxesProxyHandler.reset_attempts( + ip_address=self.IP_1, username=self.USERNAME_1 + ), + ) + self.check_failures(0, username=self.USERNAME_1, ip_address=self.IP_1) + self.check_failures(1, username=self.USERNAME_2, ip_address=self.IP_1) + self.check_failures(1, username=self.USERNAME_1, ip_address=self.IP_2) + + def test_handler_reset_attempts_ip_or_username(self): + with self.assertRaises(NotImplementedError): + AxesProxyHandler.reset_attempts() + + @override_settings( AXES_HANDLER="axes.handlers.cache.AxesCacheHandler", AXES_COOLOFF_TIME=timedelta(seconds=1),