diff --git a/CHANGELOG.md b/CHANGELOG.md index 0880ba5..12b151c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ #### Improvements - feat: Add register model from settings ([#368](https://github.com/jazzband/django-auditlog/pull/368)) +- Context manager set_actor() for use in Celery tasks ([#262](https://github.com/jazzband/django-auditlog/pull/262)) #### Fixes diff --git a/auditlog/context.py b/auditlog/context.py new file mode 100644 index 0000000..6e9513d --- /dev/null +++ b/auditlog/context.py @@ -0,0 +1,66 @@ +import contextlib +import threading +import time +from functools import partial + +from django.contrib.auth import get_user_model +from django.db.models.signals import pre_save + +from auditlog.models import LogEntry + +threadlocal = threading.local() + + +@contextlib.contextmanager +def set_actor(actor, remote_addr=None): + """Connect a signal receiver with current user attached.""" + # Initialize thread local storage + threadlocal.auditlog = { + "signal_duid": ("set_actor", time.time()), + "remote_addr": remote_addr, + } + + # Connect signal for automatic logging + set_actor = partial( + _set_actor, user=actor, signal_duid=threadlocal.auditlog["signal_duid"] + ) + pre_save.connect( + set_actor, + sender=LogEntry, + dispatch_uid=threadlocal.auditlog["signal_duid"], + weak=False, + ) + + try: + yield + finally: + try: + auditlog = threadlocal.auditlog + except AttributeError: + pass + else: + pre_save.disconnect(sender=LogEntry, dispatch_uid=auditlog["signal_duid"]) + del threadlocal.auditlog + + +def _set_actor(user, sender, instance, signal_duid, **kwargs): + """Signal receiver with extra 'user' and 'signal_duid' kwargs. + + This function becomes a valid signal receiver when it is curried with the actor and a dispatch id. + """ + try: + auditlog = threadlocal.auditlog + except AttributeError: + pass + else: + if signal_duid != auditlog["signal_duid"]: + return + auth_user_model = get_user_model() + if ( + sender == LogEntry + and isinstance(user, auth_user_model) + and instance.actor is None + ): + instance.actor = user + + instance.remote_addr = auditlog["remote_addr"] diff --git a/auditlog/middleware.py b/auditlog/middleware.py index ec1f563..202f114 100644 --- a/auditlog/middleware.py +++ b/auditlog/middleware.py @@ -1,97 +1,31 @@ -import threading -import time -from functools import partial +import contextlib -from django.apps import apps -from django.conf import settings -from django.db.models.signals import pre_save -from django.utils.deprecation import MiddlewareMixin - -from auditlog.models import LogEntry - -threadlocal = threading.local() +from auditlog.context import set_actor -class AuditlogMiddleware(MiddlewareMixin): +class AuditlogMiddleware: """ Middleware to couple the request's user to log items. This is accomplished by currying the signal receiver with the user from the request (or None if the user is not authenticated). """ - def process_request(self, request): - """ - Gets the current user from the request and prepares and connects a signal receiver with the user already - attached to it. - """ - # Initialize thread local storage - threadlocal.auditlog = { - "signal_duid": (self.__class__, time.time()), - "remote_addr": request.META.get("REMOTE_ADDR"), - } + def __init__(self, get_response=None): + self.get_response = get_response + + def __call__(self, request): - # In case of proxy, set 'original' address if request.META.get("HTTP_X_FORWARDED_FOR"): - threadlocal.auditlog["remote_addr"] = request.META.get( - "HTTP_X_FORWARDED_FOR" - ).split(",")[0] + # In case of proxy, set 'original' address + remote_addr = request.META.get("HTTP_X_FORWARDED_FOR").split(",")[0] + else: + remote_addr = request.META.get("REMOTE_ADDR") - # Connect signal for automatic logging if hasattr(request, "user") and getattr( request.user, "is_authenticated", False ): - set_actor = partial( - self.set_actor, - user=request.user, - signal_duid=threadlocal.auditlog["signal_duid"], - ) - pre_save.connect( - set_actor, - sender=LogEntry, - dispatch_uid=threadlocal.auditlog["signal_duid"], - weak=False, - ) + context = set_actor(actor=request.user, remote_addr=remote_addr) + else: + context = contextlib.nullcontext() - def process_response(self, request, response): - """ - Disconnects the signal receiver to prevent it from staying active. - """ - if hasattr(threadlocal, "auditlog"): - pre_save.disconnect( - sender=LogEntry, dispatch_uid=threadlocal.auditlog["signal_duid"] - ) - - return response - - def process_exception(self, request, exception): - """ - Disconnects the signal receiver to prevent it from staying active in case of an exception. - """ - if hasattr(threadlocal, "auditlog"): - pre_save.disconnect( - sender=LogEntry, dispatch_uid=threadlocal.auditlog["signal_duid"] - ) - - return None - - @staticmethod - def set_actor(user, sender, instance, signal_duid, **kwargs): - """ - Signal receiver with an extra, required 'user' kwarg. This method becomes a real (valid) signal receiver when - it is curried with the actor. - """ - if hasattr(threadlocal, "auditlog"): - if signal_duid != threadlocal.auditlog["signal_duid"]: - return - try: - app_label, model_name = settings.AUTH_USER_MODEL.split(".") - auth_user_model = apps.get_model(app_label, model_name) - except ValueError: - auth_user_model = apps.get_model("auth", "user") - if ( - sender == LogEntry - and isinstance(user, auth_user_model) - and instance.actor is None - ): - instance.actor = user - - instance.remote_addr = threadlocal.auditlog["remote_addr"] + with context: + return self.get_response(request) diff --git a/auditlog_tests/tests.py b/auditlog_tests/tests.py index 5b18b5a..2f4f28c 100644 --- a/auditlog_tests/tests.py +++ b/auditlog_tests/tests.py @@ -1,18 +1,20 @@ import datetime +import itertools import json +from unittest import mock -import django from dateutil.tz import gettz from django.apps import apps from django.conf import settings -from django.contrib import auth +from django.contrib.auth import get_user_model from django.contrib.auth.models import AnonymousUser, User -from django.core.exceptions import ValidationError +from django.contrib.contenttypes.models import ContentType from django.db.models.signals import pre_save from django.http import HttpResponse from django.test import RequestFactory, TestCase, override_settings from django.utils import dateformat, formats, timezone +from auditlog.context import set_actor from auditlog.diff import model_instance_diff from auditlog.middleware import AuditlogMiddleware from auditlog.models import LogEntry @@ -40,7 +42,11 @@ from auditlog_tests.models import ( class SimpleModelTest(TestCase): def setUp(self): - self.obj = SimpleModel.objects.create(text="I am not difficult.") + self.obj = self.make_object() + super().setUp() + + def make_object(self): + return SimpleModel.objects.create(text="I am not difficult.") def test_create(self): """Creation is logged correctly.""" @@ -50,17 +56,14 @@ class SimpleModelTest(TestCase): # Check for log entries self.assertEqual(obj.history.count(), 1, msg="There is one log entry") - try: - history = obj.history.get() - except obj.history.DoesNotExist: - self.assertTrue(False, "Log entry exists") - else: - self.assertEqual( - history.action, LogEntry.Action.CREATE, msg="Action is 'CREATE'" - ) - self.assertEqual( - history.object_repr, str(obj), msg="Representation is equal" - ) + history = obj.history.get() + self.check_create_log_entry(obj, history) + + def check_create_log_entry(self, obj, history): + self.assertEqual( + history.action, LogEntry.Action.CREATE, msg="Action is 'CREATE'" + ) + self.assertEqual(history.object_repr, str(obj), msg="Representation is equal") def test_update(self): """Updates are logged correctly.""" @@ -68,8 +71,7 @@ class SimpleModelTest(TestCase): obj = self.obj # Change something - obj.boolean = True - obj.save() + self.update(obj) # Check for log entries self.assertEqual( @@ -79,7 +81,13 @@ class SimpleModelTest(TestCase): ) history = obj.history.get(action=LogEntry.Action.UPDATE) + self.check_update_log_entry(obj, history) + def update(self, obj): + obj.boolean = True + obj.save() + + def check_update_log_entry(self, obj, history): self.assertJSONEqual( history.changes, '{"boolean": ["False", "True"]}', @@ -134,25 +142,27 @@ class SimpleModelTest(TestCase): """Deletion is logged correctly.""" # Get the object to work with obj = self.obj - - history = obj.history.latest() + content_type = ContentType.objects.get_for_model(obj.__class__) + pk = obj.pk # Delete the object - obj.delete() + self.delete(obj) # Check for log entries - self.assertEqual( - LogEntry.objects.filter( - content_type=history.content_type, - object_pk=history.object_pk, - action=LogEntry.Action.DELETE, - ).count(), - 1, - msg="There is one log entry for 'DELETE'", - ) + qs = LogEntry.objects.filter(content_type=content_type, object_pk=pk) + self.assertEqual(qs.count(), 1, msg="There is one log entry for 'DELETE'") + + history = qs.get() + self.check_delete_log_entry(obj, history) + + def delete(self, obj): + obj.delete() + + def check_delete_log_entry(self, obj, history): + pass def test_recreate(self): - SimpleModel.objects.all().delete() + self.obj.delete() self.setUp() self.test_create() @@ -175,16 +185,79 @@ class SimpleModelTest(TestCase): ) # must be created in default database -class AltPrimaryKeyModelTest(SimpleModelTest): +class NoActorMixin: + def check_create_log_entry(self, obj, log_entry): + super().check_create_log_entry(obj, log_entry) + self.assertIsNone(log_entry.actor) + + def check_update_log_entry(self, obj, log_entry): + super().check_update_log_entry(obj, log_entry) + self.assertIsNone(log_entry.actor) + + def check_delete_log_entry(self, obj, log_entry): + super().check_delete_log_entry(obj, log_entry) + self.assertIsNone(log_entry.actor) + + +class WithActorMixin: + sequence = itertools.count() + def setUp(self): - self.obj = AltPrimaryKeyModel.objects.create( + username = "actor_{}".format(next(self.sequence)) + self.user = get_user_model().objects.create( + username=username, + email="{}@example.com".format(username), + password="secret", + ) + super().setUp() + + def tearDown(self): + self.user.delete() + super().tearDown() + + def make_object(self): + with set_actor(self.user): + return super().make_object() + + def check_create_log_entry(self, obj, log_entry): + super().check_create_log_entry(obj, log_entry) + self.assertEqual(log_entry.actor, self.user) + + def update(self, obj): + with set_actor(self.user): + return super().update(obj) + + def check_update_log_entry(self, obj, log_entry): + super().check_update_log_entry(obj, log_entry) + self.assertEqual(log_entry.actor, self.user) + + def delete(self, obj): + with set_actor(self.user): + return super().delete(obj) + + def check_delete_log_entry(self, obj, log_entry): + super().check_delete_log_entry(obj, log_entry) + self.assertEqual(log_entry.actor, self.user) + + +class AltPrimaryKeyModelBase(SimpleModelTest): + def make_object(self): + return AltPrimaryKeyModel.objects.create( key=str(datetime.datetime.now()), text="I am strange." ) -class UUIDPrimaryKeyModelModelTest(SimpleModelTest): - def setUp(self): - self.obj = UUIDPrimaryKeyModel.objects.create(text="I am strange.") +class AltPrimaryKeyModelTest(NoActorMixin, AltPrimaryKeyModelBase): + pass + + +class AltPrimaryKeyModelWithActorTest(WithActorMixin, AltPrimaryKeyModelBase): + pass + + +class UUIDPrimaryKeyModelModelBase(SimpleModelTest): + def make_object(self): + return UUIDPrimaryKeyModel.objects.create(text="I am strange.") def test_get_for_object(self): self.obj.boolean = True @@ -202,9 +275,27 @@ class UUIDPrimaryKeyModelModelTest(SimpleModelTest): ) -class ProxyModelTest(SimpleModelTest): - def setUp(self): - self.obj = ProxyModel.objects.create(text="I am not what you think.") +class UUIDPrimaryKeyModelModelTest(NoActorMixin, UUIDPrimaryKeyModelModelBase): + pass + + +class UUIDPrimaryKeyModelModelWithActorTest( + WithActorMixin, UUIDPrimaryKeyModelModelBase +): + pass + + +class ProxyModelBase(SimpleModelTest): + def make_object(self): + return ProxyModel.objects.create(text="I am not what you think.") + + +class ProxyModelTest(NoActorMixin, ProxyModelBase): + pass + + +class ProxyModelWithActorTest(WithActorMixin, ProxyModelBase): + pass class ManyRelatedModelTest(TestCase): @@ -237,72 +328,66 @@ class MiddlewareTest(TestCase): def get_response(request): return HttpResponse() - self.middleware = AuditlogMiddleware(get_response) + self.get_response_mock = mock.Mock() + self.response_mock = mock.Mock() + self.middleware = AuditlogMiddleware(get_response=self.get_response_mock) self.factory = RequestFactory() self.user = User.objects.create_user( username="test", email="test@example.com", password="top_secret" ) + def side_effect(self, assertion): + def inner(request): + assertion() + return self.response_mock + + return inner + + def assert_has_listeners(self): + self.assertTrue(pre_save.has_listeners(LogEntry)) + + def assert_no_listeners(self): + self.assertFalse(pre_save.has_listeners(LogEntry)) + def test_request_anonymous(self): """No actor will be logged when a user is not logged in.""" - # Create a request request = self.factory.get("/") request.user = AnonymousUser() - # Run middleware - self.middleware.process_request(request) + self.get_response_mock.side_effect = self.side_effect(self.assert_no_listeners) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + response = self.middleware(request) - # Finalize transaction - self.middleware.process_exception(request, None) + self.assertIs(response, self.response_mock) + self.get_response_mock.assert_called_once_with(request) + self.assert_no_listeners() def test_request(self): """The actor will be logged when a user is logged in.""" - # Create a request - request = self.factory.get("/") - request.user = self.user - # Run middleware - self.middleware.process_request(request) - - # Validate result - self.assertTrue(pre_save.has_listeners(LogEntry)) - - # Finalize transaction - self.middleware.process_exception(request, None) - - def test_response(self): - """The signal will be disconnected when the request is processed.""" - # Create a request request = self.factory.get("/") request.user = self.user - # Run middleware - self.middleware.process_request(request) - self.assertTrue( - pre_save.has_listeners(LogEntry) - ) # The signal should be present before trying to disconnect it. - self.middleware.process_response(request, HttpResponse()) + self.get_response_mock.side_effect = self.side_effect(self.assert_has_listeners) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + response = self.middleware(request) + + self.assertIs(response, self.response_mock) + self.get_response_mock.assert_called_once_with(request) + self.assert_no_listeners() def test_exception(self): """The signal will be disconnected when an exception is raised.""" - # Create a request request = self.factory.get("/") request.user = self.user - # Run middleware - self.middleware.process_request(request) - self.assertTrue( - pre_save.has_listeners(LogEntry) - ) # The signal should be present before trying to disconnect it. - self.middleware.process_exception(request, ValidationError("Test")) + SomeException = type("SomeException", (Exception,), {}) - # Validate result - self.assertFalse(pre_save.has_listeners(LogEntry)) + self.get_response_mock.side_effect = SomeException + + with self.assertRaises(SomeException): + self.middleware(request) + + self.assert_no_listeners() class SimpleIncludeModelTest(TestCase): diff --git a/docs/source/usage.rst b/docs/source/usage.rst index f650c7c..d0b3649 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -148,6 +148,9 @@ It must be a list or tuple. Each item in this setting can be a: Actors ------ +Middleware +********** + When using automatic logging, the actor is empty by default. However, auditlog can set the actor from the current request automatically. This does not need any custom code, adding a middleware class is enough. When an actor is logged the remote address of that actor will be logged as well. @@ -169,6 +172,22 @@ It is recommended to keep all middleware that alters the request loaded before A user as actor. To only have some object changes to be logged with the current request's user as actor manual logging is required. +Context manager +*************** + +.. versionadded:: 2.1.0 + +To enable the automatic logging of the actors outside of request context (e.g. in a Celery task), you can use a context +manager:: + + from auditlog.context import set_actor + + def do_stuff(actor_id: int): + actor = get_user(actor_id) + with set_actor(actor): + # if your code here leads to creation of LogEntry instances, these will have the actor set + ... + Object history --------------