diff --git a/CHANGELOG.md b/CHANGELOG.md index ff21e6a..9a20307 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,8 @@ #### Improvements -- feat: Added support for Correlation ID +- feat: Added support for Correlation ID. ([#481](https://github.com/jazzband/django-auditlog/pull/481)) +- feat: Added pre-log and post-log signals. ([#483](https://github.com/jazzband/django-auditlog/pull/483)) #### Fixes diff --git a/auditlog/diff.py b/auditlog/diff.py index e657ee1..0d25c86 100644 --- a/auditlog/diff.py +++ b/auditlog/diff.py @@ -1,4 +1,5 @@ from datetime import timezone +from typing import Optional from django.conf import settings from django.core.exceptions import ObjectDoesNotExist @@ -98,7 +99,9 @@ def mask_str(value: str) -> str: return "*" * mask_limit + value[mask_limit:] -def model_instance_diff(old, new, fields_to_check=None): +def model_instance_diff( + old: Optional[Model], new: Optional[Model], fields_to_check=None +): """ Calculates the differences between two model instances. One of the instances may be ``None`` (i.e., a newly created model or deleted model). This will cause all fields with a value to have diff --git a/auditlog/receivers.py b/auditlog/receivers.py index 2a2c475..561013e 100644 --- a/auditlog/receivers.py +++ b/auditlog/receivers.py @@ -6,6 +6,7 @@ from django.conf import settings from auditlog.context import threadlocal from auditlog.diff import model_instance_diff from auditlog.models import LogEntry +from auditlog.signals import post_log, pre_log def check_disable(signal_handler): @@ -33,12 +34,12 @@ def log_create(sender, instance, created, **kwargs): Direct use is discouraged, connect your model through :py:func:`auditlog.registry.register` instead. """ if created: - changes = model_instance_diff(None, instance) - - LogEntry.objects.log_create( - instance, + _create_log_entry( action=LogEntry.Action.CREATE, - changes=json.dumps(changes), + instance=instance, + sender=sender, + diff_old=None, + diff_new=instance, ) @@ -50,22 +51,16 @@ def log_update(sender, instance, **kwargs): Direct use is discouraged, connect your model through :py:func:`auditlog.registry.register` instead. """ if instance.pk is not None: - try: - old = sender.objects.get(pk=instance.pk) - except sender.DoesNotExist: - pass - else: - new = instance - update_fields = kwargs.get("update_fields", None) - changes = model_instance_diff(old, new, fields_to_check=update_fields) - - # Log an entry only if there are changes - if changes: - LogEntry.objects.log_create( - instance, - action=LogEntry.Action.UPDATE, - changes=json.dumps(changes), - ) + update_fields = kwargs.get("update_fields", None) + old = sender.objects.filter(pk=instance.pk).first() + _create_log_entry( + action=LogEntry.Action.UPDATE, + instance=instance, + sender=sender, + diff_old=old, + diff_new=instance, + fields_to_check=update_fields, + ) @check_disable @@ -76,12 +71,12 @@ def log_delete(sender, instance, **kwargs): Direct use is discouraged, connect your model through :py:func:`auditlog.registry.register` instead. """ if instance.pk is not None: - changes = model_instance_diff(instance, None) - - LogEntry.objects.log_create( - instance, + _create_log_entry( action=LogEntry.Action.DELETE, - changes=json.dumps(changes), + instance=instance, + sender=sender, + diff_old=instance, + diff_new=None, ) @@ -92,14 +87,50 @@ def log_access(sender, instance, **kwargs): Direct use is discouraged, connect your model through :py:func:`auditlog.registry.register` instead. """ if instance.pk is not None: - - LogEntry.objects.log_create( - instance, + _create_log_entry( action=LogEntry.Action.ACCESS, - changes="null", + instance=instance, + sender=sender, + diff_old=None, + diff_new=None, + force_log=True, ) +def _create_log_entry( + action, instance, sender, diff_old, diff_new, fields_to_check=None, force_log=False +): + pre_log_results = pre_log.send( + sender, + instance=instance, + action=action, + ) + error = None + try: + changes = model_instance_diff( + diff_old, diff_new, fields_to_check=fields_to_check + ) + + if force_log or changes: + LogEntry.objects.log_create( + instance, + action=action, + changes=json.dumps(changes), + ) + except BaseException as e: + error = e + finally: + post_log.send( + sender, + instance=instance, + action=action, + error=error, + pre_log_results=pre_log_results, + ) + if error: + raise error + + def make_log_m2m_changes(field_name): """Return a handler for m2m_changed with field_name enclosed.""" diff --git a/auditlog/signals.py b/auditlog/signals.py index 67e518c..aec291a 100644 --- a/auditlog/signals.py +++ b/auditlog/signals.py @@ -1,3 +1,53 @@ import django.dispatch accessed = django.dispatch.Signal() + + +pre_log = django.dispatch.Signal() +""" +Whenever an audit log entry is written, this signal +is sent before writing the log. +Keyword arguments sent with this signal: + +:param class sender: + The model class that's being audited. + +:param Any instance: + The actual instance that's being audited. + +:param Action action: + The action on the model resulting in an + audit log entry. Type: :class:`auditlog.models.LogEntry.Action` + +The receivers' return values are sent to any :func:`post_log` +signal receivers. +""" + +post_log = django.dispatch.Signal() +""" +Whenever an audit log entry is written, this signal +is sent after writing the log. +Keyword arguments sent with this signal: + +:param class sender: + The model class that's being audited. + +:param Any instance: + The actual instance that's being audited. + +:param Action action: + The action on the model resulting in an + audit log entry. Type: :class:`auditlog.models.LogEntry.Action` + +:param Optional[Exception] error: + The error, if one occurred while saving the audit log entry. ``None``, + otherwise + +:param List[Tuple[method,Any]] pre_log_results: + List of tuple pairs ``[(pre_log_receiver, pre_log_response)]``, where + ``pre_log_receiver`` is the receiver method, and ``pre_log_response`` is the + corresponding response of that method. If there are no :const:`pre_log` receivers, + then the list will be empty. ``pre_log_receiver`` is guaranteed to be + non-null, but ``pre_log_response`` may be ``None``. This depends on the corresponding + ``pre_log_receiver``'s return value. +""" diff --git a/auditlog_tests/tests.py b/auditlog_tests/tests.py index 290fdc6..c747582 100644 --- a/auditlog_tests/tests.py +++ b/auditlog_tests/tests.py @@ -1,9 +1,11 @@ import datetime import itertools import json +import random import warnings from datetime import timezone from unittest import mock +from unittest.mock import patch import freezegun from dateutil.tz import gettz @@ -27,6 +29,7 @@ from auditlog.diff import model_instance_diff from auditlog.middleware import AuditlogMiddleware from auditlog.models import LogEntry from auditlog.registry import AuditlogModelRegistry, AuditLogRegistrationError, auditlog +from auditlog.signals import post_log, pre_log from auditlog_tests.fixtures.custom_get_cid import get_cid as custom_get_cid from auditlog_tests.models import ( AdditionalDataIncludedModel, @@ -1911,6 +1914,167 @@ class TestAccessLog(TestCase): self.assertEqual(log_entry.changes_dict, {}) +class SignalTests(TestCase): + def setUp(self): + self.obj = SimpleModel.objects.create(text="I am not difficult.") + self.my_pre_log_data = { + "is_called": False, + "my_sender": None, + "my_instance": None, + "my_action": None, + } + self.my_post_log_data = { + "is_called": False, + "my_sender": None, + "my_instance": None, + "my_action": None, + "my_error": None, + } + + def assertSignals(self, action): + self.assertTrue( + self.my_pre_log_data["is_called"], "pre_log hook receiver not called" + ) + self.assertIs(self.my_pre_log_data["my_sender"], self.obj.__class__) + self.assertIs(self.my_pre_log_data["my_instance"], self.obj) + self.assertEqual(self.my_pre_log_data["my_action"], action) + + self.assertTrue( + self.my_post_log_data["is_called"], "post_log hook receiver not called" + ) + self.assertIs(self.my_post_log_data["my_sender"], self.obj.__class__) + self.assertIs(self.my_post_log_data["my_instance"], self.obj) + self.assertEqual(self.my_post_log_data["my_action"], action) + self.assertIsNone(self.my_post_log_data["my_error"]) + + def test_custom_signals(self): + my_ret_val = random.randint(0, 10000) + my_other_ret_val = random.randint(0, 10000) + + def pre_log_receiver(sender, instance, action, **_kwargs): + self.my_pre_log_data["is_called"] = True + self.my_pre_log_data["my_sender"] = sender + self.my_pre_log_data["my_instance"] = instance + self.my_pre_log_data["my_action"] = action + return my_ret_val + + def pre_log_receiver_extra(*_args, **_kwargs): + return my_other_ret_val + + def post_log_receiver( + sender, instance, action, error, pre_log_results, **_kwargs + ): + self.my_post_log_data["is_called"] = True + self.my_post_log_data["my_sender"] = sender + self.my_post_log_data["my_instance"] = instance + self.my_post_log_data["my_action"] = action + self.my_post_log_data["my_error"] = error + + self.assertEqual(len(pre_log_results), 2) + + found_first_result = False + found_second_result = False + for pre_log_fn, pre_log_result in pre_log_results: + if pre_log_fn is pre_log_receiver and pre_log_result == my_ret_val: + found_first_result = True + for pre_log_fn, pre_log_result in pre_log_results: + if ( + pre_log_fn is pre_log_receiver_extra + and pre_log_result == my_other_ret_val + ): + found_second_result = True + + self.assertTrue(found_first_result) + self.assertTrue(found_second_result) + + return my_ret_val + + pre_log.connect(pre_log_receiver) + pre_log.connect(pre_log_receiver_extra) + post_log.connect(post_log_receiver) + + self.obj = SimpleModel.objects.create(text="I am not difficult.") + + self.assertSignals(LogEntry.Action.CREATE) + + def test_custom_signals_update(self): + def pre_log_receiver(sender, instance, action, **_kwargs): + self.my_pre_log_data["is_called"] = True + self.my_pre_log_data["my_sender"] = sender + self.my_pre_log_data["my_instance"] = instance + self.my_pre_log_data["my_action"] = action + + def post_log_receiver(sender, instance, action, error, **_kwargs): + self.my_post_log_data["is_called"] = True + self.my_post_log_data["my_sender"] = sender + self.my_post_log_data["my_instance"] = instance + self.my_post_log_data["my_action"] = action + self.my_post_log_data["my_error"] = error + + pre_log.connect(pre_log_receiver) + post_log.connect(post_log_receiver) + + self.obj.text = "Changed Text" + self.obj.save() + + self.assertSignals(LogEntry.Action.UPDATE) + + def test_custom_signals_delete(self): + def pre_log_receiver(sender, instance, action, **_kwargs): + self.my_pre_log_data["is_called"] = True + self.my_pre_log_data["my_sender"] = sender + self.my_pre_log_data["my_instance"] = instance + self.my_pre_log_data["my_action"] = action + + def post_log_receiver(sender, instance, action, error, **_kwargs): + self.my_post_log_data["is_called"] = True + self.my_post_log_data["my_sender"] = sender + self.my_post_log_data["my_instance"] = instance + self.my_post_log_data["my_action"] = action + self.my_post_log_data["my_error"] = error + + pre_log.connect(pre_log_receiver) + post_log.connect(post_log_receiver) + + self.obj.delete() + + self.assertSignals(LogEntry.Action.DELETE) + + @patch("auditlog.receivers.LogEntry.objects") + def test_signals_errors(self, log_entry_objects_mock): + class CustomSignalError(BaseException): + pass + + def post_log_receiver(error, **_kwargs): + self.my_post_log_data["my_error"] = error + + post_log.connect(post_log_receiver) + + # create + error_create = CustomSignalError(LogEntry.Action.CREATE) + log_entry_objects_mock.log_create.side_effect = error_create + with self.assertRaises(CustomSignalError): + SimpleModel.objects.create(text="I am not difficult.") + self.assertEqual(self.my_post_log_data["my_error"], error_create) + + # update + error_update = CustomSignalError(LogEntry.Action.UPDATE) + log_entry_objects_mock.log_create.side_effect = error_update + with self.assertRaises(CustomSignalError): + obj = SimpleModel.objects.get(pk=self.obj.pk) + obj.text = "updating" + obj.save() + self.assertEqual(self.my_post_log_data["my_error"], error_update) + + # delete + error_delete = CustomSignalError(LogEntry.Action.DELETE) + log_entry_objects_mock.log_create.side_effect = error_delete + with self.assertRaises(CustomSignalError): + obj = SimpleModel.objects.get(pk=self.obj.pk) + obj.delete() + self.assertEqual(self.my_post_log_data["my_error"], error_delete) + + @override_settings(AUDITLOG_DISABLE_ON_RAW_SAVE=True) class DisableTest(TestCase): """ diff --git a/docs/source/internals.rst b/docs/source/internals.rst index 57163d2..9c869c1 100644 --- a/docs/source/internals.rst +++ b/docs/source/internals.rst @@ -31,6 +31,18 @@ Signal receivers .. automodule:: auditlog.receivers :members: +Custom Signals +-------------- +Django Auditlog provides two custom signals that will hook in before +and after any Auditlog record is written from a ``create``, ``update``, +``delete``, or ``accessed`` action on an audited model. + +.. automodule:: auditlog.signals + :members: + :member-order: bysource + +.. versionadded:: 3.0.0 + Calculating changes -------------------