diff --git a/MANIFEST b/MANIFEST index f2ed177..3a6bf60 100644 --- a/MANIFEST +++ b/MANIFEST @@ -4,6 +4,7 @@ src/auditlog/__init__.py src/auditlog/admin.py src/auditlog/apps.py src/auditlog/compat.py +src/auditlog/context.py src/auditlog/diff.py src/auditlog/filters.py src/auditlog/middleware.py diff --git a/requirements-test.txt b/requirements-test.txt index d240772..30c3bfe 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -5,3 +5,4 @@ tox>=1.7.0 codecov>=2.0.0 django-multiselectfield==0.1.8 psycopg2-binary +mock diff --git a/src/auditlog/context.py b/src/auditlog/context.py new file mode 100644 index 0000000..291fc25 --- /dev/null +++ b/src/auditlog/context.py @@ -0,0 +1,55 @@ +import contextlib +import time +import threading + +from django.contrib.auth import get_user_model +from django.db.models.signals import pre_save +from django.utils.functional import curry + +from auditlog.models import LogEntry + +threadlocal = threading.local() + + +@contextlib.contextmanager +def set_actor(actor, remote_addr=None): + """Connect a signal receiver with current user attached.""" + # Initialize thread local storage + threadlocal.auditlog = { + 'signal_duid': ('set_actor', time.time()), + 'remote_addr': remote_addr, + } + + # Connect signal for automatic logging + set_actor = curry(_set_actor, user=actor, signal_duid=threadlocal.auditlog['signal_duid']) + pre_save.connect(set_actor, sender=LogEntry, dispatch_uid=threadlocal.auditlog['signal_duid'], weak=False) + + try: + yield + + finally: + try: + auditlog = threadlocal.auditlog + except AttributeError: + pass + else: + pre_save.disconnect(sender=LogEntry, dispatch_uid=auditlog['signal_duid']) + + +def _set_actor(user, sender, instance, signal_duid, **kwargs): + """Signal receiver with an extra 'user' kwarg. + + This function becomes a valid signal receiver when it is curried with the actor. + """ + try: + auditlog = threadlocal.auditlog + except AttributeError: + pass + else: + if signal_duid != auditlog['signal_duid']: + return + auth_user_model = get_user_model() + if sender == LogEntry and isinstance(user, auth_user_model) and instance.actor is None: + instance.actor = user + + instance.remote_addr = auditlog['remote_addr'] diff --git a/src/auditlog/middleware.py b/src/auditlog/middleware.py index 1f29035..b983789 100644 --- a/src/auditlog/middleware.py +++ b/src/auditlog/middleware.py @@ -1,78 +1,38 @@ from __future__ import unicode_literals -import threading -import time +import contextlib -from django.contrib.auth import get_user_model -from django.db.models.signals import pre_save -from django.utils.functional import curry -from auditlog.models import LogEntry from auditlog.compat import is_authenticated - -# Use MiddlewareMixin when present (Django >= 1.10) -try: - from django.utils.deprecation import MiddlewareMixin -except ImportError: - MiddlewareMixin = object +from auditlog.context import set_actor -threadlocal = threading.local() +@contextlib.contextmanager +def nullcontext(): + """Equivalent to contextlib.nullcontext(None) from Python 3.7.""" + yield -class AuditlogMiddleware(MiddlewareMixin): +class AuditlogMiddleware(object): """ Middleware to couple the request's user to log items. This is accomplished by currying the signal receiver with the user from the request (or None if the user is not authenticated). """ - def process_request(self, request): - """ - Gets the current user from the request and prepares and connects a signal receiver with the user already - attached to it. - """ - # Initialize thread local storage - threadlocal.auditlog = { - 'signal_duid': (self.__class__, time.time()), - 'remote_addr': request.META.get('REMOTE_ADDR'), - } + def __init__(self, get_response=None): + self.get_response = get_response + + def __call__(self, request): - # In case of proxy, set 'original' address if request.META.get('HTTP_X_FORWARDED_FOR'): - threadlocal.auditlog['remote_addr'] = request.META.get('HTTP_X_FORWARDED_FOR').split(',')[0] + # In case of proxy, set 'original' address + remote_addr = request.META.get('HTTP_X_FORWARDED_FOR').split(',')[0] + else: + remote_addr = request.META.get('REMOTE_ADDR') - # Connect signal for automatic logging if hasattr(request, 'user') and is_authenticated(request.user): - set_actor = curry(self.set_actor, user=request.user, signal_duid=threadlocal.auditlog['signal_duid']) - pre_save.connect(set_actor, sender=LogEntry, dispatch_uid=threadlocal.auditlog['signal_duid'], weak=False) + context = set_actor(actor=request.user, remote_addr=remote_addr) + else: + context = nullcontext() - def process_response(self, request, response): - """ - Disconnects the signal receiver to prevent it from staying active. - """ - if hasattr(threadlocal, 'auditlog'): - pre_save.disconnect(sender=LogEntry, dispatch_uid=threadlocal.auditlog['signal_duid']) - - return response - - def process_exception(self, request, exception): - """ - Disconnects the signal receiver to prevent it from staying active in case of an exception. - """ - if hasattr(threadlocal, 'auditlog'): - pre_save.disconnect(sender=LogEntry, dispatch_uid=threadlocal.auditlog['signal_duid']) - - return None - - @staticmethod - def set_actor(user, sender, instance, signal_duid, **kwargs): - """ - Signal receiver with an extra, required 'user' kwarg. This method becomes a real (valid) signal receiver when - it is curried with the actor. - """ - if hasattr(threadlocal, 'auditlog'): - if signal_duid != threadlocal.auditlog['signal_duid']: - return - if sender == LogEntry and isinstance(user, get_user_model()) and instance.actor is None: - instance.actor = user - - instance.remote_addr = threadlocal.auditlog['remote_addr'] + with context: + return self.get_response(request) diff --git a/src/auditlog_tests/tests.py b/src/auditlog_tests/tests.py index 4560acc..5de9e59 100644 --- a/src/auditlog_tests/tests.py +++ b/src/auditlog_tests/tests.py @@ -9,6 +9,7 @@ from django.http import HttpResponse from django.test import TestCase, RequestFactory from django.utils import dateformat, formats, timezone from dateutil.tz import gettz +import mock from auditlog.middleware import AuditlogMiddleware from auditlog.models import LogEntry @@ -121,66 +122,64 @@ class MiddlewareTest(TestCase): Test the middleware responsible for connecting and disconnecting the signals used in automatic logging. """ def setUp(self): - self.middleware = AuditlogMiddleware() + self.get_response_mock = mock.Mock() + self.response_mock = mock.Mock() + self.middleware = AuditlogMiddleware(get_response=self.get_response_mock) self.factory = RequestFactory() self.user = User.objects.create_user(username='test', email='test@example.com', password='top_secret') + def side_effect(self, assertion): + def inner(request): + assertion() + return self.response_mock + + return inner + + def assert_has_listeners(self): + self.assertTrue(pre_save.has_listeners(LogEntry)) + + def assert_no_listeners(self): + self.assertFalse(pre_save.has_listeners(LogEntry)) + def test_request_anonymous(self): """No actor will be logged when a user is not logged in.""" - # Create a request request = self.factory.get('/') request.user = AnonymousUser() - # Run middleware - self.middleware.process_request(request) + self.get_response_mock.side_effect = self.side_effect(self.assert_no_listeners) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + response = self.middleware(request) - # Finalize transaction - self.middleware.process_exception(request, None) + self.assertIs(response, self.response_mock) + self.get_response_mock.assert_called_once_with(request) + self.assert_no_listeners() def test_request(self): """The actor will be logged when a user is logged in.""" - # Create a request - request = self.factory.get('/') - request.user = self.user - # Run middleware - self.middleware.process_request(request) - - # Validate result - self.assertTrue(pre_save.has_listeners(LogEntry)) - - # Finalize transaction - self.middleware.process_exception(request, None) - - def test_response(self): - """The signal will be disconnected when the request is processed.""" - # Create a request request = self.factory.get('/') request.user = self.user - # Run middleware - self.middleware.process_request(request) - self.assertTrue(pre_save.has_listeners(LogEntry)) # The signal should be present before trying to disconnect it. - self.middleware.process_response(request, HttpResponse()) + self.get_response_mock.side_effect = self.side_effect(self.assert_has_listeners) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + response = self.middleware(request) + + self.assertIs(response, self.response_mock) + self.get_response_mock.assert_called_once_with(request) + self.assert_no_listeners() def test_exception(self): """The signal will be disconnected when an exception is raised.""" - # Create a request request = self.factory.get('/') request.user = self.user - # Run middleware - self.middleware.process_request(request) - self.assertTrue(pre_save.has_listeners(LogEntry)) # The signal should be present before trying to disconnect it. - self.middleware.process_exception(request, ValidationError("Test")) + SomeException = type('SomeException', (Exception,), {}) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + self.get_response_mock.side_effect = SomeException + + with self.assertRaises(SomeException): + self.middleware(request) + + self.assert_no_listeners() class SimpeIncludeModelTest(TestCase):