Add logic to track m2m relationships

This commit is contained in:
Alieh Rymašeŭski 2021-06-28 16:47:45 +03:00
parent 465bfded80
commit 48adbc5a1e
6 changed files with 242 additions and 11 deletions

View file

@ -3,7 +3,7 @@ import json
from django import urls as urlresolvers
from django.conf import settings
from django.urls.exceptions import NoReverseMatch
from django.utils.html import format_html
from django.utils.html import format_html, format_html_join
from django.utils.safestring import mark_safe
from django.utils.timezone import localtime
@ -63,14 +63,49 @@ class LogEntryAdminMixin(object):
return "" # delete
changes = json.loads(obj.changes)
atom_changes = {}
m2m_changes = {}
for field, change in changes.items():
if isinstance(change, dict):
assert (
change["type"] == "m2m"
), "Only m2m operations are expected to produce dict changes now"
m2m_changes[field] = change
else:
atom_changes[field] = change
msg = []
msg.append("<table>")
msg.append(self._format_header("#", "Field", "From", "To"))
for i, (field, change) in enumerate(sorted(changes.items()), 1):
value = [i, field] + (["***", "***"] if field == "password" else change)
msg.append(self._format_line(*value))
msg.append("</table>")
if atom_changes:
msg.append("<table>")
msg.append(self._format_header("#", "Field", "From", "To"))
for i, (field, change) in enumerate(sorted(atom_changes.items()), 1):
value = [i, field] + (["***", "***"] if field == "password" else change)
msg.append(self._format_line(*value))
msg.append("</table>")
if m2m_changes:
msg.append("<table>")
msg.append(self._format_header("#", "Relationship", "Action", "Objects"))
for i, (field, change) in enumerate(sorted(m2m_changes.items()), 1):
change_html = format_html_join(
mark_safe("</br>"),
"{}",
[(value,) for value in change["objects"]],
)
msg.append(
format_html(
"<tr><td>{}</td><td>{}</td><td>{}</td><td>{}</td>",
i,
field,
change["operation"],
change_html,
)
)
msg.append("</table>")
return mark_safe("".join(msg))

View file

@ -76,6 +76,55 @@ class LogEntryManager(models.Manager):
)
return None
def log_m2m_changes(
self, changed_queryset, instance, operation, field_name, **kwargs
):
"""Create a new "changed" log entry from m2m record.
:param instance: The model instance to log a change for.
:type instance: Model
:param operation: "add" or "delete".
:type action: str
:param kwargs: Field overrides for the :py:class:`LogEntry` object.
:return: The new log entry or `None` if there were no changes.
:rtype: LogEntry
"""
pk = self._get_pk_value(instance)
if changed_queryset is not None:
kwargs.setdefault(
"content_type", ContentType.objects.get_for_model(instance)
)
kwargs.setdefault("object_pk", pk)
kwargs.setdefault("object_repr", smart_str(instance))
kwargs.setdefault("action", LogEntry.Action.UPDATE)
if isinstance(pk, int):
kwargs.setdefault("object_id", pk)
get_additional_data = getattr(instance, "get_additional_data", None)
if callable(get_additional_data):
kwargs.setdefault("additional_data", get_additional_data())
objects = [smart_str(instance) for instance in changed_queryset]
kwargs["changes"] = json.dumps(
{
field_name: {
"type": "m2m",
"operation": operation,
"objects": objects,
}
}
)
db = instance._state.db
return (
self.create(**kwargs)
if db is None or db == ""
else self.using(db).create(**kwargs)
)
return None
def get_for_object(self, instance):
"""
Get log entries for the specified model instance.

View file

@ -59,3 +59,34 @@ def log_delete(sender, instance, **kwargs):
action=LogEntry.Action.DELETE,
changes=json.dumps(changes),
)
def make_log_m2m_changes(field_name):
"""Return a handler for m2m_changed with field_name enclosed."""
def log_m2m_changes(signal, action, **kwargs):
"""Handle m2m_changed and call LogEntry.objects.log_m2m_changes as needed."""
if action not in ["post_add", "post_clear", "post_remove"]:
return
if action == "post_clear":
changed_queryset = kwargs["model"].objects.all()
else:
changed_queryset = kwargs["model"].objects.filter(pk__in=kwargs["pk_set"])
if action in ["post_add"]:
LogEntry.objects.log_m2m_changes(
changed_queryset,
kwargs["instance"],
"add",
field_name,
)
elif action in ["post_remove", "post_clear"]:
LogEntry.objects.log_m2m_changes(
changed_queryset,
kwargs["instance"],
"delete",
field_name,
)
return log_m2m_changes

View file

@ -1,8 +1,15 @@
from typing import Callable, Dict, List, Optional, Tuple
from collections import defaultdict
from typing import Callable, Collection, Dict, List, Optional, Tuple
from django.db.models import Model
from django.db.models.base import ModelBase
from django.db.models.signals import ModelSignal, post_delete, post_save, pre_save
from django.db.models.signals import (
ModelSignal,
m2m_changed,
post_delete,
post_save,
pre_save,
)
DispatchUID = Tuple[int, int, int]
@ -17,12 +24,14 @@ class AuditlogModelRegistry(object):
create: bool = True,
update: bool = True,
delete: bool = True,
m2m: bool = True,
custom: Optional[Dict[ModelSignal, Callable]] = None,
):
from auditlog.receivers import log_create, log_delete, log_update
self._registry = {}
self._signals = {}
self._m2m_signals = defaultdict(dict)
if create:
self._signals[post_save] = log_create
@ -30,6 +39,7 @@ class AuditlogModelRegistry(object):
self._signals[pre_save] = log_update
if delete:
self._signals[post_delete] = log_delete
self._m2m = m2m
if custom is not None:
self._signals.update(custom)
@ -40,6 +50,7 @@ class AuditlogModelRegistry(object):
include_fields: Optional[List[str]] = None,
exclude_fields: Optional[List[str]] = None,
mapping_fields: Optional[Dict[str, str]] = None,
m2m_fields: Optional[Collection[str]] = None,
):
"""
Register a model with auditlog. Auditlog will then track mutations on this model's instances.
@ -48,6 +59,7 @@ class AuditlogModelRegistry(object):
:param include_fields: The fields to include. Implicitly excludes all other fields.
:param exclude_fields: The fields to exclude. Overrides the fields to include.
:param mapping_fields: Mapping from field names to strings in diff.
:param m2m_fields: The fields to map as many to many.
"""
@ -57,6 +69,8 @@ class AuditlogModelRegistry(object):
exclude_fields = []
if mapping_fields is None:
mapping_fields = {}
if m2m_fields is None:
m2m_fields = set()
def registrar(cls):
"""Register models for a given class."""
@ -67,6 +81,7 @@ class AuditlogModelRegistry(object):
"include_fields": include_fields,
"exclude_fields": exclude_fields,
"mapping_fields": mapping_fields,
"m2m_fields": m2m_fields,
}
self._connect_signals(cls)
@ -121,12 +136,26 @@ class AuditlogModelRegistry(object):
"""
Connect signals for the model.
"""
from auditlog.receivers import make_log_m2m_changes
for signal, receiver in self._signals.items():
signal.connect(
receiver,
sender=model,
dispatch_uid=self._dispatch_uid(signal, receiver),
)
if self._m2m:
for field_name in self._registry[model]["m2m_fields"]:
receiver = make_log_m2m_changes(field_name)
self._m2m_signals[model][field_name] = receiver
field = getattr(model, field_name)
m2m_model = getattr(field, "through")
m2m_changed.connect(
receiver,
sender=m2m_model,
dispatch_uid=self._dispatch_uid(m2m_changed, receiver),
)
def _disconnect_signals(self, model):
"""
@ -136,6 +165,14 @@ class AuditlogModelRegistry(object):
signal.disconnect(
sender=model, dispatch_uid=self._dispatch_uid(signal, receiver)
)
for field_name, receiver in self._m2m_signals[model].items():
field = getattr(model, field_name)
m2m_model = getattr(field, "through")
m2m_changed.disconnect(
sender=m2m_model,
dispatch_uid=self._dispatch_uid(m2m_changed, receiver),
)
del self._m2m_signals[model]
def _dispatch_uid(self, signal, receiver) -> DispatchUID:
"""Generate a dispatch_uid which is unique for a combination of self, signal, and receiver."""

View file

@ -4,7 +4,9 @@ from django.contrib.postgres.fields import ArrayField
from django.db import models
from auditlog.models import AuditlogHistoryField
from auditlog.registry import auditlog
from auditlog.registry import AuditlogModelRegistry, auditlog
m2m_only_auditlog = AuditlogModelRegistry(create=False, update=False, delete=False)
@auditlog.register()
@ -80,6 +82,24 @@ class ManyRelatedModel(models.Model):
history = AuditlogHistoryField()
class FirstManyRelatedModel(models.Model):
"""
A model with a many to many relation to another model similar.
"""
related = models.ManyToManyField("OtherManyRelatedModel", related_name="related")
history = AuditlogHistoryField()
class OtherManyRelatedModel(models.Model):
"""
A model that 'receives' the other side of the many to many relation from 'FirstManyRelatedModel'.
"""
history = AuditlogHistoryField()
@auditlog.register(include_fields=["label"])
class SimpleIncludeModel(models.Model):
"""
@ -224,6 +244,9 @@ auditlog.register(ProxyModel)
auditlog.register(RelatedModel)
auditlog.register(ManyRelatedModel)
auditlog.register(ManyRelatedModel.related.through)
m2m_only_auditlog.register(
FirstManyRelatedModel, include_fields=["pk", "history"], m2m_fields={"related"}
)
auditlog.register(SimpleExcludeModel, exclude_fields=["text"])
auditlog.register(SimpleMappingModel, mapping_fields={"sku": "Product No."})
auditlog.register(AdditionalDataIncludedModel)

View file

@ -23,8 +23,10 @@ from auditlog_tests.models import (
CharfieldTextfieldModel,
ChoicesFieldModel,
DateTimeFieldModel,
FirstManyRelatedModel,
ManyRelatedModel,
NoDeleteHistoryModel,
OtherManyRelatedModel,
PostgresArrayFieldModel,
ProxyModel,
RelatedModel,
@ -234,7 +236,7 @@ class ProxyModelWithActorTest(WithActorMixin, ProxyModelBase):
class ManyRelatedModelTest(TestCase):
"""
Test the behaviour of a many-to-many relationship.
Test the behaviour of a default many-to-many relationship.
"""
def setUp(self):
@ -253,6 +255,60 @@ class ManyRelatedModelTest(TestCase):
)
class FirstManyRelatedModelTest(TestCase):
"""
Test the behaviour of a many-to-many relationship.
"""
def setUp(self):
self.obj = FirstManyRelatedModel.objects.create()
self.rel_obj = OtherManyRelatedModel.objects.create()
def test_related_add_from_first_side(self):
self.obj.related.add(self.rel_obj)
self.assertEqual(
LogEntry.objects.get_for_objects(self.obj.related.all()).count(),
self.rel_obj.history.count(),
)
self.assertEqual(
LogEntry.objects.get_for_objects(self.obj.related.all()).first(),
self.rel_obj.history.first(),
)
self.assertEqual(LogEntry.objects.count(), 1)
def test_related_add_from_other_side(self):
self.rel_obj.related.add(self.obj)
self.assertEqual(
LogEntry.objects.get_for_objects(self.obj.related.all()).count(),
self.rel_obj.history.count(),
)
self.assertEqual(
LogEntry.objects.get_for_objects(self.obj.related.all()).first(),
self.rel_obj.history.first(),
)
self.assertEqual(LogEntry.objects.count(), 1)
def test_related_remove_from_first_side(self):
self.obj.related.add(self.rel_obj)
self.obj.related.remove(self.rel_obj)
self.assertEqual(LogEntry.objects.count(), 2)
def test_related_remove_from_other_side(self):
self.rel_obj.related.add(self.obj)
self.rel_obj.related.remove(self.obj)
self.assertEqual(LogEntry.objects.count(), 2)
def test_related_clear_from_first_side(self):
self.obj.related.add(self.rel_obj)
self.obj.related.clear()
self.assertEqual(LogEntry.objects.count(), 2)
def test_related_clear_from_other_side(self):
self.rel_obj.related.add(self.obj)
self.rel_obj.related.clear()
self.assertEqual(LogEntry.objects.count(), 2)
class MiddlewareTest(TestCase):
"""
Test the middleware responsible for connecting and disconnecting the signals used in automatic logging.