From 4684a1a04fffe309808c4b4e4e2f8a4691b14834 Mon Sep 17 00:00:00 2001 From: jd Date: Tue, 22 May 2018 14:54:45 -0700 Subject: [PATCH] Add a setting to supply a callable that can return a correct username given a request object #318 --- axes/attempts.py | 6 +++--- axes/conf.py | 3 +++ axes/decorators.py | 4 ++-- axes/tests/test_utils.py | 30 +++++++++++++++++++++++++++++- axes/utils.py | 6 ++++++ 5 files changed, 43 insertions(+), 6 deletions(-) diff --git a/axes/attempts.py b/axes/attempts.py index ac8cdaf..caf55c3 100644 --- a/axes/attempts.py +++ b/axes/attempts.py @@ -8,7 +8,7 @@ from django.utils import timezone from axes.conf import settings from axes.models import AccessAttempt -from axes.utils import get_axes_cache, get_client_ip +from axes.utils import get_axes_cache, get_client_ip, get_client_username def _query_user_attempts(request): @@ -16,7 +16,7 @@ def _query_user_attempts(request): Otherwise return None. """ ip = get_client_ip(request) - username = request.POST.get(settings.AXES_USERNAME_FORM_FIELD, None) + username = get_client_username(request) if settings.AXES_ONLY_USER_FAILURES: attempts = AccessAttempt.objects.filter(username=username) @@ -158,7 +158,7 @@ def is_user_lockable(request): try: field = getattr(get_user_model(), 'USERNAME_FIELD', 'username') kwargs = { - field: request.POST.get(settings.AXES_USERNAME_FORM_FIELD) + field: get_client_username(request) } user = get_user_model().objects.get(**kwargs) diff --git a/axes/conf.py b/axes/conf.py index 336db56..f2be02e 100644 --- a/axes/conf.py +++ b/axes/conf.py @@ -20,6 +20,9 @@ class MyAppConf(AppConf): # use a specific password field to retrieve from login POST data PASSWORD_FORM_FIELD = 'password' + # use a provided callable to transform the POSTed username into the one used in credentials + USERNAME_CALLABLE = None + # only check user name and not location or user_agent ONLY_USER_FAILURES = False diff --git a/axes/decorators.py b/axes/decorators.py index 51daa77..712122f 100644 --- a/axes/decorators.py +++ b/axes/decorators.py @@ -12,7 +12,7 @@ from django.shortcuts import render from axes import get_version from axes.conf import settings from axes.attempts import is_already_locked -from axes.utils import iso8601, get_lockout_message +from axes.utils import iso8601, get_client_username, get_lockout_message log = logging.getLogger(settings.AXES_LOGGER) if settings.AXES_VERBOSE: @@ -50,7 +50,7 @@ def axes_form_invalid(func): def lockout_response(request): context = { 'failure_limit': settings.AXES_FAILURE_LIMIT, - 'username': request.POST.get(settings.AXES_USERNAME_FORM_FIELD, '') + 'username': get_client_username(request) or '' } cool_off = settings.AXES_COOLOFF_TIME diff --git a/axes/tests/test_utils.py b/axes/tests/test_utils.py index 9c1f0bc..bf3b40a 100644 --- a/axes/tests/test_utils.py +++ b/axes/tests/test_utils.py @@ -2,10 +2,11 @@ from __future__ import unicode_literals import datetime +from django.http import HttpRequest from django.test import TestCase, override_settings from django.utils import six -from axes.utils import iso8601, is_ipv6, get_client_str +from axes.utils import iso8601, is_ipv6, get_client_str, get_client_username class UtilsTest(TestCase): @@ -143,3 +144,30 @@ class UtilsTest(TestCase): actual = get_client_str(username, ip, user_agent, path_info) self.assertEqual(expected, actual) + + @override_settings(AXES_USERNAME_FORM_FIELD='username') + def test_default_get_client_username(self): + expected = 'test-username' + + request = HttpRequest() + request.POST['username'] = expected + + actual = get_client_username(request) + + self.assertEqual(expected, actual) + + def sample_customize_username(request): + return 'prefixed-' + request.POST.get('username') + + @override_settings(AXES_USERNAME_FORM_FIELD='username') + @override_settings(AXES_USERNAME_CALLABLE=sample_customize_username) + def test_custom_get_client_username(self): + provided = 'test-username' + expected = 'prefixed-' + provided + + request = HttpRequest() + request.POST['username'] = provided + + actual = get_client_username(request) + + self.assertEqual(expected, actual) diff --git a/axes/utils.py b/axes/utils.py index 69c08d8..0596f7a 100644 --- a/axes/utils.py +++ b/axes/utils.py @@ -69,6 +69,12 @@ def get_client_ip(request): return getattr(request, client_ip_attribute) +def get_client_username(request): + if settings.AXES_USERNAME_CALLABLE: + return settings.AXES_USERNAME_CALLABLE(request) + return request.POST.get(settings.AXES_USERNAME_FORM_FIELD, None) + + def is_ipv6(ip): try: inet_pton(AF_INET6, ip)