diff --git a/auditlog/context.py b/auditlog/context.py new file mode 100644 index 0000000..251a240 --- /dev/null +++ b/auditlog/context.py @@ -0,0 +1,66 @@ +import contextlib +import threading +import time +from functools import partial + +from django.contrib.auth import get_user_model +from django.db.models.signals import pre_save + +from auditlog.models import LogEntry + +threadlocal = threading.local() + + +@contextlib.contextmanager +def set_actor(actor, remote_addr=None): + """Connect a signal receiver with current user attached.""" + # Initialize thread local storage + threadlocal.auditlog = { + "signal_duid": ("set_actor", time.time()), + "remote_addr": remote_addr, + } + + # Connect signal for automatic logging + set_actor = partial( + _set_actor, user=actor, signal_duid=threadlocal.auditlog["signal_duid"] + ) + pre_save.connect( + set_actor, + sender=LogEntry, + dispatch_uid=threadlocal.auditlog["signal_duid"], + weak=False, + ) + + try: + yield + + finally: + try: + auditlog = threadlocal.auditlog + except AttributeError: + pass + else: + pre_save.disconnect(sender=LogEntry, dispatch_uid=auditlog["signal_duid"]) + + +def _set_actor(user, sender, instance, signal_duid, **kwargs): + """Signal receiver with an extra 'user' kwarg. + + This function becomes a valid signal receiver when it is curried with the actor. + """ + try: + auditlog = threadlocal.auditlog + except AttributeError: + pass + else: + if signal_duid != auditlog["signal_duid"]: + return + auth_user_model = get_user_model() + if ( + sender == LogEntry + and isinstance(user, auth_user_model) + and instance.actor is None + ): + instance.actor = user + + instance.remote_addr = auditlog["remote_addr"] diff --git a/auditlog/middleware.py b/auditlog/middleware.py index de18037..18fbd0e 100644 --- a/auditlog/middleware.py +++ b/auditlog/middleware.py @@ -1,91 +1,37 @@ -import threading -import time -from functools import partial +import contextlib -from django.contrib.auth import get_user_model -from django.db.models.signals import pre_save -from django.utils.deprecation import MiddlewareMixin - -from auditlog.models import LogEntry - -threadlocal = threading.local() +from auditlog.context import set_actor -class AuditlogMiddleware(MiddlewareMixin): +@contextlib.contextmanager +def nullcontext(): + """Equivalent to contextlib.nullcontext(None) from Python 3.7.""" + yield + + +class AuditlogMiddleware(object): """ Middleware to couple the request's user to log items. This is accomplished by currying the signal receiver with the user from the request (or None if the user is not authenticated). """ - def process_request(self, request): - """ - Gets the current user from the request and prepares and connects a signal receiver with the user already - attached to it. - """ - # Initialize thread local storage - threadlocal.auditlog = { - "signal_duid": (self.__class__, time.time()), - "remote_addr": request.META.get("REMOTE_ADDR"), - } + def __init__(self, get_response=None): + self.get_response = get_response + + def __call__(self, request): - # In case of proxy, set 'original' address if request.META.get("HTTP_X_FORWARDED_FOR"): - threadlocal.auditlog["remote_addr"] = request.META.get( - "HTTP_X_FORWARDED_FOR" - ).split(",")[0] + # In case of proxy, set 'original' address + remote_addr = request.META.get("HTTP_X_FORWARDED_FOR").split(",")[0] + else: + remote_addr = request.META.get("REMOTE_ADDR") - # Connect signal for automatic logging if hasattr(request, "user") and getattr( request.user, "is_authenticated", False ): - set_actor = partial( - self.set_actor, - user=request.user, - signal_duid=threadlocal.auditlog["signal_duid"], - ) - pre_save.connect( - set_actor, - sender=LogEntry, - dispatch_uid=threadlocal.auditlog["signal_duid"], - weak=False, - ) + context = set_actor(actor=request.user, remote_addr=remote_addr) + else: + context = nullcontext() - def process_response(self, request, response): - """ - Disconnects the signal receiver to prevent it from staying active. - """ - if hasattr(threadlocal, "auditlog"): - pre_save.disconnect( - sender=LogEntry, dispatch_uid=threadlocal.auditlog["signal_duid"] - ) - - return response - - def process_exception(self, request, exception): - """ - Disconnects the signal receiver to prevent it from staying active in case of an exception. - """ - if hasattr(threadlocal, "auditlog"): - pre_save.disconnect( - sender=LogEntry, dispatch_uid=threadlocal.auditlog["signal_duid"] - ) - - return None - - @staticmethod - def set_actor(user, sender, instance, signal_duid, **kwargs): - """ - Signal receiver with an extra, required 'user' kwarg. This method becomes a real (valid) signal receiver when - it is curried with the actor. - """ - if hasattr(threadlocal, "auditlog"): - if signal_duid != threadlocal.auditlog["signal_duid"]: - return - if ( - sender == LogEntry - and isinstance(user, get_user_model()) - and instance.actor is None - ): - instance.actor = user - - instance.remote_addr = threadlocal.auditlog["remote_addr"] + with context: + return self.get_response(request) diff --git a/auditlog_tests/tests.py b/auditlog_tests/tests.py index eb6b858..3c47f1b 100644 --- a/auditlog_tests/tests.py +++ b/auditlog_tests/tests.py @@ -2,6 +2,7 @@ import datetime import json import django +import mock from dateutil.tz import gettz from django.conf import settings from django.contrib import auth @@ -167,72 +168,66 @@ class MiddlewareTest(TestCase): """ def setUp(self): - self.middleware = AuditlogMiddleware() + self.get_response_mock = mock.Mock() + self.response_mock = mock.Mock() + self.middleware = AuditlogMiddleware(get_response=self.get_response_mock) self.factory = RequestFactory() self.user = User.objects.create_user( username="test", email="test@example.com", password="top_secret" ) + def side_effect(self, assertion): + def inner(request): + assertion() + return self.response_mock + + return inner + + def assert_has_listeners(self): + self.assertTrue(pre_save.has_listeners(LogEntry)) + + def assert_no_listeners(self): + self.assertFalse(pre_save.has_listeners(LogEntry)) + def test_request_anonymous(self): """No actor will be logged when a user is not logged in.""" - # Create a request request = self.factory.get("/") request.user = AnonymousUser() - # Run middleware - self.middleware.process_request(request) + self.get_response_mock.side_effect = self.side_effect(self.assert_no_listeners) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + response = self.middleware(request) - # Finalize transaction - self.middleware.process_exception(request, None) + self.assertIs(response, self.response_mock) + self.get_response_mock.assert_called_once_with(request) + self.assert_no_listeners() def test_request(self): """The actor will be logged when a user is logged in.""" - # Create a request - request = self.factory.get("/") - request.user = self.user - # Run middleware - self.middleware.process_request(request) - - # Validate result - self.assertTrue(pre_save.has_listeners(LogEntry)) - - # Finalize transaction - self.middleware.process_exception(request, None) - - def test_response(self): - """The signal will be disconnected when the request is processed.""" - # Create a request request = self.factory.get("/") request.user = self.user - # Run middleware - self.middleware.process_request(request) - self.assertTrue( - pre_save.has_listeners(LogEntry) - ) # The signal should be present before trying to disconnect it. - self.middleware.process_response(request, HttpResponse()) + self.get_response_mock.side_effect = self.side_effect(self.assert_has_listeners) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + response = self.middleware(request) + + self.assertIs(response, self.response_mock) + self.get_response_mock.assert_called_once_with(request) + self.assert_no_listeners() def test_exception(self): """The signal will be disconnected when an exception is raised.""" - # Create a request request = self.factory.get("/") request.user = self.user - # Run middleware - self.middleware.process_request(request) - self.assertTrue( - pre_save.has_listeners(LogEntry) - ) # The signal should be present before trying to disconnect it. - self.middleware.process_exception(request, ValidationError("Test")) + SomeException = type("SomeException", (Exception,), {}) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + self.get_response_mock.side_effect = SomeException + + with self.assertRaises(SomeException): + self.middleware(request) + + self.assert_no_listeners() class SimpeIncludeModelTest(TestCase): diff --git a/tox.ini b/tox.ini index d95efe1..e6e8f74 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,7 @@ deps = # Test requirements coverage codecov + mock psycopg2-binary passenv= TEST_DB_HOST