diff --git a/axes/attempts.py b/axes/attempts.py index 17f67f4..83367f9 100644 --- a/axes/attempts.py +++ b/axes/attempts.py @@ -1,13 +1,10 @@ from logging import getLogger -from typing import List, Optional +from typing import Optional -from django.db.models import QuerySet from django.http import HttpRequest from django.utils.timezone import datetime, now -from axes.conf import settings -from axes.helpers import get_client_username, get_client_parameters, get_cool_off -from axes.models import AccessAttempt +from axes.helpers import get_cool_off log = getLogger(__name__) @@ -27,83 +24,3 @@ def get_cool_off_threshold(request: Optional[HttpRequest] = None) -> datetime: if attempt_time is None: return now() - cool_off return attempt_time - cool_off - - -def filter_user_attempts( - request: HttpRequest, credentials: Optional[dict] = None -) -> List[QuerySet]: - """ - Return a list querysets of AccessAttempts that match the given request and credentials. - """ - - username = get_client_username(request, credentials) - - filter_kwargs_list = get_client_parameters( - username, request.axes_ip_address, request.axes_user_agent, request, credentials - ) - attempts_list = [ - AccessAttempt.objects.filter(**filter_kwargs) - for filter_kwargs in filter_kwargs_list - ] - return attempts_list - - -def get_user_attempts( - request: HttpRequest, credentials: Optional[dict] = None -) -> List[QuerySet]: - """ - Get list of querysets with valid user attempts that match the given request and credentials. - """ - - attempts_list = filter_user_attempts(request, credentials) - - if settings.AXES_COOLOFF_TIME is None: - log.debug( - "AXES: Getting all access attempts from database because no AXES_COOLOFF_TIME is configured" - ) - return attempts_list - - threshold = get_cool_off_threshold(request) - log.debug("AXES: Getting access attempts that are newer than %s", threshold) - return [attempts.filter(attempt_time__gte=threshold) for attempts in attempts_list] - - -def clean_expired_user_attempts( - request: Optional[HttpRequest] = None, credentials: Optional[dict] = None -) -> int: - """ - Clean expired user attempts from the database. - """ - - if settings.AXES_COOLOFF_TIME is None: - log.debug( - "AXES: Skipping clean for expired access attempts because no AXES_COOLOFF_TIME is configured" - ) - return 0 - - threshold = get_cool_off_threshold(request) - count, _ = AccessAttempt.objects.filter(attempt_time__lt=threshold).delete() - log.info( - "AXES: Cleaned up %s expired access attempts from database that were older than %s", - count, - threshold, - ) - return count - - -def reset_user_attempts( - request: HttpRequest, credentials: Optional[dict] = None -) -> int: - """ - Reset all user attempts that match the given request and credentials. - """ - - attempts_list = filter_user_attempts(request, credentials) - - count = 0 - for attempts in attempts_list: - _count, _ = attempts.delete() - count += _count - log.info("AXES: Reset %s access attempts from database.", count) - - return count diff --git a/axes/handlers/database.py b/axes/handlers/database.py index 2f71229..64f6357 100644 --- a/axes/handlers/database.py +++ b/axes/handlers/database.py @@ -1,20 +1,17 @@ from logging import getLogger -from typing import Optional +from typing import List, Optional from django.db import router, transaction -from django.db.models import F, Q, Sum, Value +from django.db.models import F, Q, QuerySet, Sum, Value from django.db.models.functions import Concat from django.http import HttpRequest from django.utils import timezone -from axes.attempts import ( - clean_expired_user_attempts, - get_user_attempts, - reset_user_attempts, -) +from axes.attempts import get_cool_off_threshold from axes.conf import settings -from axes.handlers.base import AxesBaseHandler, AbstractAxesHandler +from axes.handlers.base import AbstractAxesHandler, AxesBaseHandler from axes.helpers import ( + get_client_parameters, get_client_session_hash, get_client_str, get_client_username, @@ -23,7 +20,7 @@ from axes.helpers import ( get_lockout_parameters, get_query_str, ) -from axes.models import AccessLog, AccessAttempt, AccessFailureLog +from axes.models import AccessAttempt, AccessFailureLog, AccessLog from axes.signals import user_locked_out log = getLogger(__name__) @@ -105,7 +102,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler): return count def get_failures(self, request, credentials: Optional[dict] = None) -> int: - attempts_list = get_user_attempts(request, credentials) + attempts_list = self.get_user_attempts(request, credentials) attempt_count = max( ( attempts.aggregate(Sum("failures_since_start"))[ @@ -293,7 +290,7 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler): if settings.AXES_RESET_ON_SUCCESS: # 3. database query: Reset failed attempts for the logging in user - count = reset_user_attempts(request, credentials) + count = self.reset_user_attempts(request, credentials) log.info( "AXES: Deleted %d failed login attempts by %s from database.", count, @@ -329,6 +326,88 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler): session_hash=get_client_session_hash(request), ).update(logout_time=request.axes_attempt_time) + def filter_user_attempts( + self, request: HttpRequest, credentials: Optional[dict] = None + ) -> List[QuerySet]: + """ + Return a list querysets of AccessAttempts that match the given request and credentials. + """ + + username = get_client_username(request, credentials) + + filter_kwargs_list = get_client_parameters( + username, + request.axes_ip_address, + request.axes_user_agent, + request, + credentials, + ) + attempts_list = [ + AccessAttempt.objects.filter(**filter_kwargs) + for filter_kwargs in filter_kwargs_list + ] + return attempts_list + + def get_user_attempts( + self, request: HttpRequest, credentials: Optional[dict] = None + ) -> List[QuerySet]: + """ + Get list of querysets with valid user attempts that match the given request and credentials. + """ + + attempts_list = self.filter_user_attempts(request, credentials) + + if settings.AXES_COOLOFF_TIME is None: + log.debug( + "AXES: Getting all access attempts from database because no AXES_COOLOFF_TIME is configured" + ) + return attempts_list + + threshold = get_cool_off_threshold(request) + log.debug("AXES: Getting access attempts that are newer than %s", threshold) + return [ + attempts.filter(attempt_time__gte=threshold) for attempts in attempts_list + ] + + def clean_expired_user_attempts( + self, request: Optional[HttpRequest] = None, credentials: Optional[dict] = None + ) -> int: + """ + Clean expired user attempts from the database. + """ + + if settings.AXES_COOLOFF_TIME is None: + log.debug( + "AXES: Skipping clean for expired access attempts because no AXES_COOLOFF_TIME is configured" + ) + return 0 + + threshold = get_cool_off_threshold(request) + count, _ = AccessAttempt.objects.filter(attempt_time__lt=threshold).delete() + log.info( + "AXES: Cleaned up %s expired access attempts from database that were older than %s", + count, + threshold, + ) + return count + + def reset_user_attempts( + self, request: HttpRequest, credentials: Optional[dict] = None + ) -> int: + """ + Reset all user attempts that match the given request and credentials. + """ + + attempts_list = self.filter_user_attempts(request, credentials) + + count = 0 + for attempts in attempts_list: + _count, _ = attempts.delete() + count += _count + log.info("AXES: Reset %s access attempts from database.", count) + + return count + def post_save_access_attempt(self, instance, **kwargs): """ Handles the ``axes.models.AccessAttempt`` object post save signal. @@ -344,13 +423,3 @@ class AxesDatabaseHandler(AbstractAxesHandler, AxesBaseHandler): When needed, all post_delete actions for this backend should be located here. """ - - @staticmethod - def clean_expired_user_attempts( - request: Optional[HttpRequest] = None, credentials: Optional[dict] = None - ) -> int: - """ - Clean expired user attempts from the database. - """ - - return clean_expired_user_attempts(request, credentials)