diff --git a/auditlog/mixins.py b/auditlog/mixins.py index d2e315e..316054e 100644 --- a/auditlog/mixins.py +++ b/auditlog/mixins.py @@ -3,7 +3,7 @@ import json from django import urls as urlresolvers from django.conf import settings from django.urls.exceptions import NoReverseMatch -from django.utils.html import format_html +from django.utils.html import format_html, format_html_join from django.utils.safestring import mark_safe from django.utils.timezone import localtime @@ -63,14 +63,49 @@ class LogEntryAdminMixin(object): return "" # delete changes = json.loads(obj.changes) + atom_changes = {} + m2m_changes = {} + + for field, change in changes.items(): + if isinstance(change, dict): + assert ( + change["type"] == "m2m" + ), "Only m2m operations are expected to produce dict changes now" + m2m_changes[field] = change + else: + atom_changes[field] = change + msg = [] - msg.append("") - msg.append(self._format_header("#", "Field", "From", "To")) - for i, (field, change) in enumerate(sorted(changes.items()), 1): - value = [i, field] + (["***", "***"] if field == "password" else change) - msg.append(self._format_line(*value)) - msg.append("
") + if atom_changes: + msg.append("") + msg.append(self._format_header("#", "Field", "From", "To")) + for i, (field, change) in enumerate(sorted(atom_changes.items()), 1): + value = [i, field] + (["***", "***"] if field == "password" else change) + msg.append(self._format_line(*value)) + msg.append("
") + + if m2m_changes: + msg.append("") + msg.append(self._format_header("#", "Relationship", "Action", "Objects")) + for i, (field, change) in enumerate(sorted(m2m_changes.items()), 1): + change_html = format_html_join( + mark_safe("
"), + "{}", + [(value,) for value in change["objects"]], + ) + + msg.append( + format_html( + "", + i, + field, + change["operation"], + change_html, + ) + ) + + msg.append("
{}{}{}{}
") return mark_safe("".join(msg)) diff --git a/auditlog/models.py b/auditlog/models.py index d55ffa9..0b36ff6 100644 --- a/auditlog/models.py +++ b/auditlog/models.py @@ -76,6 +76,55 @@ class LogEntryManager(models.Manager): ) return None + def log_m2m_changes( + self, changed_queryset, instance, operation, field_name, **kwargs + ): + """Create a new "changed" log entry from m2m record. + + :param instance: The model instance to log a change for. + :type instance: Model + :param operation: "add" or "delete". + :type action: str + :param kwargs: Field overrides for the :py:class:`LogEntry` object. + :return: The new log entry or `None` if there were no changes. + :rtype: LogEntry + """ + + pk = self._get_pk_value(instance) + if changed_queryset is not None: + kwargs.setdefault( + "content_type", ContentType.objects.get_for_model(instance) + ) + kwargs.setdefault("object_pk", pk) + kwargs.setdefault("object_repr", smart_str(instance)) + kwargs.setdefault("action", LogEntry.Action.UPDATE) + + if isinstance(pk, int): + kwargs.setdefault("object_id", pk) + + get_additional_data = getattr(instance, "get_additional_data", None) + if callable(get_additional_data): + kwargs.setdefault("additional_data", get_additional_data()) + + objects = [smart_str(instance) for instance in changed_queryset] + kwargs["changes"] = json.dumps( + { + field_name: { + "type": "m2m", + "operation": operation, + "objects": objects, + } + } + ) + db = instance._state.db + return ( + self.create(**kwargs) + if db is None or db == "" + else self.using(db).create(**kwargs) + ) + + return None + def get_for_object(self, instance): """ Get log entries for the specified model instance. diff --git a/auditlog/receivers.py b/auditlog/receivers.py index 25a6228..e16fee6 100644 --- a/auditlog/receivers.py +++ b/auditlog/receivers.py @@ -59,3 +59,34 @@ def log_delete(sender, instance, **kwargs): action=LogEntry.Action.DELETE, changes=json.dumps(changes), ) + + +def make_log_m2m_changes(field_name): + """Return a handler for m2m_changed with field_name enclosed.""" + + def log_m2m_changes(signal, action, **kwargs): + """Handle m2m_changed and call LogEntry.objects.log_m2m_changes as needed.""" + if action not in ["post_add", "post_clear", "post_remove"]: + return + + if action == "post_clear": + changed_queryset = kwargs["model"].objects.all() + else: + changed_queryset = kwargs["model"].objects.filter(pk__in=kwargs["pk_set"]) + + if action in ["post_add"]: + LogEntry.objects.log_m2m_changes( + changed_queryset, + kwargs["instance"], + "add", + field_name, + ) + elif action in ["post_remove", "post_clear"]: + LogEntry.objects.log_m2m_changes( + changed_queryset, + kwargs["instance"], + "delete", + field_name, + ) + + return log_m2m_changes diff --git a/auditlog/registry.py b/auditlog/registry.py index acfc3f3..acdf5fc 100644 --- a/auditlog/registry.py +++ b/auditlog/registry.py @@ -1,8 +1,15 @@ -from typing import Callable, Dict, List, Optional, Tuple +from collections import defaultdict +from typing import Callable, Collection, Dict, List, Optional, Tuple from django.db.models import Model from django.db.models.base import ModelBase -from django.db.models.signals import ModelSignal, post_delete, post_save, pre_save +from django.db.models.signals import ( + ModelSignal, + m2m_changed, + post_delete, + post_save, + pre_save, +) DispatchUID = Tuple[int, int, int] @@ -17,12 +24,14 @@ class AuditlogModelRegistry(object): create: bool = True, update: bool = True, delete: bool = True, + m2m: bool = True, custom: Optional[Dict[ModelSignal, Callable]] = None, ): from auditlog.receivers import log_create, log_delete, log_update self._registry = {} self._signals = {} + self._m2m_signals = defaultdict(dict) if create: self._signals[post_save] = log_create @@ -30,6 +39,7 @@ class AuditlogModelRegistry(object): self._signals[pre_save] = log_update if delete: self._signals[post_delete] = log_delete + self._m2m = m2m if custom is not None: self._signals.update(custom) @@ -40,6 +50,7 @@ class AuditlogModelRegistry(object): include_fields: Optional[List[str]] = None, exclude_fields: Optional[List[str]] = None, mapping_fields: Optional[Dict[str, str]] = None, + m2m_fields: Optional[Collection[str]] = None, ): """ Register a model with auditlog. Auditlog will then track mutations on this model's instances. @@ -48,6 +59,7 @@ class AuditlogModelRegistry(object): :param include_fields: The fields to include. Implicitly excludes all other fields. :param exclude_fields: The fields to exclude. Overrides the fields to include. :param mapping_fields: Mapping from field names to strings in diff. + :param m2m_fields: The fields to map as many to many. """ @@ -57,6 +69,8 @@ class AuditlogModelRegistry(object): exclude_fields = [] if mapping_fields is None: mapping_fields = {} + if m2m_fields is None: + m2m_fields = set() def registrar(cls): """Register models for a given class.""" @@ -67,6 +81,7 @@ class AuditlogModelRegistry(object): "include_fields": include_fields, "exclude_fields": exclude_fields, "mapping_fields": mapping_fields, + "m2m_fields": m2m_fields, } self._connect_signals(cls) @@ -121,12 +136,26 @@ class AuditlogModelRegistry(object): """ Connect signals for the model. """ + from auditlog.receivers import make_log_m2m_changes + for signal, receiver in self._signals.items(): signal.connect( receiver, sender=model, dispatch_uid=self._dispatch_uid(signal, receiver), ) + if self._m2m: + for field_name in self._registry[model]["m2m_fields"]: + receiver = make_log_m2m_changes(field_name) + self._m2m_signals[model][field_name] = receiver + field = getattr(model, field_name) + m2m_model = getattr(field, "through") + + m2m_changed.connect( + receiver, + sender=m2m_model, + dispatch_uid=self._dispatch_uid(m2m_changed, receiver), + ) def _disconnect_signals(self, model): """ @@ -136,6 +165,14 @@ class AuditlogModelRegistry(object): signal.disconnect( sender=model, dispatch_uid=self._dispatch_uid(signal, receiver) ) + for field_name, receiver in self._m2m_signals[model].items(): + field = getattr(model, field_name) + m2m_model = getattr(field, "through") + m2m_changed.disconnect( + sender=m2m_model, + dispatch_uid=self._dispatch_uid(m2m_changed, receiver), + ) + del self._m2m_signals[model] def _dispatch_uid(self, signal, receiver) -> DispatchUID: """Generate a dispatch_uid which is unique for a combination of self, signal, and receiver.""" diff --git a/auditlog_tests/models.py b/auditlog_tests/models.py index eddede4..4eea4d0 100644 --- a/auditlog_tests/models.py +++ b/auditlog_tests/models.py @@ -4,7 +4,9 @@ from django.contrib.postgres.fields import ArrayField from django.db import models from auditlog.models import AuditlogHistoryField -from auditlog.registry import auditlog +from auditlog.registry import AuditlogModelRegistry, auditlog + +m2m_only_auditlog = AuditlogModelRegistry(create=False, update=False, delete=False) @auditlog.register() @@ -80,6 +82,24 @@ class ManyRelatedModel(models.Model): history = AuditlogHistoryField() +class FirstManyRelatedModel(models.Model): + """ + A model with a many to many relation to another model similar. + """ + + related = models.ManyToManyField("OtherManyRelatedModel", related_name="related") + + history = AuditlogHistoryField() + + +class OtherManyRelatedModel(models.Model): + """ + A model that 'receives' the other side of the many to many relation from 'FirstManyRelatedModel'. + """ + + history = AuditlogHistoryField() + + @auditlog.register(include_fields=["label"]) class SimpleIncludeModel(models.Model): """ @@ -224,6 +244,9 @@ auditlog.register(ProxyModel) auditlog.register(RelatedModel) auditlog.register(ManyRelatedModel) auditlog.register(ManyRelatedModel.related.through) +m2m_only_auditlog.register( + FirstManyRelatedModel, include_fields=["pk", "history"], m2m_fields={"related"} +) auditlog.register(SimpleExcludeModel, exclude_fields=["text"]) auditlog.register(SimpleMappingModel, mapping_fields={"sku": "Product No."}) auditlog.register(AdditionalDataIncludedModel) diff --git a/auditlog_tests/tests.py b/auditlog_tests/tests.py index c8bbd35..3985419 100644 --- a/auditlog_tests/tests.py +++ b/auditlog_tests/tests.py @@ -23,8 +23,10 @@ from auditlog_tests.models import ( CharfieldTextfieldModel, ChoicesFieldModel, DateTimeFieldModel, + FirstManyRelatedModel, ManyRelatedModel, NoDeleteHistoryModel, + OtherManyRelatedModel, PostgresArrayFieldModel, ProxyModel, RelatedModel, @@ -234,7 +236,7 @@ class ProxyModelWithActorTest(WithActorMixin, ProxyModelBase): class ManyRelatedModelTest(TestCase): """ - Test the behaviour of a many-to-many relationship. + Test the behaviour of a default many-to-many relationship. """ def setUp(self): @@ -253,6 +255,60 @@ class ManyRelatedModelTest(TestCase): ) +class FirstManyRelatedModelTest(TestCase): + """ + Test the behaviour of a many-to-many relationship. + """ + + def setUp(self): + self.obj = FirstManyRelatedModel.objects.create() + self.rel_obj = OtherManyRelatedModel.objects.create() + + def test_related_add_from_first_side(self): + self.obj.related.add(self.rel_obj) + self.assertEqual( + LogEntry.objects.get_for_objects(self.obj.related.all()).count(), + self.rel_obj.history.count(), + ) + self.assertEqual( + LogEntry.objects.get_for_objects(self.obj.related.all()).first(), + self.rel_obj.history.first(), + ) + self.assertEqual(LogEntry.objects.count(), 1) + + def test_related_add_from_other_side(self): + self.rel_obj.related.add(self.obj) + self.assertEqual( + LogEntry.objects.get_for_objects(self.obj.related.all()).count(), + self.rel_obj.history.count(), + ) + self.assertEqual( + LogEntry.objects.get_for_objects(self.obj.related.all()).first(), + self.rel_obj.history.first(), + ) + self.assertEqual(LogEntry.objects.count(), 1) + + def test_related_remove_from_first_side(self): + self.obj.related.add(self.rel_obj) + self.obj.related.remove(self.rel_obj) + self.assertEqual(LogEntry.objects.count(), 2) + + def test_related_remove_from_other_side(self): + self.rel_obj.related.add(self.obj) + self.rel_obj.related.remove(self.obj) + self.assertEqual(LogEntry.objects.count(), 2) + + def test_related_clear_from_first_side(self): + self.obj.related.add(self.rel_obj) + self.obj.related.clear() + self.assertEqual(LogEntry.objects.count(), 2) + + def test_related_clear_from_other_side(self): + self.rel_obj.related.add(self.obj) + self.rel_obj.related.clear() + self.assertEqual(LogEntry.objects.count(), 2) + + class MiddlewareTest(TestCase): """ Test the middleware responsible for connecting and disconnecting the signals used in automatic logging.