Move signal management to a context manager

This change allows setting the same signals when the request is not
present, i.e. in a celery task.
This commit is contained in:
Alieh Rymašeŭski 2019-05-11 13:11:12 +03:00
parent 2dc0ac43b5
commit a5381b6195
5 changed files with 112 additions and 96 deletions

View file

@ -4,6 +4,7 @@ src/auditlog/__init__.py
src/auditlog/admin.py
src/auditlog/apps.py
src/auditlog/compat.py
src/auditlog/context.py
src/auditlog/diff.py
src/auditlog/filters.py
src/auditlog/middleware.py

View file

@ -5,3 +5,4 @@ tox>=1.7.0
codecov>=2.0.0
django-multiselectfield==0.1.8
psycopg2-binary
mock

55
src/auditlog/context.py Normal file
View file

@ -0,0 +1,55 @@
import contextlib
import time
import threading
from django.contrib.auth import get_user_model
from django.db.models.signals import pre_save
from django.utils.functional import curry
from auditlog.models import LogEntry
threadlocal = threading.local()
@contextlib.contextmanager
def set_actor(actor, remote_addr=None):
"""Connect a signal receiver with current user attached."""
# Initialize thread local storage
threadlocal.auditlog = {
'signal_duid': ('set_actor', time.time()),
'remote_addr': remote_addr,
}
# Connect signal for automatic logging
set_actor = curry(_set_actor, user=actor, signal_duid=threadlocal.auditlog['signal_duid'])
pre_save.connect(set_actor, sender=LogEntry, dispatch_uid=threadlocal.auditlog['signal_duid'], weak=False)
try:
yield
finally:
try:
auditlog = threadlocal.auditlog
except AttributeError:
pass
else:
pre_save.disconnect(sender=LogEntry, dispatch_uid=auditlog['signal_duid'])
def _set_actor(user, sender, instance, signal_duid, **kwargs):
"""Signal receiver with an extra 'user' kwarg.
This function becomes a valid signal receiver when it is curried with the actor.
"""
try:
auditlog = threadlocal.auditlog
except AttributeError:
pass
else:
if signal_duid != auditlog['signal_duid']:
return
auth_user_model = get_user_model()
if sender == LogEntry and isinstance(user, auth_user_model) and instance.actor is None:
instance.actor = user
instance.remote_addr = auditlog['remote_addr']

View file

@ -1,78 +1,38 @@
from __future__ import unicode_literals
import threading
import time
import contextlib
from django.contrib.auth import get_user_model
from django.db.models.signals import pre_save
from django.utils.functional import curry
from auditlog.models import LogEntry
from auditlog.compat import is_authenticated
# Use MiddlewareMixin when present (Django >= 1.10)
try:
from django.utils.deprecation import MiddlewareMixin
except ImportError:
MiddlewareMixin = object
from auditlog.context import set_actor
threadlocal = threading.local()
@contextlib.contextmanager
def nullcontext():
"""Equivalent to contextlib.nullcontext(None) from Python 3.7."""
yield
class AuditlogMiddleware(MiddlewareMixin):
class AuditlogMiddleware(object):
"""
Middleware to couple the request's user to log items. This is accomplished by currying the signal receiver with the
user from the request (or None if the user is not authenticated).
"""
def process_request(self, request):
"""
Gets the current user from the request and prepares and connects a signal receiver with the user already
attached to it.
"""
# Initialize thread local storage
threadlocal.auditlog = {
'signal_duid': (self.__class__, time.time()),
'remote_addr': request.META.get('REMOTE_ADDR'),
}
def __init__(self, get_response=None):
self.get_response = get_response
def __call__(self, request):
# In case of proxy, set 'original' address
if request.META.get('HTTP_X_FORWARDED_FOR'):
threadlocal.auditlog['remote_addr'] = request.META.get('HTTP_X_FORWARDED_FOR').split(',')[0]
# In case of proxy, set 'original' address
remote_addr = request.META.get('HTTP_X_FORWARDED_FOR').split(',')[0]
else:
remote_addr = request.META.get('REMOTE_ADDR')
# Connect signal for automatic logging
if hasattr(request, 'user') and is_authenticated(request.user):
set_actor = curry(self.set_actor, user=request.user, signal_duid=threadlocal.auditlog['signal_duid'])
pre_save.connect(set_actor, sender=LogEntry, dispatch_uid=threadlocal.auditlog['signal_duid'], weak=False)
context = set_actor(actor=request.user, remote_addr=remote_addr)
else:
context = nullcontext()
def process_response(self, request, response):
"""
Disconnects the signal receiver to prevent it from staying active.
"""
if hasattr(threadlocal, 'auditlog'):
pre_save.disconnect(sender=LogEntry, dispatch_uid=threadlocal.auditlog['signal_duid'])
return response
def process_exception(self, request, exception):
"""
Disconnects the signal receiver to prevent it from staying active in case of an exception.
"""
if hasattr(threadlocal, 'auditlog'):
pre_save.disconnect(sender=LogEntry, dispatch_uid=threadlocal.auditlog['signal_duid'])
return None
@staticmethod
def set_actor(user, sender, instance, signal_duid, **kwargs):
"""
Signal receiver with an extra, required 'user' kwarg. This method becomes a real (valid) signal receiver when
it is curried with the actor.
"""
if hasattr(threadlocal, 'auditlog'):
if signal_duid != threadlocal.auditlog['signal_duid']:
return
if sender == LogEntry and isinstance(user, get_user_model()) and instance.actor is None:
instance.actor = user
instance.remote_addr = threadlocal.auditlog['remote_addr']
with context:
return self.get_response(request)

View file

@ -9,6 +9,7 @@ from django.http import HttpResponse
from django.test import TestCase, RequestFactory
from django.utils import dateformat, formats, timezone
from dateutil.tz import gettz
import mock
from auditlog.middleware import AuditlogMiddleware
from auditlog.models import LogEntry
@ -121,66 +122,64 @@ class MiddlewareTest(TestCase):
Test the middleware responsible for connecting and disconnecting the signals used in automatic logging.
"""
def setUp(self):
self.middleware = AuditlogMiddleware()
self.get_response_mock = mock.Mock()
self.response_mock = mock.Mock()
self.middleware = AuditlogMiddleware(get_response=self.get_response_mock)
self.factory = RequestFactory()
self.user = User.objects.create_user(username='test', email='test@example.com', password='top_secret')
def side_effect(self, assertion):
def inner(request):
assertion()
return self.response_mock
return inner
def assert_has_listeners(self):
self.assertTrue(pre_save.has_listeners(LogEntry))
def assert_no_listeners(self):
self.assertFalse(pre_save.has_listeners(LogEntry))
def test_request_anonymous(self):
"""No actor will be logged when a user is not logged in."""
# Create a request
request = self.factory.get('/')
request.user = AnonymousUser()
# Run middleware
self.middleware.process_request(request)
self.get_response_mock.side_effect = self.side_effect(self.assert_no_listeners)
# Validate result
self.assertFalse(pre_save.has_listeners(LogEntry))
response = self.middleware(request)
# Finalize transaction
self.middleware.process_exception(request, None)
self.assertIs(response, self.response_mock)
self.get_response_mock.assert_called_once_with(request)
self.assert_no_listeners()
def test_request(self):
"""The actor will be logged when a user is logged in."""
# Create a request
request = self.factory.get('/')
request.user = self.user
# Run middleware
self.middleware.process_request(request)
# Validate result
self.assertTrue(pre_save.has_listeners(LogEntry))
# Finalize transaction
self.middleware.process_exception(request, None)
def test_response(self):
"""The signal will be disconnected when the request is processed."""
# Create a request
request = self.factory.get('/')
request.user = self.user
# Run middleware
self.middleware.process_request(request)
self.assertTrue(pre_save.has_listeners(LogEntry)) # The signal should be present before trying to disconnect it.
self.middleware.process_response(request, HttpResponse())
self.get_response_mock.side_effect = self.side_effect(self.assert_has_listeners)
# Validate result
self.assertFalse(pre_save.has_listeners(LogEntry))
response = self.middleware(request)
self.assertIs(response, self.response_mock)
self.get_response_mock.assert_called_once_with(request)
self.assert_no_listeners()
def test_exception(self):
"""The signal will be disconnected when an exception is raised."""
# Create a request
request = self.factory.get('/')
request.user = self.user
# Run middleware
self.middleware.process_request(request)
self.assertTrue(pre_save.has_listeners(LogEntry)) # The signal should be present before trying to disconnect it.
self.middleware.process_exception(request, ValidationError("Test"))
SomeException = type('SomeException', (Exception,), {})
# Validate result
self.assertFalse(pre_save.has_listeners(LogEntry))
self.get_response_mock.side_effect = SomeException
with self.assertRaises(SomeException):
self.middleware(request)
self.assert_no_listeners()
class SimpeIncludeModelTest(TestCase):