From 46fc5381932b9b98dae7d0a9f7f0a02ab0f65fa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksi=20H=C3=A4kli?= Date: Sat, 23 Feb 2019 01:22:11 +0200 Subject: [PATCH] Add cache handler and refactor tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Aleksi Häkli --- axes/attempts.py | 124 ++-- axes/conf.py | 2 +- axes/handlers/base.py | 65 +- axes/handlers/cache.py | 116 ++++ axes/handlers/database.py | 194 ++---- axes/handlers/proxy.py | 14 +- axes/management/commands/axes_reset.py | 2 +- axes/management/commands/axes_reset_ip.py | 2 +- .../commands/axes_reset_username.py | 2 +- axes/tests/base.py | 162 +++++ axes/tests/settings.py | 9 +- axes/tests/test_attempt.py | 590 ------------------ axes/tests/test_attempts.py | 54 ++ axes/tests/test_backends.py | 5 +- axes/tests/test_checks.py | 5 +- axes/tests/test_decorators.py | 4 +- axes/tests/test_handlers.py | 158 ++--- axes/tests/test_logging.py | 91 ++- axes/tests/test_login.py | 135 +++- axes/tests/test_management.py | 4 +- axes/tests/test_middleware.py | 4 +- axes/tests/test_models.py | 6 +- axes/tests/test_signals.py | 18 + axes/tests/test_utils.py | 237 +++++-- axes/utils.py | 57 +- 25 files changed, 1075 insertions(+), 985 deletions(-) create mode 100644 axes/handlers/cache.py create mode 100644 axes/tests/base.py delete mode 100644 axes/tests/test_attempt.py create mode 100644 axes/tests/test_attempts.py create mode 100644 axes/tests/test_signals.py diff --git a/axes/attempts.py b/axes/attempts.py index 013f90e..ff7317b 100644 --- a/axes/attempts.py +++ b/axes/attempts.py @@ -1,51 +1,32 @@ -from hashlib import md5 from logging import getLogger -from typing import Union +from django.contrib.auth import get_user_model from django.db.models import QuerySet from django.http import HttpRequest -from django.utils.timezone import now +from django.utils.timezone import datetime, now from axes.conf import settings from axes.models import AccessAttempt from axes.utils import ( - get_axes_cache, get_client_ip_address, get_client_username, get_client_user_agent, - get_cache_timeout, - get_cool_off, get_client_parameters, + get_cool_off, ) log = getLogger(settings.AXES_LOGGER) -def get_cache_key(request_or_attempt: Union[HttpRequest, AccessAttempt], credentials: dict = None) -> str: +def get_cool_off_threshold(attempt_time: datetime = None) -> datetime: """ - Build cache key name from request or AccessAttempt object. - - :param request_or_attempt: HttpRequest or AccessAttempt object - :param credentials: credentials containing user information - :return cache_key: Hash key that is usable for Django cache backends + Get threshold for fetching access attempts from the database. """ - if isinstance(request_or_attempt, AccessAttempt): - username = request_or_attempt.username - ip_address = request_or_attempt.ip_address - user_agent = request_or_attempt.user_agent - else: - username = get_client_username(request_or_attempt, credentials) - ip_address = get_client_ip_address(request_or_attempt) - user_agent = get_client_user_agent(request_or_attempt) + if attempt_time is None: + return now() - get_cool_off() - filter_kwargs = get_client_parameters(username, ip_address, user_agent) - - cache_key_components = ''.join(filter_kwargs.values()) - cache_key_digest = md5(cache_key_components.encode()).hexdigest() - cache_key = 'axes-{}'.format(cache_key_digest) - - return cache_key + return attempt_time - get_cool_off() def filter_user_attempts(request: HttpRequest, credentials: dict = None) -> QuerySet: @@ -62,35 +43,35 @@ def filter_user_attempts(request: HttpRequest, credentials: dict = None) -> Quer return AccessAttempt.objects.filter(**filter_kwargs) -def get_user_attempts(request: HttpRequest, credentials: dict = None) -> QuerySet: +def get_user_attempts(request: HttpRequest, credentials: dict = None, attempt_time: datetime = None) -> QuerySet: """ - Get valid user attempts and delete expired attempts which have cool offs in the past. + Get valid user attempts that match the given request and credentials. """ attempts = filter_user_attempts(request, credentials) - # If settings.AXES_COOLOFF_TIME is not configured return the attempts - cool_off = get_cool_off() - if cool_off is None: + if settings.AXES_COOLOFF_TIME is None: + log.debug('AXES: Getting all access attempts from database because no AXES_COOLOFF_TIME is configured') return attempts - # Else AccessAttempts that have expired need to be cleaned up from the database - num_deleted, _ = attempts.filter(attempt_time__lte=now() - cool_off).delete() - if not num_deleted: - return attempts + threshold = get_cool_off_threshold(attempt_time) + log.debug('AXES: Getting access attempts that are newer than %s', threshold) + return attempts.filter(attempt_time__gte=threshold) - # If there deletions the cache needs to be updated - cache_key = get_cache_key(request, credentials) - num_failures_cached = get_axes_cache().get(cache_key) - if num_failures_cached is not None: - get_axes_cache().set( - cache_key, - num_failures_cached - num_deleted, - get_cache_timeout(), - ) - # AccessAttempts need to be refreshed from the database because of the delete before returning them - return attempts.all() +def clean_expired_user_attempts(attempt_time: datetime = None) -> int: + """ + Clean expired user attempts from the database. + """ + + if settings.AXES_COOLOFF_TIME is None: + log.debug('AXES: Skipping clean for expired access attempts because no AXES_COOLOFF_TIME is configured') + return 0 + + threshold = get_cool_off_threshold(attempt_time) + count, _ = AccessAttempt.objects.filter(attempt_time__lt=threshold).delete() + log.info('AXES: Cleaned up %s expired access attempts from database that were older than %s', count, threshold) + return count def reset_user_attempts(request: HttpRequest, credentials: dict = None) -> int: @@ -99,6 +80,55 @@ def reset_user_attempts(request: HttpRequest, credentials: dict = None) -> int: """ attempts = filter_user_attempts(request, credentials) + count, _ = attempts.delete() + log.info('AXES: Reset %s access attempts from database.', count) return count + + +def reset(ip: str = None, username: str = None) -> int: + """ + Reset records that match IP or username, and return the count of removed attempts. + + This utility method is meant to be used from the CLI or via Python API. + """ + + attempts = AccessAttempt.objects.all() + + if ip: + attempts = attempts.filter(ip_address=ip) + if username: + attempts = attempts.filter(username=username) + + count, _ = attempts.delete() + log.info('AXES: Reset %s access attempts from database.', count) + + return count + + +def is_user_attempt_whitelisted(request: HttpRequest, credentials: dict = None) -> bool: + """ + Check if the given request or credentials refer to a whitelisted username. + + A whitelisted user has the magic ``nolockout`` property set. + + If the property is unknown or False or the user can not be found, + this implementation fails gracefully and returns True. + """ + + username_field = getattr(get_user_model(), 'USERNAME_FIELD', 'username') + username_value = get_client_username(request, credentials) + kwargs = { + username_field: username_value + } + + user_model = get_user_model() + + try: + user = user_model.objects.get(**kwargs) + return user.nolockout + except (user_model.DoesNotExist, AttributeError): + pass + + return False diff --git a/axes/conf.py b/axes/conf.py index 4a66f29..681aac8 100644 --- a/axes/conf.py +++ b/axes/conf.py @@ -4,7 +4,7 @@ from django.utils.translation import gettext_lazy as _ from appconf import AppConf -class MyAppConf(AppConf): +class AxesAppConf(AppConf): # see if the user has overridden the failure limit FAILURE_LIMIT = 3 diff --git a/axes/handlers/base.py b/axes/handlers/base.py index 4e7cd4f..fd336e4 100644 --- a/axes/handlers/base.py +++ b/axes/handlers/base.py @@ -1,6 +1,12 @@ -from typing import Any, Dict, Optional - from django.http import HttpRequest +from django.utils.timezone import datetime + +from axes.conf import settings +from axes.utils import ( + is_client_ip_address_blacklisted, + is_client_ip_address_whitelisted, + is_client_method_whitelisted, +) class AxesBaseHandler: # pylint: disable=unused-argument @@ -11,7 +17,7 @@ class AxesBaseHandler: # pylint: disable=unused-argument and define the class to be used with ``settings.AXES_HANDLER = 'dotted.full.path.to.YourClass'``. """ - def is_allowed(self, request: HttpRequest, credentials: Optional[Dict[str, Any]] = None) -> bool: + def is_allowed(self, request: HttpRequest, credentials: dict = None) -> bool: """ Check if the user is allowed to access or use given functionality such as a login view or authentication. @@ -26,9 +32,18 @@ class AxesBaseHandler: # pylint: disable=unused-argument and inspiration on some common checks and access restrictions before writing your own implementation. """ - raise NotImplementedError('The Axes handler class needs a method definition for is_allowed') + if self.is_blacklisted(request, credentials): + return False - def user_login_failed(self, sender, credentials: Dict[str, Any], request: HttpRequest, **kwargs): + if self.is_whitelisted(request, credentials): + return True + + if self.is_locked(request, credentials): + return False + + return True + + def user_login_failed(self, sender, credentials: dict, request: HttpRequest = None, **kwargs): """ Handle the Django user_login_failed authentication signal. """ @@ -52,3 +67,43 @@ class AxesBaseHandler: # pylint: disable=unused-argument """ Handle the Axes AccessAttempt object post delete signal. """ + + def is_blacklisted(self, request: HttpRequest, credentials: dict = None) -> bool: # pylint: disable=unused-argument + """ + Check if the request or given credentials are blacklisted from access. + """ + + if is_client_ip_address_blacklisted(request): + return True + + return False + + def is_whitelisted(self, request: HttpRequest, credentials: dict = None) -> bool: # pylint: disable=unused-argument + """ + Check if the request or given credentials are whitelisted for access. + """ + + if is_client_ip_address_whitelisted(request): + return True + + if is_client_method_whitelisted(request): + return True + + return False + + def is_locked(self, request: HttpRequest, credentials: dict = None, attempt_time: datetime = None) -> bool: + """ + Check if the request or given credentials are locked. + """ + + if settings.AXES_LOCK_OUT_AT_FAILURE: + return self.get_failures(request, credentials, attempt_time) >= settings.AXES_FAILURE_LIMIT + + return False + + def get_failures(self, request: HttpRequest, credentials: dict = None, attempt_time: datetime = None) -> int: + """ + Check the number of failures associated to the given request and credentials. + """ + + raise NotImplementedError('The Axes handler class needs a method definition for get_failures') diff --git a/axes/handlers/cache.py b/axes/handlers/cache.py new file mode 100644 index 0000000..1a00d41 --- /dev/null +++ b/axes/handlers/cache.py @@ -0,0 +1,116 @@ +from logging import getLogger + +from axes.conf import settings +from axes.exceptions import AxesSignalPermissionDenied +from axes.handlers.base import AxesBaseHandler +from axes.signals import user_locked_out +from axes.utils import ( + get_axes_cache, + get_client_cache_key, + get_client_ip_address, + get_client_path_info, + get_client_str, + get_client_username, + get_client_user_agent, + get_credentials, + get_cool_off, +) + +log = getLogger(settings.AXES_LOGGER) + + +class AxesCacheHandler(AxesBaseHandler): # pylint: disable=too-many-locals + """ + Signal handler implementation that records user login attempts to cache and locks users out if necessary. + """ + + def get_failures(self, request, credentials=None, attempt_time=None) -> int: + cache = get_axes_cache() + cache_key = get_client_cache_key(request, credentials) + return cache.get(cache_key, default=0) + + def user_login_failed(self, sender, credentials, request=None, **kwargs): # pylint: disable=too-many-locals + """ + When user login fails, save attempt record in cache and lock user out if necessary. + + :raises AxesSignalPermissionDenied: if user should be locked out. + """ + + if request is None: + log.error('AXES: AxesCacheHandler.user_login_failed does not function without a request.') + return + + username = get_client_username(request, credentials) + ip_address = get_client_ip_address(request) + user_agent = get_client_user_agent(request) + path_info = get_client_path_info(request) + client_str = get_client_str(username, ip_address, user_agent, path_info) + + if self.is_whitelisted(request, credentials): + log.info('AXES: Login failed from whitelisted client %s.', client_str) + return + + failures_since_start = 1 + self.get_failures(request, credentials) + + if failures_since_start > 1: + log.warning( + 'AXES: Repeated login failure by %s. Count = %d of %d. Updating existing record in the cache.', + client_str, + failures_since_start, + settings.AXES_FAILURE_LIMIT, + ) + else: + log.warning( + 'AXES: New login failure by %s. Creating new record in the cache.', + client_str, + ) + + cache = get_axes_cache() + cache_key = get_client_cache_key(request, credentials) + cache_timeout = get_cool_off().total_seconds() + + cache.set(cache_key, failures_since_start, cache_timeout) + + if failures_since_start >= settings.AXES_FAILURE_LIMIT: + log.warning('AXES: Locking out %s after repeated login failures.', client_str) + + user_locked_out.send( + 'axes', + request=request, + username=username, + ip_address=ip_address, + ) + + raise AxesSignalPermissionDenied('Locked out due to repeated login failures.') + + def user_logged_in(self, sender, request, user, **kwargs): # pylint: disable=unused-argument + """ + When user logs in, update the AccessLog related to the user. + """ + + username = user.get_username() + credentials = get_credentials(username) + ip_address = get_client_ip_address(request) + user_agent = get_client_user_agent(request) + path_info = get_client_path_info(request) + client_str = get_client_str(username, ip_address, user_agent, path_info) + + log.info('AXES: Successful login by %s.', client_str) + + if settings.AXES_RESET_ON_SUCCESS: + cache = get_axes_cache() + cache_key = get_client_cache_key(request, credentials) + + failures_since_start = cache.get(cache_key, default=0) + log.info('AXES: Deleting %d failed login attempts by %s from cache.', failures_since_start, client_str) + + cache.delete(cache_key) + + def user_logged_out(self, sender, request, user, **kwargs): + username = user.get_username() + ip_address = get_client_ip_address(request) + user_agent = get_client_user_agent(request) + path_info = get_client_path_info(request) + client_str = get_client_str(username, ip_address, user_agent, path_info) + + log.info('AXES: Successful logout by %s.', client_str) diff --git a/axes/handlers/database.py b/axes/handlers/database.py index 1f4097c..a97b826 100644 --- a/axes/handlers/database.py +++ b/axes/handlers/database.py @@ -1,24 +1,21 @@ from logging import getLogger -from typing import Any, Dict, Optional from django.db.models import Max, Value from django.db.models.functions import Concat -from django.http import HttpRequest from django.utils.timezone import now from axes.attempts import ( - get_cache_key, + clean_expired_user_attempts, get_user_attempts, + is_user_attempt_whitelisted, reset_user_attempts, ) from axes.conf import settings from axes.exceptions import AxesSignalPermissionDenied +from axes.handlers.base import AxesBaseHandler from axes.models import AccessLog, AccessAttempt from axes.signals import user_locked_out -from axes.handlers.base import AxesBaseHandler from axes.utils import ( - get_axes_cache, - get_cache_timeout, get_client_ip_address, get_client_path_info, get_client_http_accept, @@ -27,10 +24,6 @@ from axes.utils import ( get_client_user_agent, get_credentials, get_query_str, - is_client_ip_address_blacklisted, - is_client_ip_address_whitelisted, - is_client_method_whitelisted, - is_client_username_whitelisted, ) @@ -42,69 +35,30 @@ class AxesDatabaseHandler(AxesBaseHandler): # pylint: disable=too-many-locals Signal handler implementation that records user login attempts to database and locks users out if necessary. """ - def is_allowed(self, request: HttpRequest, credentials: Optional[Dict[str, Any]] = None) -> bool: - """ - Check if the request or given credentials are already locked by Axes. + def get_failures(self, request, credentials=None, attempt_time=None) -> int: + attempts = get_user_attempts(request, credentials, attempt_time) + return attempts.aggregate(Max('failures_since_start'))['failures_since_start__max'] or 0 - This function is called from - - - function decorators defined in ``axes.decorators``, - - authentication backends defined in ``axes.backends``, and - - signal handlers defined in ``axes.handlers``. - - This function checks the following facts for a given request: - - 1. Is the request IP address _blacklisted_? If it is, return ``False``. - 2. Is the request IP address _whitelisted_? If it is, return ``True``. - 4. Is the request HTTP method _whitelisted_? If it is, return ``True``. - 3. Is the request user _whitelisted_? If it is, return ``True``. - 5. Is failed authentication attempt always allowed to proceed? If it is, return ``True``. - 6. Is failed authentication attempt count over the attempt limit? If it is, return ``False``. - - Refer to the function source code for the exact implementation. - """ - - if is_client_ip_address_blacklisted(request): + def is_locked(self, request, credentials=None, attempt_time=None): + if is_user_attempt_whitelisted(request, credentials): return False - if is_client_ip_address_whitelisted(request): - return True + return super().is_locked(request, credentials, attempt_time) - if is_client_method_whitelisted(request): - return True - - if is_client_username_whitelisted(request, credentials): - return True - - if not settings.AXES_LOCK_OUT_AT_FAILURE: - return True - - # Check failure statistics against cache - cache_hash_key = get_cache_key(request, credentials) - num_failures_cached = get_axes_cache().get(cache_hash_key) - - # Do not hit the database if we have an answer in the cache - if num_failures_cached is not None: - return num_failures_cached < settings.AXES_FAILURE_LIMIT - - # Check failure statistics against database - attempts = get_user_attempts(request, credentials) - - lockouts = attempts.filter( - failures_since_start__gte=settings.AXES_FAILURE_LIMIT, - ) - - return not lockouts.exists() - - def user_login_failed(self, sender, credentials, request, **kwargs): # pylint: disable=too-many-locals + def user_login_failed(self, sender, credentials, request=None, **kwargs): # pylint: disable=too-many-locals """ When user login fails, save AccessAttempt record in database and lock user out if necessary. - :raises AxesSignalPermissionDenied: if user should is locked out + :raises AxesSignalPermissionDenied: if user should be locked out. """ + attempt_time = now() + + # 1. database query: Clean up expired user attempts from the database before logging new attempts + clean_expired_user_attempts(attempt_time) + if request is None: - log.warning('AXES: AxesDatabaseHandler.user_login_failed does not function without a request.') + log.error('AXES: AxesDatabaseHandler.user_login_failed does not function without a request.') return username = get_client_username(request, credentials) @@ -116,45 +70,30 @@ class AxesDatabaseHandler(AxesBaseHandler): # pylint: disable=too-many-locals get_data = get_query_str(request.GET) post_data = get_query_str(request.POST) - attempt_time = now() - if is_client_ip_address_whitelisted(request): - log.info('AXES: Login failed from whitelisted IP %s.', ip_address) + if self.is_whitelisted(request, credentials): + log.info('AXES: Login failed from whitelisted client %s.', client_str) return - attempts = get_user_attempts(request, credentials) - cache_key = get_cache_key(request, credentials) - num_failures_cached = get_axes_cache().get(cache_key) + # 2. database query: Calculate the current maximum failure number from the existing attempts + failures_since_start = 1 + self.get_failures(request, credentials, attempt_time) - if num_failures_cached: - failures_since_start = num_failures_cached - elif attempts: - failures_since_start = attempts.aggregate( - Max('failures_since_start'), - )['failures_since_start__max'] - else: - failures_since_start = 0 + # 3. database query: Insert or update access records with the new failure data + if failures_since_start > 1: + # Update failed attempt information but do not touch the username, IP address, or user agent fields, + # because attackers can request the site with multiple different configurations + # in order to bypass the defense mechanisms that are used by the site. - # add a failed attempt for this user - failures_since_start += 1 - get_axes_cache().set( - cache_key, - failures_since_start, - get_cache_timeout(), - ) - - if attempts: - # Update existing attempt information but do not touch the username, ip_address, or user_agent fields, - # because attackers can request the site with multiple different usernames, addresses, or programs. - - log.info( - 'AXES: Repeated login failure by %s. Count = %d of %d', + log.warning( + 'AXES: Repeated login failure by %s. Count = %d of %d. Updating existing record in the database.', client_str, failures_since_start, settings.AXES_FAILURE_LIMIT, ) separator = '\n---------\n' + + attempts = get_user_attempts(request, credentials, attempt_time) attempts.update( get_data=Concat('get_data', Value(separator + get_data)), post_data=Concat('post_data', Value(separator + post_data)), @@ -164,11 +103,12 @@ class AxesDatabaseHandler(AxesBaseHandler): # pylint: disable=too-many-locals attempt_time=attempt_time, ) else: - # Record failed attempt. Whether or not the username, IP address or user agent is - # used in counting failures is handled elsewhere, and we just record everything here. + # Record failed attempt with all the relevant information. + # Filtering based on username, IP address and user agent handled elsewhere, + # and this handler just records the available information for further use. - log.info( - 'AXES: New login failure by %s. Creating access record.', + log.warning( + 'AXES: New login failure by %s. Creating new record in the database.', client_str, ) @@ -184,11 +124,8 @@ class AxesDatabaseHandler(AxesBaseHandler): # pylint: disable=too-many-locals attempt_time=attempt_time, ) - if not self.is_allowed(request, credentials): - log.warning( - 'AXES: Locked out %s after repeated login failures.', - client_str, - ) + if failures_since_start >= settings.AXES_FAILURE_LIMIT: + log.warning('AXES: Locking out %s after repeated login failures.', client_str) user_locked_out.send( 'axes', @@ -204,6 +141,11 @@ class AxesDatabaseHandler(AxesBaseHandler): # pylint: disable=too-many-locals When user logs in, update the AccessLog related to the user. """ + attempt_time = now() + + # 1. database query: Clean up expired user attempts from the database + clean_expired_user_attempts(attempt_time) + username = user.get_username() credentials = get_credentials(username) ip_address = get_client_ip_address(request) @@ -212,72 +154,48 @@ class AxesDatabaseHandler(AxesBaseHandler): # pylint: disable=too-many-locals http_accept = get_client_http_accept(request) client_str = get_client_str(username, ip_address, user_agent, path_info) - log.info( - 'AXES: Successful login by %s.', - client_str, - ) + log.info('AXES: Successful login by %s.', client_str) if not settings.AXES_DISABLE_SUCCESS_ACCESS_LOG: + # 2. database query: Insert new access logs with login time AccessLog.objects.create( username=username, ip_address=ip_address, user_agent=user_agent, http_accept=http_accept, path_info=path_info, + attempt_time=attempt_time, trusted=True, ) if settings.AXES_RESET_ON_SUCCESS: + # 3. database query: Reset failed attempts for the logging in user count = reset_user_attempts(request, credentials) - log.info( - 'AXES: Deleted %d failed login attempts by %s.', - count, - client_str, - ) + log.info('AXES: Deleted %d failed login attempts by %s from database.', count, client_str) def user_logged_out(self, sender, request, user, **kwargs): # pylint: disable=unused-argument """ When user logs out, update the AccessLog related to the user. """ + attempt_time = now() + + # 1. database query: Clean up expired user attempts from the database + clean_expired_user_attempts(attempt_time) + username = user.get_username() ip_address = get_client_ip_address(request) user_agent = get_client_user_agent(request) path_info = get_client_path_info(request) client_str = get_client_str(username, ip_address, user_agent, path_info) - logout_time = now() - log.info( - 'AXES: Successful logout by %s.', - client_str, - ) + log.info('AXES: Successful logout by %s.', client_str) - if user and not settings.AXES_DISABLE_ACCESS_LOG: + if username and not settings.AXES_DISABLE_ACCESS_LOG: + # 2. database query: Update existing attempt logs with logout time AccessLog.objects.filter( username=username, logout_time__isnull=True, ).update( - logout_time=logout_time, + logout_time=attempt_time, ) - - def post_save_access_attempt(self, instance, **kwargs): # pylint: disable=unused-argument - """ - Update cache after saving AccessAttempts. - """ - - cache_key = get_cache_key(instance) - - if not get_axes_cache().get(cache_key): - get_axes_cache().set( - cache_key, - instance.failures_since_start, - get_cache_timeout(), - ) - - def post_delete_access_attempt(self, instance, **kwargs): # pylint: disable=unused-argument - """ - Update cache after deleting AccessAttempts. - """ - - cache_hash_key = get_cache_key(instance) - get_axes_cache().delete(cache_hash_key) diff --git a/axes/handlers/proxy.py b/axes/handlers/proxy.py index c4e9a7f..81a3fc0 100644 --- a/axes/handlers/proxy.py +++ b/axes/handlers/proxy.py @@ -1,8 +1,8 @@ from logging import getLogger -from typing import Any, Dict, Optional from django.http import HttpRequest from django.utils.module_loading import import_string +from django.utils.timezone import datetime from axes.conf import settings from axes.handlers.base import AxesBaseHandler @@ -37,19 +37,23 @@ class AxesProxyHandler(AxesBaseHandler): return cls.implementation @classmethod - def is_allowed(cls, request: HttpRequest, credentials: Optional[Dict[str, Any]] = None) -> bool: + def is_locked(cls, request: HttpRequest, credentials: dict = None, attempt_time: datetime = None) -> bool: + return cls.get_implementation().is_locked(request, credentials) + + @classmethod + def is_allowed(cls, request: HttpRequest, credentials: dict = None) -> bool: return cls.get_implementation().is_allowed(request, credentials) @classmethod - def user_login_failed(cls, sender: Any, credentials: Dict[str, Any], request: HttpRequest, **kwargs): + def user_login_failed(cls, sender, credentials: dict, request: HttpRequest = None, **kwargs): return cls.get_implementation().user_login_failed(sender, credentials, request, **kwargs) @classmethod - def user_logged_in(cls, sender: Any, request: HttpRequest, user, **kwargs): + def user_logged_in(cls, sender, request: HttpRequest, user, **kwargs): return cls.get_implementation().user_logged_in(sender, request, user, **kwargs) @classmethod - def user_logged_out(cls, sender: Any, request: HttpRequest, user, **kwargs): + def user_logged_out(cls, sender, request: HttpRequest, user, **kwargs): return cls.get_implementation().user_logged_out(sender, request, user, **kwargs) @classmethod diff --git a/axes/management/commands/axes_reset.py b/axes/management/commands/axes_reset.py index f28c9f4..aadf039 100644 --- a/axes/management/commands/axes_reset.py +++ b/axes/management/commands/axes_reset.py @@ -1,6 +1,6 @@ from django.core.management.base import BaseCommand -from axes.utils import reset +from axes.attempts import reset class Command(BaseCommand): diff --git a/axes/management/commands/axes_reset_ip.py b/axes/management/commands/axes_reset_ip.py index 657a1ac..899fdec 100644 --- a/axes/management/commands/axes_reset_ip.py +++ b/axes/management/commands/axes_reset_ip.py @@ -1,6 +1,6 @@ from django.core.management.base import BaseCommand -from axes.utils import reset +from axes.attempts import reset class Command(BaseCommand): diff --git a/axes/management/commands/axes_reset_username.py b/axes/management/commands/axes_reset_username.py index 2e6ec7c..6bbca5f 100644 --- a/axes/management/commands/axes_reset_username.py +++ b/axes/management/commands/axes_reset_username.py @@ -1,6 +1,6 @@ from django.core.management.base import BaseCommand -from axes.utils import reset +from axes.attempts import reset class Command(BaseCommand): diff --git a/axes/tests/base.py b/axes/tests/base.py new file mode 100644 index 0000000..05186df --- /dev/null +++ b/axes/tests/base.py @@ -0,0 +1,162 @@ +from random import choice +from string import ascii_letters, digits +from time import sleep + +from django.contrib.auth import get_user_model +from django.http import HttpRequest +from django.test import TestCase +from django.urls import reverse + +from axes.attempts import reset +from axes.conf import settings +from axes.utils import get_axes_cache, get_cool_off, get_credentials +from axes.models import AccessLog, AccessAttempt + + +class AxesTestCase(TestCase): + """ + Test case using custom settings for testing. + """ + + VALID_USERNAME = 'axes-valid-username' + VALID_PASSWORD = 'axes-valid-password' + VALID_EMAIL = 'axes-valid-email@example.com' + VALID_USER_AGENT = 'axes-user-agent' + VALID_IP_ADDRESS = '127.0.0.1' + + INVALID_USERNAME = 'axes-invalid-username' + INVALID_PASSWORD = 'axes-invalid-password' + INVALID_EMAIL = 'axes-invalid-email@example.com' + + LOCKED_MESSAGE = 'Account locked: too many login attempts.' + LOGOUT_MESSAGE = 'Logged out' + LOGIN_FORM_KEY = '' + + SUCCESS = 200 + ALLOWED = 302 + BLOCKED = 403 + + def setUp(self): + """ + Create a valid user for login. + """ + + self.username = self.VALID_USERNAME + self.password = self.VALID_PASSWORD + self.email = self.VALID_EMAIL + + self.ip_address = self.VALID_IP_ADDRESS + self.user_agent = self.VALID_USER_AGENT + self.path_info = reverse('admin:login') + + self.user = get_user_model().objects.create_superuser( + username=self.username, + password=self.password, + email=self.email, + ) + + self.request = HttpRequest() + self.request.method = 'POST' + self.request.META['REMOTE_ADDR'] = self.ip_address + self.request.META['HTTP_USER_AGENT'] = self.user_agent + self.request.META['PATH_INFO'] = self.path_info + + self.credentials = get_credentials(self.username) + + def tearDown(self): + get_axes_cache().clear() + + def get_kwargs_with_defaults(self, **kwargs): + defaults = { + 'user_agent': self.user_agent, + 'ip_address': self.ip_address, + 'username': self.username, + 'failures_since_start': 1, + } + + defaults.update(kwargs) + return defaults + + def create_attempt(self, **kwargs): + return AccessAttempt.objects.create(**self.get_kwargs_with_defaults(**kwargs)) + + def reset(self, ip=None, username=None): + return reset(ip, username) + + def login(self, is_valid_username=False, is_valid_password=False, **kwargs): + """ + Login a user. + + A valid credential is used when is_valid_username is True, + otherwise it will use a random string to make a failed login. + """ + + if is_valid_username: + username = self.VALID_USERNAME + else: + username = ''.join( + choice(ascii_letters + digits) + for _ in range(10) + ) + + if is_valid_password: + password = self.VALID_PASSWORD + else: + password = self.INVALID_PASSWORD + + post_data = { + 'username': username, + 'password': password, + **kwargs + } + + return self.client.post( + reverse('admin:login'), + post_data, + REMOTE_ADDR=self.ip_address, + HTTP_USER_AGENT=self.user_agent, + ) + + def logout(self): + return self.client.post( + reverse('admin:logout'), + REMOTE_ADDR=self.ip_address, + HTTP_USER_AGENT=self.user_agent, + ) + + def check_login(self): + response = self.login(is_valid_username=True, is_valid_password=True) + self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=self.ALLOWED, html=True) + + def almost_lockout(self): + for _ in range(1, settings.AXES_FAILURE_LIMIT): + response = self.login() + self.assertContains(response, self.LOGIN_FORM_KEY, html=True) + + def lockout(self): + self.almost_lockout() + return self.login() + + def check_lockout(self): + response = self.lockout() + self.assertContains(response, self.LOCKED_MESSAGE, status_code=self.BLOCKED) + + def cool_off(self): + sleep(get_cool_off().total_seconds()) + + def check_logout(self): + response = self.logout() + self.assertContains(response, self.LOGOUT_MESSAGE, status_code=self.SUCCESS) + + def check_handler(self): + """ + Check a handler and its basic functionality with lockouts, cool offs, login, and logout. + + This is a check that is intended to successfully run for each and every new handler. + """ + + self.check_lockout() + self.cool_off() + self.check_login() + self.check_logout() + diff --git a/axes/tests/settings.py b/axes/tests/settings.py index 825172b..e150568 100644 --- a/axes/tests/settings.py +++ b/axes/tests/settings.py @@ -1,3 +1,6 @@ +import os.path +import tempfile + DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', @@ -7,7 +10,11 @@ DATABASES = { CACHES = { 'default': { - 'BACKEND': 'django.core.cache.backends.dummy.DummyCache' + 'BACKEND': 'django.core.cache.backends.filebased.FileBasedCache', + 'LOCATION': os.path.abspath(os.path.join(tempfile.gettempdir(), 'axes')), + 'OPTIONS': { + 'MAX_ENTRIES': 420, + } } } diff --git a/axes/tests/test_attempt.py b/axes/tests/test_attempt.py deleted file mode 100644 index 2db0a77..0000000 --- a/axes/tests/test_attempt.py +++ /dev/null @@ -1,590 +0,0 @@ -import datetime -import hashlib -import random -import string -import time -from unittest.mock import patch, MagicMock - -from django.contrib.auth import authenticate -from django.contrib.auth.models import User -from django.http import HttpRequest -from django.test import TestCase, override_settings -from django.test.client import RequestFactory -from django.urls import reverse - -from axes.attempts import ( - get_cache_key, - get_client_parameters, - get_user_attempts, -) -from axes.conf import settings -from axes.models import AccessAttempt, AccessLog -from axes.signals import user_locked_out - - -class AccessAttemptTest(TestCase): - """ - Test case using custom settings for testing. - """ - - VALID_USERNAME = 'valid-username' - VALID_PASSWORD = 'valid-password' - LOCKED_MESSAGE = 'Account locked: too many login attempts.' - LOGIN_FORM_KEY = '' - - def _login(self, is_valid_username=False, is_valid_password=False, **kwargs): - """ - Login a user. - - A valid credential is used when is_valid_username is True, - otherwise it will use a random string to make a failed login. - """ - - if is_valid_username: - # Use a valid username - username = self.VALID_USERNAME - else: - # Generate a wrong random username - chars = string.ascii_uppercase + string.digits - username = ''.join(random.choice(chars) for _ in range(10)) - - if is_valid_password: - password = self.VALID_PASSWORD - else: - password = 'invalid-password' - - post_data = { - 'username': username, - 'password': password, - 'this_is_the_login_form': 1, - } - - post_data.update(kwargs) - - return self.client.post( - reverse('admin:login'), - post_data, - HTTP_USER_AGENT='test-browser', - ) - - def setUp(self): - """ - Create a valid user for login. - """ - - self.username = self.VALID_USERNAME - self.ip_address = '127.0.0.1' - self.user_agent = 'test-browser' - - self.user = User.objects.create_superuser( - username=self.VALID_USERNAME, - email='test@example.com', - password=self.VALID_PASSWORD, - ) - - def test_failure_limit_once(self): - """ - Test the login lock trying to login one more time than failure limit. - """ - - # test until one try before the limit - for _ in range(1, settings.AXES_FAILURE_LIMIT): - response = self._login() - # Check if we are in the same login page - self.assertContains(response, self.LOGIN_FORM_KEY, html=True) - - # So, we shouldn't have gotten a lock-out yet. - # But we should get one now - response = self._login() - self.assertContains(response, self.LOCKED_MESSAGE, status_code=403) - - def test_failure_limit_many(self): - """ - Test the login lock trying to login a lot of times more than failure limit. - """ - - for _ in range(1, settings.AXES_FAILURE_LIMIT): - response = self._login() - # Check if we are in the same login page - self.assertContains(response, self.LOGIN_FORM_KEY, html=True) - - # So, we shouldn't have gotten a lock-out yet. - # We should get a locked message each time we try again - for _ in range(random.randrange(1, settings.AXES_FAILURE_LIMIT)): - response = self._login() - self.assertContains(response, self.LOCKED_MESSAGE, status_code=403) - - def test_valid_login(self): - """ - Test a valid login for a real username. - """ - - response = self._login(is_valid_username=True, is_valid_password=True) - self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=302, html=True) - - def test_valid_logout(self): - """ - Test a valid logout and make sure the logout_time is updated. - """ - - response = self._login(is_valid_username=True, is_valid_password=True) - self.assertEqual(AccessLog.objects.latest('id').logout_time, None) - - response = self.client.get(reverse('admin:logout')) - self.assertNotEqual(AccessLog.objects.latest('id').logout_time, None) - self.assertContains(response, 'Logged out') - - @override_settings(AXES_COOLOFF_TIME=datetime.timedelta(milliseconds=420)) - def test_cool_off_on_login(self): - """ - Test if the cooling time allows a user to login. - """ - - self.test_failure_limit_once() - - # Wait for the cooling off period - time.sleep(settings.AXES_COOLOFF_TIME.total_seconds()) - - # It should be possible to login again, make sure it is. - self.test_valid_login() - - @override_settings(AXES_COOLOFF_TIME=datetime.timedelta(milliseconds=420)) - @patch('axes.attempts.get_axes_cache') - def test_cooling_off_on_get_user_attempts_updates_cache(self, get_cache): - cache = MagicMock() - cache.get.return_value = 1 - cache.set.return_value = None - get_cache.return_value = cache - - attempt = AccessAttempt.objects.create( - username=self.username, - ip_address=self.ip_address, - user_agent=self.user_agent, - failures_since_start=0, - ) - - request = HttpRequest() - request.META['REMOTE_ADDR'] = self.ip_address - request.META['HTTP_USER_AGENT'] = self.user_agent - credentials = {'username': self.username} - - # Check that the function does nothing if cool off has not passed - cache.get.assert_not_called() - cache.set.assert_not_called() - - self.assertEqual( - list(get_user_attempts(request, credentials)), - [attempt], - ) - - cache.get.assert_not_called() - cache.set.assert_not_called() - - time.sleep(settings.AXES_COOLOFF_TIME.total_seconds()) - - self.assertEqual( - list(get_user_attempts(request, credentials)), - [], - ) - - self.assertTrue(cache.get.call_count) - self.assertTrue(cache.set.call_count) - - def test_long_user_agent_valid(self): - """ - Test if can handle a long user agent. - """ - - long_user_agent = 'ie6' * 1024 - response = self._login( - is_valid_username=True, - is_valid_password=True, - user_agent=long_user_agent, - ) - self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=302, html=True) - - def test_long_user_agent_not_valid(self): - """ - Test if can handle a long user agent with failure. - """ - - long_user_agent = 'ie6' * 1024 - for _ in range(settings.AXES_FAILURE_LIMIT + 1): - response = self._login(user_agent=long_user_agent) - - self.assertContains(response, self.LOCKED_MESSAGE, status_code=403) - - def test_reset_ip(self): - """ - Test resetting all attempts for an IP address. - """ - - # Make a lockout - self.test_failure_limit_once() - - # Reset the ip so we can try again - AccessAttempt.objects.filter(ip_address='127.0.0.1').delete() - - # Make a login attempt again - self.test_valid_login() - - def test_reset_all(self): - """ - Test resetting all attempts. - """ - - # Make a lockout - self.test_failure_limit_once() - - # Reset all attempts so we can try again - AccessAttempt.objects.all().delete() - - # Make a login attempt again - self.test_valid_login() - - @override_settings( - AXES_ONLY_USER_FAILURES=True, - ) - def test_get_filter_kwargs_user(self): - self.assertEqual( - dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), - {'username': self.username}, - ) - - @override_settings( - AXES_ONLY_USER_FAILURES=False, - AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=False, - AXES_USE_USER_AGENT=False, - ) - def test_get_filter_kwargs_ip(self): - self.assertEqual( - dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), - {'ip_address': self.ip_address}, - ) - - @override_settings( - AXES_ONLY_USER_FAILURES=False, - AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=True, - AXES_USE_USER_AGENT=False, - ) - def test_get_filter_kwargs_user_and_ip(self): - self.assertEqual( - dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), - {'username': self.username, 'ip_address': self.ip_address}, - ) - - @override_settings( - AXES_ONLY_USER_FAILURES=False, - AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=False, - AXES_USE_USER_AGENT=True, - ) - def test_get_filter_kwargs_ip_and_agent(self): - self.assertEqual( - dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), - {'ip_address': self.ip_address, 'user_agent': self.user_agent}, - ) - - @override_settings( - AXES_ONLY_USER_FAILURES=False, - AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=True, - AXES_USE_USER_AGENT=True, - ) - def test_get_filter_kwargs_user_ip_agent(self): - self.assertEqual( - dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), - {'username': self.username, 'ip_address': self.ip_address, 'user_agent': self.user_agent}, - ) - - @patch('axes.utils.get_client_ip_address', return_value='127.0.0.1') - def test_get_cache_key(self, _): - """ - Test the cache key format. - """ - - # Getting cache key from request - ip_address = '127.0.0.1' - cache_hash_key = 'axes-{}'.format( - hashlib.md5(ip_address.encode()).hexdigest() - ) - - request_factory = RequestFactory() - request = request_factory.post( - '/admin/login/', - data={ - 'username': self.VALID_USERNAME, - 'password': 'test', - }, - ) - - self.assertEqual(cache_hash_key, get_cache_key(request)) - - # Getting cache key from AccessAttempt Object - attempt = AccessAttempt( - user_agent='', - ip_address=ip_address, - username=self.VALID_USERNAME, - get_data='', - post_data='', - http_accept=request.META.get('HTTP_ACCEPT', ''), - path_info=request.META.get('PATH_INFO', ''), - failures_since_start=0, - ) - - self.assertEqual(cache_hash_key, get_cache_key(attempt)) - - @patch('axes.utils.get_client_ip_address', return_value='127.0.0.1') - def test_get_cache_key_credentials(self, _): - """ - Test the cache key format. - """ - - # Getting cache key from request - ip_address = '127.0.0.1' - cache_hash_key = 'axes-{}'.format( - hashlib.md5(ip_address.encode()).hexdigest() - ) - - request_factory = RequestFactory() - request = request_factory.post( - '/admin/login/', - data={ - 'username': self.VALID_USERNAME, - 'password': 'test' - } - ) - - # Difference between the upper test: new call signature with credentials - credentials = {'username': self.VALID_USERNAME} - - self.assertEqual(cache_hash_key, get_cache_key(request, credentials)) - - # Getting cache key from AccessAttempt Object - attempt = AccessAttempt( - user_agent='', - ip_address=ip_address, - username=self.VALID_USERNAME, - get_data='', - post_data='', - http_accept=request.META.get('HTTP_ACCEPT', ''), - path_info=request.META.get('PATH_INFO', ''), - failures_since_start=0, - ) - self.assertEqual(cache_hash_key, get_cache_key(attempt)) - - def test_send_lockout_signal(self): - """ - Test if the lockout signal is emitted. - """ - - # this "hack" is needed so we don't have to use global variables or python3 features - class Scope(object): pass - scope = Scope() - scope.signal_received = 0 - - def signal_handler(request, username, ip_address, *args, **kwargs): # pylint: disable=unused-argument - scope.signal_received += 1 - self.assertIsNotNone(request) - - # Connect signal handler - user_locked_out.connect(signal_handler) - - # Make a lockout - self.test_failure_limit_once() - self.assertEqual(scope.signal_received, 1) - - AccessAttempt.objects.all().delete() - - # Make another lockout - self.test_failure_limit_once() - self.assertEqual(scope.signal_received, 2) - - @override_settings(AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=True) - def test_lockout_by_combination_user_and_ip(self): - """ - Test login failure when AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP is True. - """ - - # test until one try before the limit - for _ in range(1, settings.AXES_FAILURE_LIMIT): - response = self._login( - is_valid_username=True, - is_valid_password=False, - ) - # Check if we are in the same login page - self.assertContains(response, self.LOGIN_FORM_KEY, html=True) - - # So, we shouldn't have gotten a lock-out yet. - # But we should get one now - response = self._login(is_valid_username=True, is_valid_password=False) - self.assertContains(response, self.LOCKED_MESSAGE, status_code=403) - - @override_settings(AXES_ONLY_USER_FAILURES=True) - def test_lockout_by_user_only(self): - """ - Test login failure when AXES_ONLY_USER_FAILURES is True. - """ - - # test until one try before the limit - for _ in range(1, settings.AXES_FAILURE_LIMIT): - response = self._login( - is_valid_username=True, - is_valid_password=False, - ) - # Check if we are in the same login page - self.assertContains(response, self.LOGIN_FORM_KEY, html=True) - - # So, we shouldn't have gotten a lock-out yet. - # But we should get one now - response = self._login(is_valid_username=True, is_valid_password=False) - self.assertContains(response, self.LOCKED_MESSAGE, status_code=403) - - # reset the username only and make sure we can log in now even though - # our IP has failed each time - AccessAttempt.objects.filter(username=self.VALID_USERNAME).delete() - response = self._login( - is_valid_username=True, - is_valid_password=True, - ) - # Check if we are still in the login page - self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=302, html=True) - - # now create failure_limit + 1 failed logins and then we should still - # be able to login with valid_username - for _ in range(settings.AXES_FAILURE_LIMIT): - response = self._login( - is_valid_username=False, - is_valid_password=False, - ) - # Check if we can still log in with valid user - response = self._login(is_valid_username=True, is_valid_password=True) - self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=302, html=True) - - def test_log_data_truncated(self): - """ - Test that get_query_str properly truncates data to the max_length (default 1024). - """ - - # An impossibly large post dict - extra_data = {string.ascii_letters * x: x for x in range(0, 1000)} - self._login(**extra_data) - self.assertEqual( - len(AccessAttempt.objects.latest('id').post_data), 1024 - ) - - @override_settings(AXES_DISABLE_SUCCESS_ACCESS_LOG=True) - def test_valid_logout_without_success_log(self): - AccessLog.objects.all().delete() - - response = self._login(is_valid_username=True, is_valid_password=True) - response = self.client.get(reverse('admin:logout')) - - self.assertEqual(AccessLog.objects.all().count(), 0) - self.assertContains(response, 'Logged out', html=True) - - @override_settings(AXES_DISABLE_SUCCESS_ACCESS_LOG=True) - def test_valid_login_without_success_log(self): - """ - Test that a valid login does not generate an AccessLog when DISABLE_SUCCESS_ACCESS_LOG is True. - """ - - AccessLog.objects.all().delete() - - response = self._login(is_valid_username=True, is_valid_password=True) - - self.assertEqual(response.status_code, 302) - self.assertEqual(AccessLog.objects.all().count(), 0) - - @override_settings(AXES_DISABLE_ACCESS_LOG=True) - def test_valid_logout_without_log(self): - AccessLog.objects.all().delete() - - response = self._login(is_valid_username=True, is_valid_password=True) - response = self.client.get(reverse('admin:logout')) - - self.assertEqual(AccessLog.objects.first().logout_time, None) - self.assertContains(response, 'Logged out', html=True) - - @override_settings(AXES_DISABLE_ACCESS_LOG=True) - def test_non_valid_login_without_log(self): - """ - Test that a non-valid login does generate an AccessLog when DISABLE_ACCESS_LOG is True. - """ - AccessLog.objects.all().delete() - - response = self._login(is_valid_username=True, is_valid_password=False) - self.assertEqual(response.status_code, 200) - - self.assertEqual(AccessLog.objects.all().count(), 0) - - @override_settings(AXES_DISABLE_ACCESS_LOG=True) - def test_check_is_not_made_on_GET(self): - AccessLog.objects.all().delete() - - response = self.client.get(reverse('admin:login')) - self.assertEqual(response.status_code, 200) - - response = self._login(is_valid_username=True, is_valid_password=True) - self.assertEqual(response.status_code, 302) - - response = self.client.get(reverse('admin:index')) - self.assertEqual(response.status_code, 200) - - def test_custom_authentication_backend(self): - """ - Test that log_user_login_failed skips if an attempt to authenticate with a custom authentication backend fails. - """ - - request = HttpRequest() - request.META['REMOTE_ADDR'] = '127.0.0.1' - authenticate(request=request, foo='bar') - self.assertEqual(AccessLog.objects.all().count(), 0) - - def _assert_resets_on_success(self): - """ - Sets the AXES_RESET_ON_SUCCESS up for testing. - """ - - # test until one try before the limit - for _ in range(settings.AXES_FAILURE_LIMIT - 1): - response = self._login() - # Check if we are in the same login page - self.assertContains(response, self.LOGIN_FORM_KEY, html=True) - - # Perform a valid login - response = self._login(is_valid_username=True, is_valid_password=True) - self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=302, html=True) - - return self._login() - - # by default, AXES_RESET_ON_SUCCESS = False - def test_reset_on_success_default(self): - """ - Test that the failure attempts does not reset after one successful attempt by default. - """ - - response = self._assert_resets_on_success() - - # So, we shouldn't have found a lock-out yet. - # But we should find one now - self.assertContains(response, self.LOCKED_MESSAGE, status_code=403) - - @override_settings(AXES_RESET_ON_SUCCESS=True) - def test_reset_on_success(self): - """ - Test that the failure attempts resets after one successful attempt when using the corresponding setting. - """ - - response = self._assert_resets_on_success() - - # So, we shouldn't have found a lock-out yet. - # And we shouldn't find one now - self.assertContains(response, self.LOGIN_FORM_KEY, html=True) - for _ in range(settings.AXES_FAILURE_LIMIT - 2): - response = self._login() - # Check if we are on the same login page. - self.assertContains(response, self.LOGIN_FORM_KEY, html=True) - - # But we should find one now - response = self._login() - self.assertContains(response, self.LOCKED_MESSAGE, status_code=403) - diff --git a/axes/tests/test_attempts.py b/axes/tests/test_attempts.py new file mode 100644 index 0000000..9e5efe0 --- /dev/null +++ b/axes/tests/test_attempts.py @@ -0,0 +1,54 @@ +from unittest.mock import patch + +from django.contrib.auth import get_user_model +from django.http import HttpRequest + +from axes.attempts import ( + is_user_attempt_whitelisted, + reset, +) +from axes.models import AccessAttempt +from axes.tests.base import AxesTestCase + + +class ResetTestCase(AxesTestCase): + def test_reset(self): + self.create_attempt() + reset() + self.assertFalse(AccessAttempt.objects.count()) + + def test_reset_ip(self): + self.create_attempt(ip_address=self.ip_address) + reset(ip=self.ip_address) + self.assertFalse(AccessAttempt.objects.count()) + + def test_reset_username(self): + self.create_attempt(username=self.username) + reset(username=self.username) + self.assertFalse(AccessAttempt.objects.count()) + + +class UserWhitelistTestCase(AxesTestCase): + def setUp(self): + self.user_model = get_user_model() + self.user = self.user_model.objects.create(username='jane.doe') + self.request = HttpRequest() + + def test_is_client_username_whitelisted(self): + with patch.object(self.user_model, 'nolockout', True, create=True): + self.assertTrue(is_user_attempt_whitelisted( + self.request, + {self.user_model.USERNAME_FIELD: self.user.username}, + )) + + def test_is_client_username_whitelisted_not(self): + self.assertFalse(is_user_attempt_whitelisted( + self.request, + {self.user_model.USERNAME_FIELD: self.user.username}, + )) + + def test_is_client_username_whitelisted_does_not_exist(self): + self.assertFalse(is_user_attempt_whitelisted( + self.request, + {self.user_model.USERNAME_FIELD: 'not.' + self.user.username}, + )) diff --git a/axes/tests/test_backends.py b/axes/tests/test_backends.py index 4374926..6d338c3 100644 --- a/axes/tests/test_backends.py +++ b/axes/tests/test_backends.py @@ -1,12 +1,11 @@ from unittest.mock import patch, MagicMock -from django.test import TestCase - from axes.backends import AxesBackend from axes.exceptions import AxesBackendRequestParameterRequired, AxesBackendPermissionDenied +from axes.tests.base import AxesTestCase -class BackendTestCase(TestCase): +class BackendTestCase(AxesTestCase): def test_authenticate_raises_on_missing_request(self): request = None diff --git a/axes/tests/test_checks.py b/axes/tests/test_checks.py index fe58f47..217f04c 100644 --- a/axes/tests/test_checks.py +++ b/axes/tests/test_checks.py @@ -1,11 +1,12 @@ from django.core.checks import run_checks, Error -from django.test import TestCase, override_settings +from django.test import override_settings from axes.checks import Messages, Hints, Codes from axes.conf import settings +from axes.tests.base import AxesTestCase -class CacheCheckTestCase(TestCase): +class CacheCheckTestCase(AxesTestCase): @override_settings( AXES_CACHE='nonexistent', ) diff --git a/axes/tests/test_decorators.py b/axes/tests/test_decorators.py index 593c2a8..e4440f0 100644 --- a/axes/tests/test_decorators.py +++ b/axes/tests/test_decorators.py @@ -1,12 +1,12 @@ from unittest.mock import MagicMock, patch from django.http import HttpResponse -from django.test import TestCase from axes.decorators import axes_dispatch, axes_form_invalid +from axes.tests.base import AxesTestCase -class DecoratorTestCase(TestCase): +class DecoratorTestCase(AxesTestCase): SUCCESS_RESPONSE = HttpResponse(status=200, content='Dispatched') LOCKOUT_RESPONSE = HttpResponse(status=403, content='Locked out') diff --git a/axes/tests/test_handlers.py b/axes/tests/test_handlers.py index 36948cf..3e93981 100644 --- a/axes/tests/test_handlers.py +++ b/axes/tests/test_handlers.py @@ -1,20 +1,42 @@ from unittest.mock import MagicMock, patch from django.http import HttpRequest -from django.test import TestCase, override_settings +from django.test import override_settings +from django.utils.timezone import timedelta from axes.handlers.proxy import AxesProxyHandler -from axes.models import AccessAttempt +from axes.tests.base import AxesTestCase +from axes.utils import get_client_str -class AxesBaseHandlerTestCase(TestCase): - @override_settings(AXES_HANDLER='axes.handlers.base.AxesBaseHandler') +@override_settings(AXES_HANDLER='axes.handlers.base.AxesBaseHandler') +class AxesBaseHandlerTestCase(AxesTestCase): def test_base_handler_raises_on_undefined_is_allowed_to_authenticate(self): with self.assertRaises(NotImplementedError): - AxesProxyHandler.is_allowed(HttpRequest(), {}) + AxesProxyHandler.is_allowed(self.request, {}) + + @override_settings(AXES_IP_BLACKLIST=['127.0.0.1']) + def test_is_allowed_with_blacklisted_ip_address(self): + self.assertFalse(AxesProxyHandler.is_allowed(self.request)) + + @override_settings( + AXES_NEVER_LOCKOUT_WHITELIST=True, + AXES_IP_WHITELIST=['127.0.0.1'], + ) + def test_is_allowed_with_whitelisted_ip_address(self): + self.assertTrue(AxesProxyHandler.is_allowed(self.request)) + + @override_settings(AXES_NEVER_LOCKOUT_GET=True) + def test_is_allowed_with_whitelisted_method(self): + self.request.method = 'GET' + self.assertTrue(AxesProxyHandler.is_allowed(self.request)) + + @override_settings(AXES_LOCK_OUT_AT_FAILURE=False) + def test_is_allowed_no_lock_out(self): + self.assertTrue(AxesProxyHandler.is_allowed(self.request)) -class AxesProxyHandlerTestCase(TestCase): +class AxesProxyHandlerTestCase(AxesTestCase): def setUp(self): self.sender = MagicMock() self.credentials = MagicMock() @@ -60,86 +82,68 @@ class AxesProxyHandlerTestCase(TestCase): self.assertTrue(handler.post_delete_access_attempt.called) -class AxesDatabaseHandlerTestCase(TestCase): - def setUp(self): - self.attempt = AccessAttempt.objects.create( - username='jane.doe', - ip_address='127.0.0.1', - user_agent='test-browser', - failures_since_start=42, - ) +class AxesHandlerTestCase(AxesTestCase): + def check_whitelist(self, log): + with override_settings( + AXES_NEVER_LOCKOUT_WHITELIST=True, + AXES_IP_WHITELIST=[self.ip_address], + ): + AxesProxyHandler.user_login_failed(sender=None, request=self.request, credentials=self.credentials) + client_str = get_client_str(self.username, self.ip_address, self.user_agent, self.path_info) + log.info.assert_called_with('AXES: Login failed from whitelisted client %s.', client_str) - self.request = HttpRequest() - self.request.method = 'POST' - self.request.META['REMOTE_ADDR'] = '127.0.0.1' - - @patch('axes.handlers.database.log') - def test_user_login_failed_no_request(self, log): + def check_empty_request(self, log, handler): AxesProxyHandler.user_login_failed(sender=None, credentials={}, request=None) - log.warning.assert_called_with( - 'AXES: AxesDatabaseHandler.user_login_failed does not function without a request.' + log.error.assert_called_with( + 'AXES: {handler}.user_login_failed does not function without a request.'.format(handler=handler) ) - @patch('axes.handlers.database.get_client_ip_address', return_value='127.0.0.1') - @patch('axes.handlers.database.is_client_ip_address_whitelisted', return_value=True) + +@override_settings( + AXES_HANDLER='axes.handlers.database.AxesDatabaseHandler', + AXES_COOLOFF_TIME=timedelta(seconds=1), + AXES_RESET_ON_SUCCESS=True, +) +class AxesDatabaseHandlerTestCase(AxesHandlerTestCase): + @override_settings(AXES_RESET_ON_SUCCESS=True) + def test_handler(self): + self.check_handler() + + @override_settings(AXES_RESET_ON_SUCCESS=False) + def test_handler_without_reset(self): + self.check_handler() + @patch('axes.handlers.database.log') - def test_user_login_failed_whitelist(self, log, _, __): - AxesProxyHandler.user_login_failed(sender=None, credentials={}, request=self.request) - log.info.assert_called_with('AXES: Login failed from whitelisted IP %s.', '127.0.0.1') + def test_empty_request(self, log): + self.check_empty_request(log, 'AxesDatabaseHandler') - @patch('axes.handlers.database.get_axes_cache') - def test_post_save_access_attempt_updates_cache(self, get_cache): - cache = MagicMock() - cache.get.return_value = None - cache.set.return_value = None + @patch('axes.handlers.database.log') + def test_whitelist(self, log): + self.check_whitelist(log) - get_cache.return_value = cache + @patch('axes.handlers.database.is_user_attempt_whitelisted', return_value=True) + def test_user_whitelisted(self, is_whitelisted): + self.assertFalse(AxesProxyHandler().is_locked(self.request, self.credentials)) + self.assertEqual(1, is_whitelisted.call_count) - self.assertFalse(cache.get.call_count) - self.assertFalse(cache.set.call_count) - AxesProxyHandler.post_save_access_attempt(self.attempt) +@override_settings( + AXES_HANDLER='axes.handlers.cache.AxesCacheHandler', + AXES_COOLOFF_TIME=timedelta(seconds=1), +) +class AxesCacheHandlerTestCase(AxesHandlerTestCase): + @override_settings(AXES_RESET_ON_SUCCESS=True) + def test_handler(self): + self.check_handler() - self.assertTrue(cache.get.call_count) - self.assertTrue(cache.set.call_count) + @override_settings(AXES_RESET_ON_SUCCESS=False) + def test_handler_without_reset(self): + self.check_handler() - @patch('axes.handlers.database.get_axes_cache') - def test_user_login_failed_utilizes_cache(self, get_cache): - cache = MagicMock() - cache.get.return_value = 1 - get_cache.return_value = cache + @patch('axes.handlers.cache.log') + def test_empty_request(self, log): + self.check_empty_request(log, 'AxesCacheHandler') - sender = MagicMock() - credentials = {'username': self.attempt.username} - - self.assertFalse(cache.get.call_count) - - AxesProxyHandler.user_login_failed(sender, credentials, self.request) - - self.assertTrue(cache.get.call_count) - - @override_settings(AXES_LOCK_OUT_AT_FAILURE=True) - @override_settings(AXES_FAILURE_LIMIT=40) - @patch('axes.handlers.database.get_axes_cache') - def test_is_already_locked_cache(self, get_cache): - cache = MagicMock() - cache.get.return_value = 42 - get_cache.return_value = cache - - self.assertFalse(AxesProxyHandler.is_allowed(self.request, {})) - - @override_settings(AXES_LOCK_OUT_AT_FAILURE=False) - @override_settings(AXES_FAILURE_LIMIT=40) - @patch('axes.handlers.database.get_axes_cache') - def test_is_already_locked_do_not_lock_out_at_failure(self, get_cache): - cache = MagicMock() - cache.get.return_value = 42 - get_cache.return_value = cache - - self.assertTrue(AxesProxyHandler.is_allowed(self.request, {})) - - @override_settings(AXES_NEVER_LOCKOUT_GET=True) - def test_is_already_locked_never_lockout_get(self): - self.request.method = 'GET' - - self.assertTrue(AxesProxyHandler.is_allowed(self.request, {})) + @patch('axes.handlers.cache.log') + def test_whitelist(self, log): + self.check_whitelist(log) diff --git a/axes/tests/test_logging.py b/axes/tests/test_logging.py index d2b956e..1a4bccd 100644 --- a/axes/tests/test_logging.py +++ b/axes/tests/test_logging.py @@ -1,13 +1,18 @@ from unittest.mock import patch -from django.test import TestCase, override_settings +from django.contrib.auth import authenticate +from django.http import HttpRequest +from django.test import override_settings +from django.urls import reverse from axes.apps import AppConfig +from axes.models import AccessAttempt, AccessLog +from axes.tests.base import AxesTestCase @patch('axes.apps.AppConfig.logging_initialized', False) @patch('axes.apps.log') -class AppsTestCase(TestCase): +class AppsTestCase(AxesTestCase): def test_axes_config_log_re_entrant(self, log): """ Test that initialize call count does not increase on repeat calls. @@ -41,3 +46,85 @@ class AppsTestCase(TestCase): def test_axes_config_log_user_ip(self, log): AppConfig.initialize() log.info.assert_called_with('AXES: blocking by combination of username and IP.') + + +class AccessLogTestCase(AxesTestCase): + def test_authenticate_invalid_parameters(self): + """ + Test that logging is not done if an attempt to authenticate with a custom authentication backend fails. + """ + + request = HttpRequest() + request.META['REMOTE_ADDR'] = '127.0.0.1' + authenticate(request=request, foo='bar') + self.assertEqual(AccessLog.objects.all().count(), 0) + + def test_access_log_on_logout(self): + """ + Test a valid logout and make sure the logout_time is updated. + """ + + self.login(is_valid_username=True, is_valid_password=True) + self.assertIsNone(AccessLog.objects.latest('id').logout_time) + + response = self.client.get(reverse('admin:logout')) + self.assertContains(response, 'Logged out') + + self.assertIsNotNone(AccessLog.objects.latest('id').logout_time) + + def test_log_data_truncated(self): + """ + Test that get_query_str properly truncates data to the max_length (default 1024). + """ + + # An impossibly large post dict + extra_data = {'a' * x: x for x in range(1024)} + self.login(**extra_data) + self.assertEqual( + len(AccessAttempt.objects.latest('id').post_data), 1024 + ) + + @override_settings(AXES_DISABLE_SUCCESS_ACCESS_LOG=True) + def test_valid_logout_without_success_log(self): + AccessLog.objects.all().delete() + + response = self.login(is_valid_username=True, is_valid_password=True) + response = self.client.get(reverse('admin:logout')) + + self.assertEqual(AccessLog.objects.all().count(), 0) + self.assertContains(response, 'Logged out', html=True) + + @override_settings(AXES_DISABLE_SUCCESS_ACCESS_LOG=True) + def test_valid_login_without_success_log(self): + """ + Test that a valid login does not generate an AccessLog when DISABLE_SUCCESS_ACCESS_LOG is True. + """ + + AccessLog.objects.all().delete() + + response = self.login(is_valid_username=True, is_valid_password=True) + + self.assertEqual(response.status_code, 302) + self.assertEqual(AccessLog.objects.all().count(), 0) + + @override_settings(AXES_DISABLE_ACCESS_LOG=True) + def test_valid_logout_without_log(self): + AccessLog.objects.all().delete() + + response = self.login(is_valid_username=True, is_valid_password=True) + response = self.client.get(reverse('admin:logout')) + + self.assertEqual(AccessLog.objects.first().logout_time, None) + self.assertContains(response, 'Logged out', html=True) + + @override_settings(AXES_DISABLE_ACCESS_LOG=True) + def test_non_valid_login_without_log(self): + """ + Test that a non-valid login does generate an AccessLog when DISABLE_ACCESS_LOG is True. + """ + AccessLog.objects.all().delete() + + response = self.login(is_valid_username=True, is_valid_password=False) + self.assertEqual(response.status_code, 200) + + self.assertEqual(AccessLog.objects.all().count(), 0) diff --git a/axes/tests/test_login.py b/axes/tests/test_login.py index 6e3a5c1..a72f562 100644 --- a/axes/tests/test_login.py +++ b/axes/tests/test_login.py @@ -1,15 +1,19 @@ """ -Test access from purely the +Integration tests for the login handling. + +TODO: Clean up the tests in this module. """ -from django.test import TestCase, override_settings +from django.test import override_settings from django.urls import reverse from django.contrib.auth import get_user_model from axes.conf import settings +from axes.models import AccessLog, AccessAttempt +from axes.tests.base import AxesTestCase -class LoginTestCase(TestCase): +class LoginTestCase(AxesTestCase): """ Test for lockouts under different configurations and circumstances to prevent false positives and false negatives. @@ -21,7 +25,15 @@ class LoginTestCase(TestCase): IP_2 = '10.2.2.2' USER_1 = 'valid-user-1' USER_2 = 'valid-user-2' + EMAIL_1 = 'valid-email-1@example.com' + EMAIL_2 = 'valid-email-2@example.com' + + VALID_USERNAME = USER_1 + VALID_EMAIL = EMAIL_1 VALID_PASSWORD = 'valid-password' + + VALID_IP_ADDRESS = IP_1 + WRONG_PASSWORD = 'wrong-password' LOCKED_MESSAGE = 'Account locked: too many login attempts.' LOGIN_FORM_KEY = '' @@ -38,7 +50,6 @@ class LoginTestCase(TestCase): post_data = { 'username': username, 'password': password, - 'this_is_the_login_form': 1, } post_data.update(kwargs) @@ -70,17 +81,119 @@ class LoginTestCase(TestCase): Create two valid users for authentication. """ - self.user = get_user_model().objects.create_superuser( - username=self.USER_1, - email='test_1@example.com', - password=self.VALID_PASSWORD, - ) - self.user = get_user_model().objects.create_superuser( + super().setUp() + + self.user2 = get_user_model().objects.create_superuser( username=self.USER_2, - email='test_2@example.com', + email=self.EMAIL_2, password=self.VALID_PASSWORD, + is_staff=True, + is_superuser=True, ) + def test_login(self): + """ + Test a valid login for a real username. + """ + + response = self._login(self.username, self.password) + self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=self.ALLOWED, html=True) + + def test_lockout_limit_once(self): + """ + Test the login lock trying to login one more time than failure limit. + """ + + response = self.lockout() + self.assertContains(response, self.LOCKED_MESSAGE, status_code=self.BLOCKED) + + def test_lockout_limit_many(self): + """ + Test the login lock trying to login a lot of times more than failure limit. + """ + + self.lockout() + + for _ in range(settings.AXES_FAILURE_LIMIT): + response = self.login() + self.assertContains(response, self.LOCKED_MESSAGE, status_code=self.BLOCKED) + + @override_settings(AXES_RESET_ON_SUCCESS=False) + def test_reset_on_success_false(self): + self.almost_lockout() + self.login(is_valid_username=True, is_valid_password=True) + + response = self.login() + self.assertContains(response, self.LOCKED_MESSAGE, status_code=self.BLOCKED) + self.assertTrue(AccessAttempt.objects.count()) + + @override_settings(AXES_RESET_ON_SUCCESS=True) + def test_reset_on_success_true(self): + self.almost_lockout() + self.assertTrue(AccessAttempt.objects.count()) + + self.login(is_valid_username=True, is_valid_password=True) + self.assertFalse(AccessAttempt.objects.count()) + + response = self.lockout() + self.assertContains(response, self.LOCKED_MESSAGE, status_code=self.BLOCKED) + self.assertTrue(AccessAttempt.objects.count()) + + @override_settings(AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=True) + def test_lockout_by_combination_user_and_ip(self): + """ + Test login failure when AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP is True. + """ + + # test until one try before the limit + for _ in range(1, settings.AXES_FAILURE_LIMIT): + response = self.login( + is_valid_username=True, + is_valid_password=False, + ) + # Check if we are in the same login page + self.assertContains(response, self.LOGIN_FORM_KEY, html=True) + + # So, we shouldn't have gotten a lock-out yet. + # But we should get one now + response = self.login(is_valid_username=True, is_valid_password=False) + self.assertContains(response, self.LOCKED_MESSAGE, status_code=403) + + @override_settings(AXES_ONLY_USER_FAILURES=True) + def test_lockout_by_only_user_failures(self): + """ + Test login failure when AXES_ONLY_USER_FAILURES is True. + """ + + # test until one try before the limit + for _ in range(1, settings.AXES_FAILURE_LIMIT): + response = self._login(self.username, self.WRONG_PASSWORD) + + # Check if we are in the same login page + self.assertContains(response, self.LOGIN_FORM_KEY, html=True) + + # So, we shouldn't have gotten a lock-out yet. + # But we should get one now + response = self._login(self.username, self.WRONG_PASSWORD) + self.assertContains(response, self.LOCKED_MESSAGE, status_code=self.BLOCKED) + + # reset the username only and make sure we can log in now even though our IP has failed each time + self.reset(username=self.username) + + response = self._login(self.username, self.password) + + # Check if we are still in the login page + self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=self.ALLOWED, html=True) + + # now create failure_limit + 1 failed logins and then we should still + # be able to login with valid_username + for _ in range(settings.AXES_FAILURE_LIMIT): + response = self._login(self.username, self.password) + + # Check if we can still log in with valid user + response = self._login(self.username, self.password) + self.assertNotContains(response, self.LOGIN_FORM_KEY, status_code=self.ALLOWED, html=True) + # Test for true and false positives when blocking by IP *OR* user (default) # Cache disabled. Default settings. def test_lockout_by_ip_blocks_when_same_user_same_ip_without_cache(self): diff --git a/axes/tests/test_management.py b/axes/tests/test_management.py index 942cb2d..976e028 100644 --- a/axes/tests/test_management.py +++ b/axes/tests/test_management.py @@ -1,12 +1,12 @@ from io import StringIO from django.core.management import call_command -from django.test import TestCase from axes.models import AccessAttempt +from axes.tests.base import AxesTestCase -class ManagementCommandTestCase(TestCase): +class ManagementCommandTestCase(AxesTestCase): def setUp(self): AccessAttempt.objects.create( username='jane.doe', diff --git a/axes/tests/test_middleware.py b/axes/tests/test_middleware.py index cf47d3b..f803754 100644 --- a/axes/tests/test_middleware.py +++ b/axes/tests/test_middleware.py @@ -1,13 +1,13 @@ from unittest.mock import patch, MagicMock from django.http import HttpResponse -from django.test import TestCase from axes.exceptions import AxesSignalPermissionDenied from axes.middleware import AxesMiddleware +from axes.tests.base import AxesTestCase -class MiddlewareTestCase(TestCase): +class MiddlewareTestCase(AxesTestCase): SUCCESS_RESPONSE = HttpResponse(status=200, content='Dispatched') LOCKOUT_RESPONSE = HttpResponse(status=403, content='Locked out') diff --git a/axes/tests/test_models.py b/axes/tests/test_models.py index 8f84353..d5ebadf 100644 --- a/axes/tests/test_models.py +++ b/axes/tests/test_models.py @@ -3,12 +3,12 @@ from django.db import connection from django.db.migrations.autodetector import MigrationAutodetector from django.db.migrations.executor import MigrationExecutor from django.db.migrations.state import ProjectState -from django.test import TestCase from axes.models import AccessAttempt, AccessLog +from axes.tests.base import AxesTestCase -class ModelsTestCase(TestCase): +class ModelsTestCase(AxesTestCase): def setUp(self): self.failures_since_start = 42 @@ -24,7 +24,7 @@ class ModelsTestCase(TestCase): self.assertIn('Access', str(self.access_log)) -class MigrationsTestCase(TestCase): +class MigrationsTestCase(AxesTestCase): def test_missing_migrations(self): executor = MigrationExecutor(connection) autodetector = MigrationAutodetector( diff --git a/axes/tests/test_signals.py b/axes/tests/test_signals.py new file mode 100644 index 0000000..9324f17 --- /dev/null +++ b/axes/tests/test_signals.py @@ -0,0 +1,18 @@ +from unittest.mock import MagicMock + +from axes.tests.base import AxesTestCase +from axes.signals import user_locked_out + + +class SignalTestCase(AxesTestCase): + def test_send_lockout_signal(self): + """ + Test if the lockout signal is correctly emitted when user is locked out. + """ + + handler = MagicMock() + user_locked_out.connect(handler) + + self.assertEqual(0, handler.call_count) + self.lockout() + self.assertEqual(1, handler.call_count) diff --git a/axes/tests/test_utils.py b/axes/tests/test_utils.py index 9a1b415..98c66bf 100644 --- a/axes/tests/test_utils.py +++ b/axes/tests/test_utils.py @@ -1,73 +1,50 @@ from datetime import timedelta +from hashlib import md5 from unittest.mock import patch -from django.contrib.auth import get_user_model from django.http import HttpRequest, JsonResponse, HttpResponseRedirect, HttpResponse -from django.test import TestCase, override_settings +from django.test import override_settings, RequestFactory from axes import get_version +from axes.models import AccessAttempt +from axes.tests.base import AxesTestCase from axes.utils import ( - get_cool_off_iso8601, + get_cache_timeout, get_client_str, get_client_username, + get_client_cache_key, + get_client_parameters, + get_cool_off_iso8601, get_lockout_response, is_client_ip_address_blacklisted, + is_client_ip_address_whitelisted, is_ip_address_in_blacklist, is_ip_address_in_whitelist, - get_cache_timeout, - is_client_username_whitelisted, - is_client_ip_address_whitelisted) + is_client_method_whitelisted) -def get_username(request: HttpRequest, credentials: dict) -> str: - return 'username' - - -def get_expected_client_str(*args, **kwargs): - client_str_template = '{{username: "{0}", ip_address: "{1}", user_agent: "{2}", path_info: "{3}"}}' - return client_str_template.format(*args, **kwargs) - - -class VersionTestCase(TestCase): +class VersionTestCase(AxesTestCase): @patch('axes.__version__', 'test') def test_get_version(self): self.assertEqual(get_version(), 'test') -class CacheTestCase(TestCase): +class CacheTestCase(AxesTestCase): @override_settings(AXES_COOLOFF_TIME=3) # hours - def test_get_cache_timeout(self): + def test_get_cache_timeout_integer(self): timeout_seconds = float(60 * 60 * 3) self.assertEqual(get_cache_timeout(), timeout_seconds) + @override_settings(AXES_COOLOFF_TIME=timedelta(seconds=420)) + def test_get_cache_timeout_timedelta(self): + self.assertEqual(get_cache_timeout(), 420) -class UserTestCase(TestCase): - def setUp(self): - self.user_model = get_user_model() - self.user = self.user_model.objects.create(username='jane.doe') - self.request = HttpRequest() - - def test_is_client_username_whitelisted(self): - with patch.object(self.user_model, 'nolockout', True, create=True): - self.assertTrue(is_client_username_whitelisted( - self.request, - {self.user_model.USERNAME_FIELD: self.user.username}, - )) - - def test_is_client_username_whitelisted_not(self): - self.assertFalse(is_client_username_whitelisted( - self.request, - {self.user_model.USERNAME_FIELD: self.user.username}, - )) - - def test_is_client_username_whitelisted_does_not_exist(self): - self.assertFalse(is_client_username_whitelisted( - self.request, - {self.user_model.USERNAME_FIELD: 'not.' + self.user.username}, - )) + @override_settings(AXES_COOLOFF_TIME=None) + def test_get_cache_timeout_none(self): + self.assertEqual(get_cache_timeout(), None) -class TimestampTestCase(TestCase): +class TimestampTestCase(AxesTestCase): def test_iso8601(self): """ Test get_cool_off_iso8601 correctly translates datetime.timdelta to ISO 8601 formatted duration. @@ -97,7 +74,12 @@ class TimestampTestCase(TestCase): self.assertEqual(get_cool_off_iso8601(delta), iso_duration) -class ClientStringTestCase(TestCase): +class ClientStringTestCase(AxesTestCase): + @staticmethod + def get_expected_client_str(*args, **kwargs): + client_str_template = '{{username: "{0}", ip_address: "{1}", user_agent: "{2}", path_info: "{3}"}}' + return client_str_template.format(*args, **kwargs) + @override_settings(AXES_VERBOSE=True) def test_verbose_ip_only_client_details(self): username = 'test@example.com' @@ -105,7 +87,7 @@ class ClientStringTestCase(TestCase): user_agent = 'Googlebot/2.1 (+http://www.googlebot.com/bot.html)' path_info = '/admin/' - expected = get_expected_client_str(username, ip_address, user_agent, path_info) + expected = self.get_expected_client_str(username, ip_address, user_agent, path_info) actual = get_client_str(username, ip_address, user_agent, path_info) self.assertEqual(expected, actual) @@ -117,7 +99,7 @@ class ClientStringTestCase(TestCase): user_agent = 'Googlebot/2.1 (+http://www.googlebot.com/bot.html)' path_info = ('admin', 'login') - expected = get_expected_client_str(username, ip_address, user_agent, path_info[0]) + expected = self.get_expected_client_str(username, ip_address, user_agent, path_info[0]) actual = get_client_str(username, ip_address, user_agent, path_info) self.assertEqual(expected, actual) @@ -142,7 +124,7 @@ class ClientStringTestCase(TestCase): user_agent = 'Googlebot/2.1 (+http://www.googlebot.com/bot.html)' path_info = '/admin/' - expected = get_expected_client_str(username, ip_address, user_agent, path_info) + expected = self.get_expected_client_str(username, ip_address, user_agent, path_info) actual = get_client_str(username, ip_address, user_agent, path_info) self.assertEqual(expected, actual) @@ -168,7 +150,7 @@ class ClientStringTestCase(TestCase): user_agent = 'Googlebot/2.1 (+http://www.googlebot.com/bot.html)' path_info = '/admin/' - expected = get_expected_client_str(username, ip_address, user_agent, path_info) + expected = self.get_expected_client_str(username, ip_address, user_agent, path_info) actual = get_client_str(username, ip_address, user_agent, path_info) self.assertEqual(expected, actual) @@ -194,7 +176,7 @@ class ClientStringTestCase(TestCase): user_agent = 'Googlebot/2.1 (+http://www.googlebot.com/bot.html)' path_info = '/admin/' - expected = get_expected_client_str(username, ip_address, user_agent, path_info) + expected = self.get_expected_client_str(username, ip_address, user_agent, path_info) actual = get_client_str(username, ip_address, user_agent, path_info) self.assertEqual(expected, actual) @@ -213,7 +195,138 @@ class ClientStringTestCase(TestCase): self.assertEqual(expected, actual) -class UsernameTestCase(TestCase): +class ClientParametersTestCase(AxesTestCase): + @override_settings( + AXES_ONLY_USER_FAILURES=True, + ) + def test_get_filter_kwargs_user(self): + self.assertEqual( + dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), + {'username': self.username}, + ) + + @override_settings( + AXES_ONLY_USER_FAILURES=False, + AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=False, + AXES_USE_USER_AGENT=False, + ) + def test_get_filter_kwargs_ip(self): + self.assertEqual( + dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), + {'ip_address': self.ip_address}, + ) + + @override_settings( + AXES_ONLY_USER_FAILURES=False, + AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=True, + AXES_USE_USER_AGENT=False, + ) + def test_get_filter_kwargs_user_and_ip(self): + self.assertEqual( + dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), + {'username': self.username, 'ip_address': self.ip_address}, + ) + + @override_settings( + AXES_ONLY_USER_FAILURES=False, + AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=False, + AXES_USE_USER_AGENT=True, + ) + def test_get_filter_kwargs_ip_and_agent(self): + self.assertEqual( + dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), + {'ip_address': self.ip_address, 'user_agent': self.user_agent}, + ) + + @override_settings( + AXES_ONLY_USER_FAILURES=False, + AXES_LOCK_OUT_BY_COMBINATION_USER_AND_IP=True, + AXES_USE_USER_AGENT=True, + ) + def test_get_filter_kwargs_user_ip_agent(self): + self.assertEqual( + dict(get_client_parameters(self.username, self.ip_address, self.user_agent)), + {'username': self.username, 'ip_address': self.ip_address, 'user_agent': self.user_agent}, + ) + + +class ClientCacheKeyTestCase(AxesTestCase): + def test_get_cache_key(self): + """ + Test the cache key format. + """ + + # Getting cache key from request + cache_hash_key = 'axes-{}'.format( + md5(self.ip_address.encode()).hexdigest() + ) + + request_factory = RequestFactory() + request = request_factory.post( + '/admin/login/', + data={ + 'username': self.username, + 'password': 'test', + }, + ) + + self.assertEqual(cache_hash_key, get_client_cache_key(request)) + + # Getting cache key from AccessAttempt Object + attempt = AccessAttempt( + user_agent='', + ip_address=self.ip_address, + username=self.username, + get_data='', + post_data='', + http_accept=request.META.get('HTTP_ACCEPT', ''), + path_info=request.META.get('PATH_INFO', ''), + failures_since_start=0, + ) + + self.assertEqual(cache_hash_key, get_client_cache_key(attempt)) + + def test_get_cache_key_credentials(self): + """ + Test the cache key format. + """ + + # Getting cache key from request + ip_address = self.ip_address + cache_hash_key = 'axes-{}'.format( + md5(ip_address.encode()).hexdigest() + ) + + + request_factory = RequestFactory() + request = request_factory.post( + '/admin/login/', + data={ + 'username': self.username, + 'password': 'test' + } + ) + + # Difference between the upper test: new call signature with credentials + credentials = {'username': self.username} + + self.assertEqual(cache_hash_key, get_client_cache_key(request, credentials)) + + # Getting cache key from AccessAttempt Object + attempt = AccessAttempt( + user_agent='', + ip_address=ip_address, + username=self.username, + get_data='', + post_data='', + http_accept=request.META.get('HTTP_ACCEPT', ''), + path_info=request.META.get('PATH_INFO', ''), + failures_since_start=0, + ) + self.assertEqual(cache_hash_key, get_client_cache_key(attempt)) + + +class UsernameTestCase(AxesTestCase): @override_settings(AXES_USERNAME_FORM_FIELD='username') def test_default_get_client_username(self): expected = 'test-username' @@ -249,7 +362,6 @@ class UsernameTestCase(TestCase): provided = 'test-username' expected = 'prefixed-' + provided provided_in_credentials = 'test-credentials-username' - expected_in_credentials = 'prefixed-' + provided_in_credentials request = HttpRequest() request.POST['username'] = provided @@ -266,7 +378,6 @@ class UsernameTestCase(TestCase): @override_settings(AXES_USERNAME_CALLABLE=sample_customize_username_credentials) def test_custom_get_client_username_from_credentials(self): provided = 'test-username' - expected = 'prefixed-' + provided provided_in_credentials = 'test-credentials-username' expected_in_credentials = 'prefixed-' + provided_in_credentials @@ -305,7 +416,11 @@ class UsernameTestCase(TestCase): ) -class WhitelistTestCase(TestCase): +def get_username(request: HttpRequest, credentials: dict) -> str: + return 'username' + + +class IPWhitelistTestCase(AxesTestCase): def setUp(self): self.request = HttpRequest() self.request.method = 'POST' @@ -363,7 +478,21 @@ class WhitelistTestCase(TestCase): self.assertFalse(is_client_ip_address_whitelisted(self.request)) -class LockoutResponseTestCase(TestCase): +class MethodWhitelistTestCase(AxesTestCase): + def setUp(self): + self.request = HttpRequest() + self.request.method = 'GET' + + @override_settings(AXES_NEVER_LOCKOUT_GET=True) + def test_is_client_method_whitelisted(self): + self.assertTrue(is_client_method_whitelisted(self.request)) + + @override_settings(AXES_NEVER_LOCKOUT_GET=False) + def test_is_client_method_whitelisted_not(self): + self.assertFalse(is_client_method_whitelisted(self.request)) + + +class LockoutResponseTestCase(AxesTestCase): def setUp(self): self.request = HttpRequest() diff --git a/axes/utils.py b/axes/utils.py index fe044bc..22458fc 100644 --- a/axes/utils.py +++ b/axes/utils.py @@ -1,9 +1,9 @@ from collections import OrderedDict from datetime import timedelta +from hashlib import md5 from logging import getLogger -from typing import Optional, Type +from typing import Any, Optional, Type, Union -from django.contrib.auth import get_user_model from django.core.cache import caches, BaseCache from django.http import HttpResponse, HttpResponseRedirect, HttpRequest, JsonResponse, QueryDict from django.shortcuts import render @@ -12,27 +12,10 @@ from django.utils.module_loading import import_string import ipware.ip2 from axes.conf import settings -from axes.models import AccessAttempt logger = getLogger(__name__) -def reset(ip: str = None, username: str = None) -> int: - """ - Reset records that match IP or username, and return the count of removed attempts. - """ - - attempts = AccessAttempt.objects.all() - if ip: - attempts = attempts.filter(ip_address=ip) - if username: - attempts = attempts.filter(username=username) - - count, _ = attempts.delete() - - return count - - def get_axes_cache() -> BaseCache: """ Get the cache instance Axes is configured to use with ``settings.AXES_CACHE`` and use ``'default'`` if not set. @@ -383,28 +366,28 @@ def is_client_method_whitelisted(request: HttpRequest) -> bool: return False -def is_client_username_whitelisted(request: HttpRequest, credentials: dict = None) -> bool: +def get_client_cache_key(request_or_attempt: Union[HttpRequest, Any], credentials: dict = None) -> str: """ - Check if the given request or credentials refer to a whitelisted username. + Build cache key name from request or AccessAttempt object. - A whitelisted user has the magic ``nolockout`` property set. - - If the property is unknown or False or the user can not be found, - this implementation fails gracefully and returns True. + :param request_or_attempt: HttpRequest or AccessAttempt object + :param credentials: credentials containing user information + :return cache_key: Hash key that is usable for Django cache backends """ - username_field = getattr(get_user_model(), 'USERNAME_FIELD', 'username') - username_value = get_client_username(request, credentials) - kwargs = { - username_field: username_value - } + if isinstance(request_or_attempt, HttpRequest): + username = get_client_username(request_or_attempt, credentials) + ip_address = get_client_ip_address(request_or_attempt) + user_agent = get_client_user_agent(request_or_attempt) + else: + username = request_or_attempt.username + ip_address = request_or_attempt.ip_address + user_agent = request_or_attempt.user_agent - user_model = get_user_model() + filter_kwargs = get_client_parameters(username, ip_address, user_agent) - try: - user = user_model.objects.get(**kwargs) - return user.nolockout - except (user_model.DoesNotExist, AttributeError): - pass + cache_key_components = ''.join(filter_kwargs.values()) + cache_key_digest = md5(cache_key_components.encode()).hexdigest() + cache_key = 'axes-{}'.format(cache_key_digest) - return False + return cache_key