mirror of
https://github.com/jazzband/django-auditlog.git
synced 2026-03-17 06:30:27 +00:00
Move signal management to a context manager
This change allows setting the same signals when the request is not present, i.e. in a celery task.
This commit is contained in:
parent
3eb5d66c39
commit
9629f3f8d7
4 changed files with 124 additions and 116 deletions
66
auditlog/context.py
Normal file
66
auditlog/context.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.db.models.signals import pre_save
|
||||
|
||||
from auditlog.models import LogEntry
|
||||
|
||||
threadlocal = threading.local()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_actor(actor, remote_addr=None):
|
||||
"""Connect a signal receiver with current user attached."""
|
||||
# Initialize thread local storage
|
||||
threadlocal.auditlog = {
|
||||
"signal_duid": ("set_actor", time.time()),
|
||||
"remote_addr": remote_addr,
|
||||
}
|
||||
|
||||
# Connect signal for automatic logging
|
||||
set_actor = partial(
|
||||
_set_actor, user=actor, signal_duid=threadlocal.auditlog["signal_duid"]
|
||||
)
|
||||
pre_save.connect(
|
||||
set_actor,
|
||||
sender=LogEntry,
|
||||
dispatch_uid=threadlocal.auditlog["signal_duid"],
|
||||
weak=False,
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
|
||||
finally:
|
||||
try:
|
||||
auditlog = threadlocal.auditlog
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
pre_save.disconnect(sender=LogEntry, dispatch_uid=auditlog["signal_duid"])
|
||||
|
||||
|
||||
def _set_actor(user, sender, instance, signal_duid, **kwargs):
|
||||
"""Signal receiver with an extra 'user' kwarg.
|
||||
|
||||
This function becomes a valid signal receiver when it is curried with the actor.
|
||||
"""
|
||||
try:
|
||||
auditlog = threadlocal.auditlog
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if signal_duid != auditlog["signal_duid"]:
|
||||
return
|
||||
auth_user_model = get_user_model()
|
||||
if (
|
||||
sender == LogEntry
|
||||
and isinstance(user, auth_user_model)
|
||||
and instance.actor is None
|
||||
):
|
||||
instance.actor = user
|
||||
|
||||
instance.remote_addr = auditlog["remote_addr"]
|
||||
|
|
@ -1,91 +1,37 @@
|
|||
import threading
|
||||
import time
|
||||
from functools import partial
|
||||
import contextlib
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.db.models.signals import pre_save
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
|
||||
from auditlog.models import LogEntry
|
||||
|
||||
threadlocal = threading.local()
|
||||
from auditlog.context import set_actor
|
||||
|
||||
|
||||
class AuditlogMiddleware(MiddlewareMixin):
|
||||
@contextlib.contextmanager
|
||||
def nullcontext():
|
||||
"""Equivalent to contextlib.nullcontext(None) from Python 3.7."""
|
||||
yield
|
||||
|
||||
|
||||
class AuditlogMiddleware(object):
|
||||
"""
|
||||
Middleware to couple the request's user to log items. This is accomplished by currying the signal receiver with the
|
||||
user from the request (or None if the user is not authenticated).
|
||||
"""
|
||||
|
||||
def process_request(self, request):
|
||||
"""
|
||||
Gets the current user from the request and prepares and connects a signal receiver with the user already
|
||||
attached to it.
|
||||
"""
|
||||
# Initialize thread local storage
|
||||
threadlocal.auditlog = {
|
||||
"signal_duid": (self.__class__, time.time()),
|
||||
"remote_addr": request.META.get("REMOTE_ADDR"),
|
||||
}
|
||||
def __init__(self, get_response=None):
|
||||
self.get_response = get_response
|
||||
|
||||
def __call__(self, request):
|
||||
|
||||
# In case of proxy, set 'original' address
|
||||
if request.META.get("HTTP_X_FORWARDED_FOR"):
|
||||
threadlocal.auditlog["remote_addr"] = request.META.get(
|
||||
"HTTP_X_FORWARDED_FOR"
|
||||
).split(",")[0]
|
||||
# In case of proxy, set 'original' address
|
||||
remote_addr = request.META.get("HTTP_X_FORWARDED_FOR").split(",")[0]
|
||||
else:
|
||||
remote_addr = request.META.get("REMOTE_ADDR")
|
||||
|
||||
# Connect signal for automatic logging
|
||||
if hasattr(request, "user") and getattr(
|
||||
request.user, "is_authenticated", False
|
||||
):
|
||||
set_actor = partial(
|
||||
self.set_actor,
|
||||
user=request.user,
|
||||
signal_duid=threadlocal.auditlog["signal_duid"],
|
||||
)
|
||||
pre_save.connect(
|
||||
set_actor,
|
||||
sender=LogEntry,
|
||||
dispatch_uid=threadlocal.auditlog["signal_duid"],
|
||||
weak=False,
|
||||
)
|
||||
context = set_actor(actor=request.user, remote_addr=remote_addr)
|
||||
else:
|
||||
context = nullcontext()
|
||||
|
||||
def process_response(self, request, response):
|
||||
"""
|
||||
Disconnects the signal receiver to prevent it from staying active.
|
||||
"""
|
||||
if hasattr(threadlocal, "auditlog"):
|
||||
pre_save.disconnect(
|
||||
sender=LogEntry, dispatch_uid=threadlocal.auditlog["signal_duid"]
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def process_exception(self, request, exception):
|
||||
"""
|
||||
Disconnects the signal receiver to prevent it from staying active in case of an exception.
|
||||
"""
|
||||
if hasattr(threadlocal, "auditlog"):
|
||||
pre_save.disconnect(
|
||||
sender=LogEntry, dispatch_uid=threadlocal.auditlog["signal_duid"]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_actor(user, sender, instance, signal_duid, **kwargs):
|
||||
"""
|
||||
Signal receiver with an extra, required 'user' kwarg. This method becomes a real (valid) signal receiver when
|
||||
it is curried with the actor.
|
||||
"""
|
||||
if hasattr(threadlocal, "auditlog"):
|
||||
if signal_duid != threadlocal.auditlog["signal_duid"]:
|
||||
return
|
||||
if (
|
||||
sender == LogEntry
|
||||
and isinstance(user, get_user_model())
|
||||
and instance.actor is None
|
||||
):
|
||||
instance.actor = user
|
||||
|
||||
instance.remote_addr = threadlocal.auditlog["remote_addr"]
|
||||
with context:
|
||||
return self.get_response(request)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import datetime
|
|||
import json
|
||||
|
||||
import django
|
||||
import mock
|
||||
from dateutil.tz import gettz
|
||||
from django.conf import settings
|
||||
from django.contrib import auth
|
||||
|
|
@ -167,72 +168,66 @@ class MiddlewareTest(TestCase):
|
|||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.middleware = AuditlogMiddleware()
|
||||
self.get_response_mock = mock.Mock()
|
||||
self.response_mock = mock.Mock()
|
||||
self.middleware = AuditlogMiddleware(get_response=self.get_response_mock)
|
||||
self.factory = RequestFactory()
|
||||
self.user = User.objects.create_user(
|
||||
username="test", email="test@example.com", password="top_secret"
|
||||
)
|
||||
|
||||
def side_effect(self, assertion):
|
||||
def inner(request):
|
||||
assertion()
|
||||
return self.response_mock
|
||||
|
||||
return inner
|
||||
|
||||
def assert_has_listeners(self):
|
||||
self.assertTrue(pre_save.has_listeners(LogEntry))
|
||||
|
||||
def assert_no_listeners(self):
|
||||
self.assertFalse(pre_save.has_listeners(LogEntry))
|
||||
|
||||
def test_request_anonymous(self):
|
||||
"""No actor will be logged when a user is not logged in."""
|
||||
# Create a request
|
||||
request = self.factory.get("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
# Run middleware
|
||||
self.middleware.process_request(request)
|
||||
self.get_response_mock.side_effect = self.side_effect(self.assert_no_listeners)
|
||||
|
||||
# Validate result
|
||||
self.assertFalse(pre_save.has_listeners(LogEntry))
|
||||
response = self.middleware(request)
|
||||
|
||||
# Finalize transaction
|
||||
self.middleware.process_exception(request, None)
|
||||
self.assertIs(response, self.response_mock)
|
||||
self.get_response_mock.assert_called_once_with(request)
|
||||
self.assert_no_listeners()
|
||||
|
||||
def test_request(self):
|
||||
"""The actor will be logged when a user is logged in."""
|
||||
# Create a request
|
||||
request = self.factory.get("/")
|
||||
request.user = self.user
|
||||
# Run middleware
|
||||
self.middleware.process_request(request)
|
||||
|
||||
# Validate result
|
||||
self.assertTrue(pre_save.has_listeners(LogEntry))
|
||||
|
||||
# Finalize transaction
|
||||
self.middleware.process_exception(request, None)
|
||||
|
||||
def test_response(self):
|
||||
"""The signal will be disconnected when the request is processed."""
|
||||
# Create a request
|
||||
request = self.factory.get("/")
|
||||
request.user = self.user
|
||||
|
||||
# Run middleware
|
||||
self.middleware.process_request(request)
|
||||
self.assertTrue(
|
||||
pre_save.has_listeners(LogEntry)
|
||||
) # The signal should be present before trying to disconnect it.
|
||||
self.middleware.process_response(request, HttpResponse())
|
||||
self.get_response_mock.side_effect = self.side_effect(self.assert_has_listeners)
|
||||
|
||||
# Validate result
|
||||
self.assertFalse(pre_save.has_listeners(LogEntry))
|
||||
response = self.middleware(request)
|
||||
|
||||
self.assertIs(response, self.response_mock)
|
||||
self.get_response_mock.assert_called_once_with(request)
|
||||
self.assert_no_listeners()
|
||||
|
||||
def test_exception(self):
|
||||
"""The signal will be disconnected when an exception is raised."""
|
||||
# Create a request
|
||||
request = self.factory.get("/")
|
||||
request.user = self.user
|
||||
|
||||
# Run middleware
|
||||
self.middleware.process_request(request)
|
||||
self.assertTrue(
|
||||
pre_save.has_listeners(LogEntry)
|
||||
) # The signal should be present before trying to disconnect it.
|
||||
self.middleware.process_exception(request, ValidationError("Test"))
|
||||
SomeException = type("SomeException", (Exception,), {})
|
||||
|
||||
# Validate result
|
||||
self.assertFalse(pre_save.has_listeners(LogEntry))
|
||||
self.get_response_mock.side_effect = SomeException
|
||||
|
||||
with self.assertRaises(SomeException):
|
||||
self.middleware(request)
|
||||
|
||||
self.assert_no_listeners()
|
||||
|
||||
|
||||
class SimpeIncludeModelTest(TestCase):
|
||||
|
|
|
|||
1
tox.ini
1
tox.ini
|
|
@ -15,6 +15,7 @@ deps =
|
|||
# Test requirements
|
||||
coverage
|
||||
codecov
|
||||
mock
|
||||
psycopg2-binary
|
||||
passenv=
|
||||
TEST_DB_HOST
|
||||
|
|
|
|||
Loading…
Reference in a new issue