diff --git a/axes/attempts.py b/axes/attempts.py index d78b224..d15e9e1 100644 --- a/axes/attempts.py +++ b/axes/attempts.py @@ -11,12 +11,12 @@ from axes.models import AccessAttempt from axes.utils import get_axes_cache, get_client_ip, get_client_username -def _query_user_attempts(request): +def _query_user_attempts(request, credentials): """Returns access attempt record if it exists. Otherwise return None. """ ip = get_client_ip(request) - username = get_client_username(request) + username = get_client_username(request, credentials) if settings.AXES_ONLY_USER_FAILURES: attempts = AccessAttempt.objects.filter(username=username) @@ -49,10 +49,11 @@ def _query_user_attempts(request): return attempts -def get_cache_key(request_or_obj): +def get_cache_key(request_or_obj, credentials=None): """ Build cache key name from request or AccessAttempt object. :param request_or_obj: Request or AccessAttempt object + :param credentials: Dictionary with access credentials - Only supplied when request_or_obj is not an AccessAttempt :return cache-key: String, key to be used in cache system """ if isinstance(request_or_obj, AccessAttempt): @@ -61,8 +62,11 @@ def get_cache_key(request_or_obj): ua = request_or_obj.user_agent else: ip = get_client_ip(request_or_obj) - un = request_or_obj.POST.get(settings.AXES_USERNAME_FORM_FIELD, None) ua = request_or_obj.META.get('HTTP_USER_AGENT', '')[:255] + if credentials is None: + un = request_or_obj.POST.get(settings.AXES_USERNAME_FORM_FIELD, None) + else: + un = credentials.get(settings.AXES_USERNAME_FORM_FIELD, None) ip = ip.encode('utf-8') if ip else ''.encode('utf-8') un = un.encode('utf-8') if un else ''.encode('utf-8') @@ -96,10 +100,10 @@ def get_cache_timeout(): return cache_timeout -def get_user_attempts(request): +def get_user_attempts(request, credentials): force_reload = False - attempts = _query_user_attempts(request) - cache_hash_key = get_cache_key(request) + attempts = _query_user_attempts(request, credentials) + cache_hash_key = get_cache_key(request, credentials) cache_timeout = get_cache_timeout() cool_off = settings.AXES_COOLOFF_TIME @@ -125,13 +129,13 @@ def get_user_attempts(request): # If objects were deleted, we need to update the queryset to reflect this, # so force a reload. if force_reload: - attempts = _query_user_attempts(request) + attempts = _query_user_attempts(request, credentials) return attempts -def reset_user_attempts(request): - attempts = _query_user_attempts(request) +def reset_user_attempts(request, credentials): + attempts = _query_user_attempts(request, credentials) count, _ = attempts.delete() return count @@ -151,7 +155,7 @@ def ip_in_blacklist(ip): return ip in settings.AXES_IP_BLACKLIST -def is_user_lockable(request): +def is_user_lockable(request, credentials): """Check if the user has a profile with nolockout If so, then return the value to see if this user is special and doesn't get their account locked out @@ -165,7 +169,7 @@ def is_user_lockable(request): try: field = getattr(get_user_model(), 'USERNAME_FIELD', 'username') kwargs = { - field: get_client_username(request) + field: get_client_username(request, credentials) } user = get_user_model().objects.get(**kwargs) @@ -182,7 +186,7 @@ def is_user_lockable(request): return True -def is_already_locked(request): +def is_already_locked(request, credentials=None): ip = get_client_ip(request) if ( @@ -200,10 +204,10 @@ def is_already_locked(request): if ip_in_blacklist(ip): return True - if not is_user_lockable(request): + if not is_user_lockable(request, credentials): return False - cache_hash_key = get_cache_key(request) + cache_hash_key = get_cache_key(request, credentials) failures_cached = get_axes_cache().get(cache_hash_key) if failures_cached is not None: return ( @@ -211,7 +215,7 @@ def is_already_locked(request): settings.AXES_LOCK_OUT_AT_FAILURE ) - for attempt in get_user_attempts(request): + for attempt in get_user_attempts(request, credentials): if ( attempt.failures_since_start >= settings.AXES_FAILURE_LIMIT and settings.AXES_LOCK_OUT_AT_FAILURE diff --git a/axes/backends.py b/axes/backends.py index 8e1b2fd..cbd2d5b 100644 --- a/axes/backends.py +++ b/axes/backends.py @@ -4,6 +4,7 @@ from django.contrib.auth.backends import ModelBackend from django.core.exceptions import PermissionDenied from axes.attempts import is_already_locked +from axes.conf import settings from axes.utils import get_lockout_message @@ -30,10 +31,13 @@ class AxesModelBackend(ModelBackend): :return: Nothing, but will update return_context with lockout message if user is locked out. """ + # Create credentials dictionary from username field + credentials = {settings.AXES_USERNAME_FORM_FIELD: username} + if request is None: raise AxesModelBackend.RequestParameterRequired() - if is_already_locked(request): + if is_already_locked(request, credentials): # locked out, don't try to authenticate, just update return_context and return # Its a bit weird to pass a context and expect a response value but its nice to get a "why" back. error_msg = get_lockout_message() diff --git a/axes/signals.py b/axes/signals.py index a3f6b47..c6e2afc 100644 --- a/axes/signals.py +++ b/axes/signals.py @@ -38,7 +38,7 @@ def log_user_login_failed(sender, credentials, request, **kwargs): # pylint: di return ip_address = get_client_ip(request) - username = get_client_username(request) + username = get_client_username(request, credentials) user_agent = request.META.get('HTTP_USER_AGENT', '')[:255] path_info = request.META.get('PATH_INFO', '')[:255] http_accept = request.META.get('HTTP_ACCEPT', '')[:1025] @@ -47,8 +47,8 @@ def log_user_login_failed(sender, credentials, request, **kwargs): # pylint: di return failures = 0 - attempts = get_user_attempts(request) - cache_hash_key = get_cache_key(request) + attempts = get_user_attempts(request, credentials) + cache_hash_key = get_cache_key(request, credentials) cache_timeout = get_cache_timeout() failures_cached = get_axes_cache().get(cache_hash_key) @@ -110,7 +110,7 @@ def log_user_login_failed(sender, credentials, request, **kwargs): # pylint: di if ( failures >= settings.AXES_FAILURE_LIMIT and settings.AXES_LOCK_OUT_AT_FAILURE and - is_user_lockable(request) + is_user_lockable(request, credentials) ): log.warning( 'AXES: locked out %s after repeated login attempts.', @@ -148,7 +148,9 @@ def log_user_logged_in(sender, request, user, **kwargs): # pylint: disable=unus ) if settings.AXES_RESET_ON_SUCCESS: - count = reset_user_attempts(request) + # Create credentials dictionary from the username field + credentials = {settings.AXES_USERNAME_FORM_FIELD: username} + count = reset_user_attempts(request, credentials) log.info( 'AXES: Deleted %d failed login attempts by %s.', count, diff --git a/axes/tests/test_access_attempt.py b/axes/tests/test_access_attempt.py index 5f7c185..c02f0da 100644 --- a/axes/tests/test_access_attempt.py +++ b/axes/tests/test_access_attempt.py @@ -206,8 +206,11 @@ class AccessAttemptTest(TestCase): 'username': self.VALID_USERNAME, 'password': 'test' }) + credentials = { + 'username': self.VALID_USERNAME + } - self.assertEqual(cache_hash_key, get_cache_key(request)) + self.assertEqual(cache_hash_key, get_cache_key(request, credentials)) # Getting cache key from AccessAttempt Object attempt = AccessAttempt( diff --git a/axes/tests/test_utils.py b/axes/tests/test_utils.py index bf3b40a..378379d 100644 --- a/axes/tests/test_utils.py +++ b/axes/tests/test_utils.py @@ -146,7 +146,7 @@ class UtilsTest(TestCase): self.assertEqual(expected, actual) @override_settings(AXES_USERNAME_FORM_FIELD='username') - def test_default_get_client_username(self): + def test_default_get_client_username_from_request(self): expected = 'test-username' request = HttpRequest() @@ -156,12 +156,27 @@ class UtilsTest(TestCase): self.assertEqual(expected, actual) - def sample_customize_username(request): + @override_settings(AXES_USERNAME_FORM_FIELD='username') + def test_default_get_client_username_from_credentials(self): + expected = 'test-username' + expected_in_credentials = 'test-credentials-username' + + request = HttpRequest() + request.POST['username'] = expected + credentials = { + 'username': expected_in_credentials + } + + actual = get_client_username(request, credentials) + + self.assertEqual(expected_in_credentials, actual) + + def sample_customize_username_from_request(request, credentials): return 'prefixed-' + request.POST.get('username') @override_settings(AXES_USERNAME_FORM_FIELD='username') - @override_settings(AXES_USERNAME_CALLABLE=sample_customize_username) - def test_custom_get_client_username(self): + @override_settings(AXES_USERNAME_CALLABLE=sample_customize_username_from_request) + def test_custom_get_client_username_from_request(self): provided = 'test-username' expected = 'prefixed-' + provided @@ -171,3 +186,22 @@ class UtilsTest(TestCase): actual = get_client_username(request) self.assertEqual(expected, actual) + + def sample_customize_username_from_credentials(request, credentials): + return 'prefixed-' + credentials.get('username') + + @override_settings(AXES_USERNAME_FORM_FIELD='username') + @override_settings(AXES_USERNAME_CALLABLE=sample_customize_username_from_credentials) + def test_custom_get_client_username_from_credentials(self): + provided = 'test-username' + expected = 'prefixed-' + provided + provided_in_credentials = 'test-username' + expected_in_credentials = 'prefixed-' + provided_in_credentials + + request = HttpRequest() + request.POST['username'] = provided + credentials = {'username': provided_in_credentials} + + actual = get_client_username(request, credentials) + + self.assertEqual(expected_in_credentials, actual) diff --git a/axes/utils.py b/axes/utils.py index 0596f7a..a77beda 100644 --- a/axes/utils.py +++ b/axes/utils.py @@ -69,10 +69,12 @@ def get_client_ip(request): return getattr(request, client_ip_attribute) -def get_client_username(request): +def get_client_username(request, credentials=None): if settings.AXES_USERNAME_CALLABLE: - return settings.AXES_USERNAME_CALLABLE(request) - return request.POST.get(settings.AXES_USERNAME_FORM_FIELD, None) + return settings.AXES_USERNAME_CALLABLE(request, credentials) + if credentials is None: + return request.POST.get(settings.AXES_USERNAME_FORM_FIELD, None) + return credentials.get(settings.AXES_USERNAME_FORM_FIELD, None) def is_ipv6(ip):