diff --git a/auditlog/mixins.py b/auditlog/mixins.py
index d2e315e..316054e 100644
--- a/auditlog/mixins.py
+++ b/auditlog/mixins.py
@@ -3,7 +3,7 @@ import json
from django import urls as urlresolvers
from django.conf import settings
from django.urls.exceptions import NoReverseMatch
-from django.utils.html import format_html
+from django.utils.html import format_html, format_html_join
from django.utils.safestring import mark_safe
from django.utils.timezone import localtime
@@ -63,14 +63,49 @@ class LogEntryAdminMixin(object):
return "" # delete
changes = json.loads(obj.changes)
+ atom_changes = {}
+ m2m_changes = {}
+
+ for field, change in changes.items():
+ if isinstance(change, dict):
+ assert (
+ change["type"] == "m2m"
+ ), "Only m2m operations are expected to produce dict changes now"
+ m2m_changes[field] = change
+ else:
+ atom_changes[field] = change
+
msg = []
- msg.append("
")
- msg.append(self._format_header("#", "Field", "From", "To"))
- for i, (field, change) in enumerate(sorted(changes.items()), 1):
- value = [i, field] + (["***", "***"] if field == "password" else change)
- msg.append(self._format_line(*value))
- msg.append("
")
+ if atom_changes:
+ msg.append("")
+ msg.append(self._format_header("#", "Field", "From", "To"))
+ for i, (field, change) in enumerate(sorted(atom_changes.items()), 1):
+ value = [i, field] + (["***", "***"] if field == "password" else change)
+ msg.append(self._format_line(*value))
+ msg.append("
")
+
+ if m2m_changes:
+ msg.append("")
+ msg.append(self._format_header("#", "Relationship", "Action", "Objects"))
+ for i, (field, change) in enumerate(sorted(m2m_changes.items()), 1):
+ change_html = format_html_join(
+ mark_safe(""),
+ "{}",
+ [(value,) for value in change["objects"]],
+ )
+
+ msg.append(
+ format_html(
+ "| {} | {} | {} | {} | ",
+ i,
+ field,
+ change["operation"],
+ change_html,
+ )
+ )
+
+ msg.append("
")
return mark_safe("".join(msg))
diff --git a/auditlog/models.py b/auditlog/models.py
index d55ffa9..0b36ff6 100644
--- a/auditlog/models.py
+++ b/auditlog/models.py
@@ -76,6 +76,55 @@ class LogEntryManager(models.Manager):
)
return None
+ def log_m2m_changes(
+ self, changed_queryset, instance, operation, field_name, **kwargs
+ ):
+ """Create a new "changed" log entry from m2m record.
+
+ :param instance: The model instance to log a change for.
+ :type instance: Model
+ :param operation: "add" or "delete".
+ :type action: str
+ :param kwargs: Field overrides for the :py:class:`LogEntry` object.
+ :return: The new log entry or `None` if there were no changes.
+ :rtype: LogEntry
+ """
+
+ pk = self._get_pk_value(instance)
+ if changed_queryset is not None:
+ kwargs.setdefault(
+ "content_type", ContentType.objects.get_for_model(instance)
+ )
+ kwargs.setdefault("object_pk", pk)
+ kwargs.setdefault("object_repr", smart_str(instance))
+ kwargs.setdefault("action", LogEntry.Action.UPDATE)
+
+ if isinstance(pk, int):
+ kwargs.setdefault("object_id", pk)
+
+ get_additional_data = getattr(instance, "get_additional_data", None)
+ if callable(get_additional_data):
+ kwargs.setdefault("additional_data", get_additional_data())
+
+ objects = [smart_str(instance) for instance in changed_queryset]
+ kwargs["changes"] = json.dumps(
+ {
+ field_name: {
+ "type": "m2m",
+ "operation": operation,
+ "objects": objects,
+ }
+ }
+ )
+ db = instance._state.db
+ return (
+ self.create(**kwargs)
+ if db is None or db == ""
+ else self.using(db).create(**kwargs)
+ )
+
+ return None
+
def get_for_object(self, instance):
"""
Get log entries for the specified model instance.
diff --git a/auditlog/receivers.py b/auditlog/receivers.py
index 25a6228..e16fee6 100644
--- a/auditlog/receivers.py
+++ b/auditlog/receivers.py
@@ -59,3 +59,34 @@ def log_delete(sender, instance, **kwargs):
action=LogEntry.Action.DELETE,
changes=json.dumps(changes),
)
+
+
+def make_log_m2m_changes(field_name):
+ """Return a handler for m2m_changed with field_name enclosed."""
+
+ def log_m2m_changes(signal, action, **kwargs):
+ """Handle m2m_changed and call LogEntry.objects.log_m2m_changes as needed."""
+ if action not in ["post_add", "post_clear", "post_remove"]:
+ return
+
+ if action == "post_clear":
+ changed_queryset = kwargs["model"].objects.all()
+ else:
+ changed_queryset = kwargs["model"].objects.filter(pk__in=kwargs["pk_set"])
+
+ if action in ["post_add"]:
+ LogEntry.objects.log_m2m_changes(
+ changed_queryset,
+ kwargs["instance"],
+ "add",
+ field_name,
+ )
+ elif action in ["post_remove", "post_clear"]:
+ LogEntry.objects.log_m2m_changes(
+ changed_queryset,
+ kwargs["instance"],
+ "delete",
+ field_name,
+ )
+
+ return log_m2m_changes
diff --git a/auditlog/registry.py b/auditlog/registry.py
index acfc3f3..acdf5fc 100644
--- a/auditlog/registry.py
+++ b/auditlog/registry.py
@@ -1,8 +1,15 @@
-from typing import Callable, Dict, List, Optional, Tuple
+from collections import defaultdict
+from typing import Callable, Collection, Dict, List, Optional, Tuple
from django.db.models import Model
from django.db.models.base import ModelBase
-from django.db.models.signals import ModelSignal, post_delete, post_save, pre_save
+from django.db.models.signals import (
+ ModelSignal,
+ m2m_changed,
+ post_delete,
+ post_save,
+ pre_save,
+)
DispatchUID = Tuple[int, int, int]
@@ -17,12 +24,14 @@ class AuditlogModelRegistry(object):
create: bool = True,
update: bool = True,
delete: bool = True,
+ m2m: bool = True,
custom: Optional[Dict[ModelSignal, Callable]] = None,
):
from auditlog.receivers import log_create, log_delete, log_update
self._registry = {}
self._signals = {}
+ self._m2m_signals = defaultdict(dict)
if create:
self._signals[post_save] = log_create
@@ -30,6 +39,7 @@ class AuditlogModelRegistry(object):
self._signals[pre_save] = log_update
if delete:
self._signals[post_delete] = log_delete
+ self._m2m = m2m
if custom is not None:
self._signals.update(custom)
@@ -40,6 +50,7 @@ class AuditlogModelRegistry(object):
include_fields: Optional[List[str]] = None,
exclude_fields: Optional[List[str]] = None,
mapping_fields: Optional[Dict[str, str]] = None,
+ m2m_fields: Optional[Collection[str]] = None,
):
"""
Register a model with auditlog. Auditlog will then track mutations on this model's instances.
@@ -48,6 +59,7 @@ class AuditlogModelRegistry(object):
:param include_fields: The fields to include. Implicitly excludes all other fields.
:param exclude_fields: The fields to exclude. Overrides the fields to include.
:param mapping_fields: Mapping from field names to strings in diff.
+ :param m2m_fields: The fields to map as many to many.
"""
@@ -57,6 +69,8 @@ class AuditlogModelRegistry(object):
exclude_fields = []
if mapping_fields is None:
mapping_fields = {}
+ if m2m_fields is None:
+ m2m_fields = set()
def registrar(cls):
"""Register models for a given class."""
@@ -67,6 +81,7 @@ class AuditlogModelRegistry(object):
"include_fields": include_fields,
"exclude_fields": exclude_fields,
"mapping_fields": mapping_fields,
+ "m2m_fields": m2m_fields,
}
self._connect_signals(cls)
@@ -121,12 +136,26 @@ class AuditlogModelRegistry(object):
"""
Connect signals for the model.
"""
+ from auditlog.receivers import make_log_m2m_changes
+
for signal, receiver in self._signals.items():
signal.connect(
receiver,
sender=model,
dispatch_uid=self._dispatch_uid(signal, receiver),
)
+ if self._m2m:
+ for field_name in self._registry[model]["m2m_fields"]:
+ receiver = make_log_m2m_changes(field_name)
+ self._m2m_signals[model][field_name] = receiver
+ field = getattr(model, field_name)
+ m2m_model = getattr(field, "through")
+
+ m2m_changed.connect(
+ receiver,
+ sender=m2m_model,
+ dispatch_uid=self._dispatch_uid(m2m_changed, receiver),
+ )
def _disconnect_signals(self, model):
"""
@@ -136,6 +165,14 @@ class AuditlogModelRegistry(object):
signal.disconnect(
sender=model, dispatch_uid=self._dispatch_uid(signal, receiver)
)
+ for field_name, receiver in self._m2m_signals[model].items():
+ field = getattr(model, field_name)
+ m2m_model = getattr(field, "through")
+ m2m_changed.disconnect(
+ sender=m2m_model,
+ dispatch_uid=self._dispatch_uid(m2m_changed, receiver),
+ )
+ del self._m2m_signals[model]
def _dispatch_uid(self, signal, receiver) -> DispatchUID:
"""Generate a dispatch_uid which is unique for a combination of self, signal, and receiver."""
diff --git a/auditlog_tests/models.py b/auditlog_tests/models.py
index eddede4..4eea4d0 100644
--- a/auditlog_tests/models.py
+++ b/auditlog_tests/models.py
@@ -4,7 +4,9 @@ from django.contrib.postgres.fields import ArrayField
from django.db import models
from auditlog.models import AuditlogHistoryField
-from auditlog.registry import auditlog
+from auditlog.registry import AuditlogModelRegistry, auditlog
+
+m2m_only_auditlog = AuditlogModelRegistry(create=False, update=False, delete=False)
@auditlog.register()
@@ -80,6 +82,24 @@ class ManyRelatedModel(models.Model):
history = AuditlogHistoryField()
+class FirstManyRelatedModel(models.Model):
+ """
+ A model with a many to many relation to another model similar.
+ """
+
+ related = models.ManyToManyField("OtherManyRelatedModel", related_name="related")
+
+ history = AuditlogHistoryField()
+
+
+class OtherManyRelatedModel(models.Model):
+ """
+ A model that 'receives' the other side of the many to many relation from 'FirstManyRelatedModel'.
+ """
+
+ history = AuditlogHistoryField()
+
+
@auditlog.register(include_fields=["label"])
class SimpleIncludeModel(models.Model):
"""
@@ -224,6 +244,9 @@ auditlog.register(ProxyModel)
auditlog.register(RelatedModel)
auditlog.register(ManyRelatedModel)
auditlog.register(ManyRelatedModel.related.through)
+m2m_only_auditlog.register(
+ FirstManyRelatedModel, include_fields=["pk", "history"], m2m_fields={"related"}
+)
auditlog.register(SimpleExcludeModel, exclude_fields=["text"])
auditlog.register(SimpleMappingModel, mapping_fields={"sku": "Product No."})
auditlog.register(AdditionalDataIncludedModel)
diff --git a/auditlog_tests/tests.py b/auditlog_tests/tests.py
index c8bbd35..3985419 100644
--- a/auditlog_tests/tests.py
+++ b/auditlog_tests/tests.py
@@ -23,8 +23,10 @@ from auditlog_tests.models import (
CharfieldTextfieldModel,
ChoicesFieldModel,
DateTimeFieldModel,
+ FirstManyRelatedModel,
ManyRelatedModel,
NoDeleteHistoryModel,
+ OtherManyRelatedModel,
PostgresArrayFieldModel,
ProxyModel,
RelatedModel,
@@ -234,7 +236,7 @@ class ProxyModelWithActorTest(WithActorMixin, ProxyModelBase):
class ManyRelatedModelTest(TestCase):
"""
- Test the behaviour of a many-to-many relationship.
+ Test the behaviour of a default many-to-many relationship.
"""
def setUp(self):
@@ -253,6 +255,60 @@ class ManyRelatedModelTest(TestCase):
)
+class FirstManyRelatedModelTest(TestCase):
+ """
+ Test the behaviour of a many-to-many relationship.
+ """
+
+ def setUp(self):
+ self.obj = FirstManyRelatedModel.objects.create()
+ self.rel_obj = OtherManyRelatedModel.objects.create()
+
+ def test_related_add_from_first_side(self):
+ self.obj.related.add(self.rel_obj)
+ self.assertEqual(
+ LogEntry.objects.get_for_objects(self.obj.related.all()).count(),
+ self.rel_obj.history.count(),
+ )
+ self.assertEqual(
+ LogEntry.objects.get_for_objects(self.obj.related.all()).first(),
+ self.rel_obj.history.first(),
+ )
+ self.assertEqual(LogEntry.objects.count(), 1)
+
+ def test_related_add_from_other_side(self):
+ self.rel_obj.related.add(self.obj)
+ self.assertEqual(
+ LogEntry.objects.get_for_objects(self.obj.related.all()).count(),
+ self.rel_obj.history.count(),
+ )
+ self.assertEqual(
+ LogEntry.objects.get_for_objects(self.obj.related.all()).first(),
+ self.rel_obj.history.first(),
+ )
+ self.assertEqual(LogEntry.objects.count(), 1)
+
+ def test_related_remove_from_first_side(self):
+ self.obj.related.add(self.rel_obj)
+ self.obj.related.remove(self.rel_obj)
+ self.assertEqual(LogEntry.objects.count(), 2)
+
+ def test_related_remove_from_other_side(self):
+ self.rel_obj.related.add(self.obj)
+ self.rel_obj.related.remove(self.obj)
+ self.assertEqual(LogEntry.objects.count(), 2)
+
+ def test_related_clear_from_first_side(self):
+ self.obj.related.add(self.rel_obj)
+ self.obj.related.clear()
+ self.assertEqual(LogEntry.objects.count(), 2)
+
+ def test_related_clear_from_other_side(self):
+ self.rel_obj.related.add(self.obj)
+ self.rel_obj.related.clear()
+ self.assertEqual(LogEntry.objects.count(), 2)
+
+
class MiddlewareTest(TestCase):
"""
Test the middleware responsible for connecting and disconnecting the signals used in automatic logging.