Add set_actor context manager (#262)

This commit is contained in:
Alieh Rymašeŭski 2022-05-24 10:33:54 +03:00 committed by GitHub
parent ba19a8ca35
commit bcd0d43566
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 265 additions and 160 deletions

View file

@ -3,6 +3,7 @@
#### Improvements
- feat: Add register model from settings ([#368](https://github.com/jazzband/django-auditlog/pull/368))
- Context manager set_actor() for use in Celery tasks ([#262](https://github.com/jazzband/django-auditlog/pull/262))
#### Fixes

66
auditlog/context.py Normal file
View file

@ -0,0 +1,66 @@
import contextlib
import threading
import time
from functools import partial
from django.contrib.auth import get_user_model
from django.db.models.signals import pre_save
from auditlog.models import LogEntry
threadlocal = threading.local()
@contextlib.contextmanager
def set_actor(actor, remote_addr=None):
"""Connect a signal receiver with current user attached."""
# Initialize thread local storage
threadlocal.auditlog = {
"signal_duid": ("set_actor", time.time()),
"remote_addr": remote_addr,
}
# Connect signal for automatic logging
set_actor = partial(
_set_actor, user=actor, signal_duid=threadlocal.auditlog["signal_duid"]
)
pre_save.connect(
set_actor,
sender=LogEntry,
dispatch_uid=threadlocal.auditlog["signal_duid"],
weak=False,
)
try:
yield
finally:
try:
auditlog = threadlocal.auditlog
except AttributeError:
pass
else:
pre_save.disconnect(sender=LogEntry, dispatch_uid=auditlog["signal_duid"])
del threadlocal.auditlog
def _set_actor(user, sender, instance, signal_duid, **kwargs):
"""Signal receiver with extra 'user' and 'signal_duid' kwargs.
This function becomes a valid signal receiver when it is curried with the actor and a dispatch id.
"""
try:
auditlog = threadlocal.auditlog
except AttributeError:
pass
else:
if signal_duid != auditlog["signal_duid"]:
return
auth_user_model = get_user_model()
if (
sender == LogEntry
and isinstance(user, auth_user_model)
and instance.actor is None
):
instance.actor = user
instance.remote_addr = auditlog["remote_addr"]

View file

@ -1,97 +1,31 @@
import threading
import time
from functools import partial
import contextlib
from django.apps import apps
from django.conf import settings
from django.db.models.signals import pre_save
from django.utils.deprecation import MiddlewareMixin
from auditlog.models import LogEntry
threadlocal = threading.local()
from auditlog.context import set_actor
class AuditlogMiddleware(MiddlewareMixin):
class AuditlogMiddleware:
"""
Middleware to couple the request's user to log items. This is accomplished by currying the signal receiver with the
user from the request (or None if the user is not authenticated).
"""
def process_request(self, request):
"""
Gets the current user from the request and prepares and connects a signal receiver with the user already
attached to it.
"""
# Initialize thread local storage
threadlocal.auditlog = {
"signal_duid": (self.__class__, time.time()),
"remote_addr": request.META.get("REMOTE_ADDR"),
}
def __init__(self, get_response=None):
self.get_response = get_response
def __call__(self, request):
# In case of proxy, set 'original' address
if request.META.get("HTTP_X_FORWARDED_FOR"):
threadlocal.auditlog["remote_addr"] = request.META.get(
"HTTP_X_FORWARDED_FOR"
).split(",")[0]
# In case of proxy, set 'original' address
remote_addr = request.META.get("HTTP_X_FORWARDED_FOR").split(",")[0]
else:
remote_addr = request.META.get("REMOTE_ADDR")
# Connect signal for automatic logging
if hasattr(request, "user") and getattr(
request.user, "is_authenticated", False
):
set_actor = partial(
self.set_actor,
user=request.user,
signal_duid=threadlocal.auditlog["signal_duid"],
)
pre_save.connect(
set_actor,
sender=LogEntry,
dispatch_uid=threadlocal.auditlog["signal_duid"],
weak=False,
)
context = set_actor(actor=request.user, remote_addr=remote_addr)
else:
context = contextlib.nullcontext()
def process_response(self, request, response):
"""
Disconnects the signal receiver to prevent it from staying active.
"""
if hasattr(threadlocal, "auditlog"):
pre_save.disconnect(
sender=LogEntry, dispatch_uid=threadlocal.auditlog["signal_duid"]
)
return response
def process_exception(self, request, exception):
"""
Disconnects the signal receiver to prevent it from staying active in case of an exception.
"""
if hasattr(threadlocal, "auditlog"):
pre_save.disconnect(
sender=LogEntry, dispatch_uid=threadlocal.auditlog["signal_duid"]
)
return None
@staticmethod
def set_actor(user, sender, instance, signal_duid, **kwargs):
"""
Signal receiver with an extra, required 'user' kwarg. This method becomes a real (valid) signal receiver when
it is curried with the actor.
"""
if hasattr(threadlocal, "auditlog"):
if signal_duid != threadlocal.auditlog["signal_duid"]:
return
try:
app_label, model_name = settings.AUTH_USER_MODEL.split(".")
auth_user_model = apps.get_model(app_label, model_name)
except ValueError:
auth_user_model = apps.get_model("auth", "user")
if (
sender == LogEntry
and isinstance(user, auth_user_model)
and instance.actor is None
):
instance.actor = user
instance.remote_addr = threadlocal.auditlog["remote_addr"]
with context:
return self.get_response(request)

View file

@ -1,18 +1,20 @@
import datetime
import itertools
import json
from unittest import mock
import django
from dateutil.tz import gettz
from django.apps import apps
from django.conf import settings
from django.contrib import auth
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser, User
from django.core.exceptions import ValidationError
from django.contrib.contenttypes.models import ContentType
from django.db.models.signals import pre_save
from django.http import HttpResponse
from django.test import RequestFactory, TestCase, override_settings
from django.utils import dateformat, formats, timezone
from auditlog.context import set_actor
from auditlog.diff import model_instance_diff
from auditlog.middleware import AuditlogMiddleware
from auditlog.models import LogEntry
@ -40,7 +42,11 @@ from auditlog_tests.models import (
class SimpleModelTest(TestCase):
def setUp(self):
self.obj = SimpleModel.objects.create(text="I am not difficult.")
self.obj = self.make_object()
super().setUp()
def make_object(self):
return SimpleModel.objects.create(text="I am not difficult.")
def test_create(self):
"""Creation is logged correctly."""
@ -50,17 +56,14 @@ class SimpleModelTest(TestCase):
# Check for log entries
self.assertEqual(obj.history.count(), 1, msg="There is one log entry")
try:
history = obj.history.get()
except obj.history.DoesNotExist:
self.assertTrue(False, "Log entry exists")
else:
self.assertEqual(
history.action, LogEntry.Action.CREATE, msg="Action is 'CREATE'"
)
self.assertEqual(
history.object_repr, str(obj), msg="Representation is equal"
)
history = obj.history.get()
self.check_create_log_entry(obj, history)
def check_create_log_entry(self, obj, history):
self.assertEqual(
history.action, LogEntry.Action.CREATE, msg="Action is 'CREATE'"
)
self.assertEqual(history.object_repr, str(obj), msg="Representation is equal")
def test_update(self):
"""Updates are logged correctly."""
@ -68,8 +71,7 @@ class SimpleModelTest(TestCase):
obj = self.obj
# Change something
obj.boolean = True
obj.save()
self.update(obj)
# Check for log entries
self.assertEqual(
@ -79,7 +81,13 @@ class SimpleModelTest(TestCase):
)
history = obj.history.get(action=LogEntry.Action.UPDATE)
self.check_update_log_entry(obj, history)
def update(self, obj):
obj.boolean = True
obj.save()
def check_update_log_entry(self, obj, history):
self.assertJSONEqual(
history.changes,
'{"boolean": ["False", "True"]}',
@ -134,25 +142,27 @@ class SimpleModelTest(TestCase):
"""Deletion is logged correctly."""
# Get the object to work with
obj = self.obj
history = obj.history.latest()
content_type = ContentType.objects.get_for_model(obj.__class__)
pk = obj.pk
# Delete the object
obj.delete()
self.delete(obj)
# Check for log entries
self.assertEqual(
LogEntry.objects.filter(
content_type=history.content_type,
object_pk=history.object_pk,
action=LogEntry.Action.DELETE,
).count(),
1,
msg="There is one log entry for 'DELETE'",
)
qs = LogEntry.objects.filter(content_type=content_type, object_pk=pk)
self.assertEqual(qs.count(), 1, msg="There is one log entry for 'DELETE'")
history = qs.get()
self.check_delete_log_entry(obj, history)
def delete(self, obj):
obj.delete()
def check_delete_log_entry(self, obj, history):
pass
def test_recreate(self):
SimpleModel.objects.all().delete()
self.obj.delete()
self.setUp()
self.test_create()
@ -175,16 +185,79 @@ class SimpleModelTest(TestCase):
) # must be created in default database
class AltPrimaryKeyModelTest(SimpleModelTest):
class NoActorMixin:
def check_create_log_entry(self, obj, log_entry):
super().check_create_log_entry(obj, log_entry)
self.assertIsNone(log_entry.actor)
def check_update_log_entry(self, obj, log_entry):
super().check_update_log_entry(obj, log_entry)
self.assertIsNone(log_entry.actor)
def check_delete_log_entry(self, obj, log_entry):
super().check_delete_log_entry(obj, log_entry)
self.assertIsNone(log_entry.actor)
class WithActorMixin:
sequence = itertools.count()
def setUp(self):
self.obj = AltPrimaryKeyModel.objects.create(
username = "actor_{}".format(next(self.sequence))
self.user = get_user_model().objects.create(
username=username,
email="{}@example.com".format(username),
password="secret",
)
super().setUp()
def tearDown(self):
self.user.delete()
super().tearDown()
def make_object(self):
with set_actor(self.user):
return super().make_object()
def check_create_log_entry(self, obj, log_entry):
super().check_create_log_entry(obj, log_entry)
self.assertEqual(log_entry.actor, self.user)
def update(self, obj):
with set_actor(self.user):
return super().update(obj)
def check_update_log_entry(self, obj, log_entry):
super().check_update_log_entry(obj, log_entry)
self.assertEqual(log_entry.actor, self.user)
def delete(self, obj):
with set_actor(self.user):
return super().delete(obj)
def check_delete_log_entry(self, obj, log_entry):
super().check_delete_log_entry(obj, log_entry)
self.assertEqual(log_entry.actor, self.user)
class AltPrimaryKeyModelBase(SimpleModelTest):
def make_object(self):
return AltPrimaryKeyModel.objects.create(
key=str(datetime.datetime.now()), text="I am strange."
)
class UUIDPrimaryKeyModelModelTest(SimpleModelTest):
def setUp(self):
self.obj = UUIDPrimaryKeyModel.objects.create(text="I am strange.")
class AltPrimaryKeyModelTest(NoActorMixin, AltPrimaryKeyModelBase):
pass
class AltPrimaryKeyModelWithActorTest(WithActorMixin, AltPrimaryKeyModelBase):
pass
class UUIDPrimaryKeyModelModelBase(SimpleModelTest):
def make_object(self):
return UUIDPrimaryKeyModel.objects.create(text="I am strange.")
def test_get_for_object(self):
self.obj.boolean = True
@ -202,9 +275,27 @@ class UUIDPrimaryKeyModelModelTest(SimpleModelTest):
)
class ProxyModelTest(SimpleModelTest):
def setUp(self):
self.obj = ProxyModel.objects.create(text="I am not what you think.")
class UUIDPrimaryKeyModelModelTest(NoActorMixin, UUIDPrimaryKeyModelModelBase):
pass
class UUIDPrimaryKeyModelModelWithActorTest(
WithActorMixin, UUIDPrimaryKeyModelModelBase
):
pass
class ProxyModelBase(SimpleModelTest):
def make_object(self):
return ProxyModel.objects.create(text="I am not what you think.")
class ProxyModelTest(NoActorMixin, ProxyModelBase):
pass
class ProxyModelWithActorTest(WithActorMixin, ProxyModelBase):
pass
class ManyRelatedModelTest(TestCase):
@ -237,72 +328,66 @@ class MiddlewareTest(TestCase):
def get_response(request):
return HttpResponse()
self.middleware = AuditlogMiddleware(get_response)
self.get_response_mock = mock.Mock()
self.response_mock = mock.Mock()
self.middleware = AuditlogMiddleware(get_response=self.get_response_mock)
self.factory = RequestFactory()
self.user = User.objects.create_user(
username="test", email="test@example.com", password="top_secret"
)
def side_effect(self, assertion):
def inner(request):
assertion()
return self.response_mock
return inner
def assert_has_listeners(self):
self.assertTrue(pre_save.has_listeners(LogEntry))
def assert_no_listeners(self):
self.assertFalse(pre_save.has_listeners(LogEntry))
def test_request_anonymous(self):
"""No actor will be logged when a user is not logged in."""
# Create a request
request = self.factory.get("/")
request.user = AnonymousUser()
# Run middleware
self.middleware.process_request(request)
self.get_response_mock.side_effect = self.side_effect(self.assert_no_listeners)
# Validate result
self.assertFalse(pre_save.has_listeners(LogEntry))
response = self.middleware(request)
# Finalize transaction
self.middleware.process_exception(request, None)
self.assertIs(response, self.response_mock)
self.get_response_mock.assert_called_once_with(request)
self.assert_no_listeners()
def test_request(self):
"""The actor will be logged when a user is logged in."""
# Create a request
request = self.factory.get("/")
request.user = self.user
# Run middleware
self.middleware.process_request(request)
# Validate result
self.assertTrue(pre_save.has_listeners(LogEntry))
# Finalize transaction
self.middleware.process_exception(request, None)
def test_response(self):
"""The signal will be disconnected when the request is processed."""
# Create a request
request = self.factory.get("/")
request.user = self.user
# Run middleware
self.middleware.process_request(request)
self.assertTrue(
pre_save.has_listeners(LogEntry)
) # The signal should be present before trying to disconnect it.
self.middleware.process_response(request, HttpResponse())
self.get_response_mock.side_effect = self.side_effect(self.assert_has_listeners)
# Validate result
self.assertFalse(pre_save.has_listeners(LogEntry))
response = self.middleware(request)
self.assertIs(response, self.response_mock)
self.get_response_mock.assert_called_once_with(request)
self.assert_no_listeners()
def test_exception(self):
"""The signal will be disconnected when an exception is raised."""
# Create a request
request = self.factory.get("/")
request.user = self.user
# Run middleware
self.middleware.process_request(request)
self.assertTrue(
pre_save.has_listeners(LogEntry)
) # The signal should be present before trying to disconnect it.
self.middleware.process_exception(request, ValidationError("Test"))
SomeException = type("SomeException", (Exception,), {})
# Validate result
self.assertFalse(pre_save.has_listeners(LogEntry))
self.get_response_mock.side_effect = SomeException
with self.assertRaises(SomeException):
self.middleware(request)
self.assert_no_listeners()
class SimpleIncludeModelTest(TestCase):

View file

@ -148,6 +148,9 @@ It must be a list or tuple. Each item in this setting can be a:
Actors
------
Middleware
**********
When using automatic logging, the actor is empty by default. However, auditlog can set the actor from the current
request automatically. This does not need any custom code, adding a middleware class is enough. When an actor is logged
the remote address of that actor will be logged as well.
@ -169,6 +172,22 @@ It is recommended to keep all middleware that alters the request loaded before A
user as actor. To only have some object changes to be logged with the current request's user as actor manual logging is
required.
Context manager
***************
.. versionadded:: 2.1.0
To enable the automatic logging of the actors outside of request context (e.g. in a Celery task), you can use a context
manager::
from auditlog.context import set_actor
def do_stuff(actor_id: int):
actor = get_user(actor_id)
with set_actor(actor):
# if your code here leads to creation of LogEntry instances, these will have the actor set
...
Object history
--------------