Add logic to track changes to m2m fields (#309)

This commit is contained in:
Alieh Rymašeŭski 2022-06-08 18:09:27 +03:00 committed by GitHub
parent 2e9466d1b4
commit 10c47181bb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 382 additions and 42 deletions

View file

@ -6,6 +6,7 @@
- feat: Add db_index to the `LogEntry.timestamp` column ([#364](https://github.com/jazzband/django-auditlog/pull/364))
- feat: Add register model from settings ([#368](https://github.com/jazzband/django-auditlog/pull/368))
- Context manager set_actor() for use in Celery tasks ([#262](https://github.com/jazzband/django-auditlog/pull/262))
- Tracking of changes in many-to-many fields ([#309](https://github.com/jazzband/django-auditlog/pull/309))
#### Fixes

View file

@ -3,7 +3,7 @@ import json
from django import urls as urlresolvers
from django.conf import settings
from django.urls.exceptions import NoReverseMatch
from django.utils.html import format_html
from django.utils.html import format_html, format_html_join
from django.utils.safestring import mark_safe
from auditlog.models import LogEntry
@ -63,16 +63,61 @@ class LogEntryAdminMixin:
if obj.action == LogEntry.Action.DELETE:
return "" # delete
changes = json.loads(obj.changes)
msg = "<table><tr><th>#</th><th>Field</th><th>From</th><th>To</th></tr>"
for i, field in enumerate(sorted(changes), 1):
value = [i, field] + (
["***", "***"] if field == "password" else changes[field]
)
msg += format_html(
"<tr><td>{}</td><td>{}</td><td>{}</td><td>{}</td></tr>", *value
)
msg += "</table>"
return mark_safe(msg)
atom_changes = {}
m2m_changes = {}
for field, change in changes.items():
if isinstance(change, dict):
assert (
change["type"] == "m2m"
), "Only m2m operations are expected to produce dict changes now"
m2m_changes[field] = change
else:
atom_changes[field] = change
msg = []
if atom_changes:
msg.append("<table>")
msg.append(self._format_header("#", "Field", "From", "To"))
for i, (field, change) in enumerate(sorted(atom_changes.items()), 1):
value = [i, field] + (["***", "***"] if field == "password" else change)
msg.append(self._format_line(*value))
msg.append("</table>")
if m2m_changes:
msg.append("<table>")
msg.append(self._format_header("#", "Relationship", "Action", "Objects"))
for i, (field, change) in enumerate(sorted(m2m_changes.items()), 1):
change_html = format_html_join(
mark_safe("<br>"),
"{}",
[(value,) for value in change["objects"]],
)
msg.append(
format_html(
"<tr><td>{}</td><td>{}</td><td>{}</td><td>{}</td></tr>",
i,
field,
change["operation"],
change_html,
)
)
msg.append("</table>")
return mark_safe("".join(msg))
msg.short_description = "Changes"
def _format_header(self, *labels):
return format_html(
"".join(["<tr>", "<th>{}</th>" * len(labels), "</tr>"]), *labels
)
def _format_line(self, *values):
return format_html(
"".join(["<tr>", "<td>{}</td>" * len(values), "</tr>"]), *values
)

View file

@ -69,6 +69,54 @@ class LogEntryManager(models.Manager):
return self.create(**kwargs)
return None
def log_m2m_changes(
self, changed_queryset, instance, operation, field_name, **kwargs
):
"""Create a new "changed" log entry from m2m record.
:param changed_queryset: The added or removed related objects.
:type changed_queryset: QuerySet
:param instance: The model instance to log a change for.
:type instance: Model
:param operation: "add" or "delete".
:type action: str
:param field_name: The name of the changed m2m field.
:type field_name: str
:param kwargs: Field overrides for the :py:class:`LogEntry` object.
:return: The new log entry or `None` if there were no changes.
:rtype: LogEntry
"""
pk = self._get_pk_value(instance)
if changed_queryset is not None:
kwargs.setdefault(
"content_type", ContentType.objects.get_for_model(instance)
)
kwargs.setdefault("object_pk", pk)
kwargs.setdefault("object_repr", smart_str(instance))
kwargs.setdefault("action", LogEntry.Action.UPDATE)
if isinstance(pk, int):
kwargs.setdefault("object_id", pk)
get_additional_data = getattr(instance, "get_additional_data", None)
if callable(get_additional_data):
kwargs.setdefault("additional_data", get_additional_data())
objects = [smart_str(instance) for instance in changed_queryset]
kwargs["changes"] = json.dumps(
{
field_name: {
"type": "m2m",
"operation": operation,
"objects": objects,
}
}
)
return self.create(**kwargs)
return None
def get_for_object(self, instance):
"""
Get log entries for the specified model instance.

View file

@ -59,3 +59,34 @@ def log_delete(sender, instance, **kwargs):
action=LogEntry.Action.DELETE,
changes=json.dumps(changes),
)
def make_log_m2m_changes(field_name):
"""Return a handler for m2m_changed with field_name enclosed."""
def log_m2m_changes(signal, action, **kwargs):
"""Handle m2m_changed and call LogEntry.objects.log_m2m_changes as needed."""
if action not in ["post_add", "post_clear", "post_remove"]:
return
if action == "post_clear":
changed_queryset = kwargs["model"].objects.all()
else:
changed_queryset = kwargs["model"].objects.filter(pk__in=kwargs["pk_set"])
if action in ["post_add"]:
LogEntry.objects.log_m2m_changes(
changed_queryset,
kwargs["instance"],
"add",
field_name,
)
elif action in ["post_remove", "post_clear"]:
LogEntry.objects.log_m2m_changes(
changed_queryset,
kwargs["instance"],
"delete",
field_name,
)
return log_m2m_changes

View file

@ -1,14 +1,31 @@
import copy
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from collections import defaultdict
from typing import (
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
from django.apps import apps
from django.db.models import Model
from django.db.models.base import ModelBase
from django.db.models.signals import ModelSignal, post_delete, post_save, pre_save
from django.db.models.signals import (
ModelSignal,
m2m_changed,
post_delete,
post_save,
pre_save,
)
from auditlog.conf import settings
DispatchUID = Tuple[int, str, int]
DispatchUID = Tuple[int, int, int]
class AuditlogModelRegistry:
@ -23,12 +40,14 @@ class AuditlogModelRegistry:
create: bool = True,
update: bool = True,
delete: bool = True,
m2m: bool = True,
custom: Optional[Dict[ModelSignal, Callable]] = None,
):
from auditlog.receivers import log_create, log_delete, log_update
self._registry = {}
self._signals = {}
self._m2m_signals = defaultdict(dict)
if create:
self._signals[post_save] = log_create
@ -36,6 +55,7 @@ class AuditlogModelRegistry:
self._signals[pre_save] = log_update
if delete:
self._signals[post_delete] = log_delete
self._m2m = m2m
if custom is not None:
self._signals.update(custom)
@ -47,6 +67,7 @@ class AuditlogModelRegistry:
exclude_fields: Optional[List[str]] = None,
mapping_fields: Optional[Dict[str, str]] = None,
mask_fields: Optional[List[str]] = None,
m2m_fields: Optional[Collection[str]] = None,
):
"""
Register a model with auditlog. Auditlog will then track mutations on this model's instances.
@ -56,6 +77,7 @@ class AuditlogModelRegistry:
:param exclude_fields: The fields to exclude. Overrides the fields to include.
:param mapping_fields: Mapping from field names to strings in diff.
:param mask_fields: The fields to mask for sensitive info.
:param m2m_fields: The fields to handle as many to many.
"""
@ -67,6 +89,8 @@ class AuditlogModelRegistry:
mapping_fields = {}
if mask_fields is None:
mask_fields = []
if m2m_fields is None:
m2m_fields = set()
def registrar(cls):
"""Register models for a given class."""
@ -78,6 +102,7 @@ class AuditlogModelRegistry:
"exclude_fields": exclude_fields,
"mapping_fields": mapping_fields,
"mask_fields": mask_fields,
"m2m_fields": m2m_fields,
}
self._connect_signals(cls)
@ -132,11 +157,26 @@ class AuditlogModelRegistry:
"""
Connect signals for the model.
"""
for signal in self._signals:
receiver = self._signals[signal]
from auditlog.receivers import make_log_m2m_changes
for signal, receiver in self._signals.items():
signal.connect(
receiver, sender=model, dispatch_uid=self._dispatch_uid(signal, model)
receiver,
sender=model,
dispatch_uid=self._dispatch_uid(signal, receiver),
)
if self._m2m:
for field_name in self._registry[model]["m2m_fields"]:
receiver = make_log_m2m_changes(field_name)
self._m2m_signals[model][field_name] = receiver
field = getattr(model, field_name)
m2m_model = getattr(field, "through")
m2m_changed.connect(
receiver,
sender=m2m_model,
dispatch_uid=self._dispatch_uid(m2m_changed, receiver),
)
def _disconnect_signals(self, model):
"""
@ -144,14 +184,20 @@ class AuditlogModelRegistry:
"""
for signal, receiver in self._signals.items():
signal.disconnect(
sender=model, dispatch_uid=self._dispatch_uid(signal, model)
sender=model, dispatch_uid=self._dispatch_uid(signal, receiver)
)
for field_name, receiver in self._m2m_signals[model].items():
field = getattr(model, field_name)
m2m_model = getattr(field, "through")
m2m_changed.disconnect(
sender=m2m_model,
dispatch_uid=self._dispatch_uid(m2m_changed, receiver),
)
del self._m2m_signals[model]
def _dispatch_uid(self, signal, model) -> DispatchUID:
"""
Generate a dispatch_uid.
"""
return self.__hash__(), model.__qualname__, signal.__hash__()
def _dispatch_uid(self, signal, receiver) -> DispatchUID:
"""Generate a dispatch_uid which is unique for a combination of self, signal, and receiver."""
return id(self), id(signal), id(receiver)
def _get_model_classes(self, app_model: str) -> List[ModelBase]:
try:

View file

@ -4,7 +4,9 @@ from django.contrib.postgres.fields import ArrayField
from django.db import models
from auditlog.models import AuditlogHistoryField
from auditlog.registry import auditlog
from auditlog.registry import AuditlogModelRegistry, auditlog
m2m_only_auditlog = AuditlogModelRegistry(create=False, update=False, delete=False)
@auditlog.register()
@ -81,10 +83,23 @@ class RelatedModel(RelatedModelParent):
class ManyRelatedModel(models.Model):
"""
A model with a many to many relation.
A model with many-to-many relations.
"""
related = models.ManyToManyField("self")
recursive = models.ManyToManyField("self")
related = models.ManyToManyField("ManyRelatedOtherModel", related_name="related")
history = AuditlogHistoryField()
def get_additional_data(self):
related = self.related.first()
return {"related_model_id": related.id if related else None}
class ManyRelatedOtherModel(models.Model):
"""
A model related to ManyRelatedModel as many-to-many.
"""
history = AuditlogHistoryField()
@ -250,7 +265,8 @@ auditlog.register(UUIDPrimaryKeyModel)
auditlog.register(ProxyModel)
auditlog.register(RelatedModel)
auditlog.register(ManyRelatedModel)
auditlog.register(ManyRelatedModel.related.through)
auditlog.register(ManyRelatedModel.recursive.through)
m2m_only_auditlog.register(ManyRelatedModel, m2m_fields={"related"})
auditlog.register(SimpleExcludeModel, exclude_fields=["text"])
auditlog.register(SimpleMappingModel, mapping_fields={"sku": "Product No."})
auditlog.register(AdditionalDataIncludedModel)

View file

@ -6,14 +6,15 @@ from unittest import mock
from dateutil.tz import gettz
from django.apps import apps
from django.conf import settings
from django.contrib.admin.sites import AdminSite
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser, User
from django.contrib.contenttypes.models import ContentType
from django.db.models.signals import pre_save
from django.http import HttpResponse
from django.test import RequestFactory, TestCase, override_settings
from django.utils import dateformat, formats, timezone
from auditlog.admin import LogEntryAdmin
from auditlog.context import set_actor
from auditlog.diff import model_instance_diff
from auditlog.middleware import AuditlogMiddleware
@ -27,6 +28,7 @@ from auditlog_tests.models import (
DateTimeFieldModel,
JSONModel,
ManyRelatedModel,
ManyRelatedOtherModel,
NoDeleteHistoryModel,
PostgresArrayFieldModel,
ProxyModel,
@ -300,22 +302,65 @@ class ProxyModelWithActorTest(WithActorMixin, ProxyModelBase):
class ManyRelatedModelTest(TestCase):
"""
Test the behaviour of a many-to-many relationship.
Test the behaviour of many-to-many relationships.
"""
def setUp(self):
self.obj = ManyRelatedModel.objects.create()
self.rel_obj = ManyRelatedModel.objects.create()
self.obj.related.add(self.rel_obj)
self.recursive = ManyRelatedModel.objects.create()
self.related = ManyRelatedOtherModel.objects.create()
self.base_log_entry_count = (
LogEntry.objects.count()
) # created by the create() calls above
def test_related(self):
def test_recursive(self):
self.obj.recursive.add(self.recursive)
self.assertEqual(
LogEntry.objects.get_for_objects(self.obj.related.all()).count(),
self.rel_obj.history.count(),
LogEntry.objects.get_for_objects(self.obj.recursive.all()).first(),
self.recursive.history.first(),
)
def test_related_add_from_first_side(self):
self.obj.related.add(self.related)
self.assertEqual(
LogEntry.objects.get_for_objects(self.obj.related.all()).first(),
self.rel_obj.history.first(),
self.related.history.first(),
)
self.assertEqual(LogEntry.objects.count(), self.base_log_entry_count + 1)
def test_related_add_from_other_side(self):
self.related.related.add(self.obj)
self.assertEqual(
LogEntry.objects.get_for_objects(self.obj.related.all()).first(),
self.related.history.first(),
)
self.assertEqual(LogEntry.objects.count(), self.base_log_entry_count + 1)
def test_related_remove_from_first_side(self):
self.obj.related.add(self.related)
self.obj.related.remove(self.related)
self.assertEqual(LogEntry.objects.count(), self.base_log_entry_count + 2)
def test_related_remove_from_other_side(self):
self.related.related.add(self.obj)
self.related.related.remove(self.obj)
self.assertEqual(LogEntry.objects.count(), self.base_log_entry_count + 2)
def test_related_clear_from_first_side(self):
self.obj.related.add(self.related)
self.obj.related.clear()
self.assertEqual(LogEntry.objects.count(), self.base_log_entry_count + 2)
def test_related_clear_from_other_side(self):
self.related.related.add(self.obj)
self.related.related.clear()
self.assertEqual(LogEntry.objects.count(), self.base_log_entry_count + 2)
def test_additional_data(self):
self.obj.related.add(self.related)
log_entry = self.obj.history.first()
self.assertEqual(
log_entry.additional_data, {"related_model_id": self.related.id}
)
@ -325,9 +370,6 @@ class MiddlewareTest(TestCase):
"""
def setUp(self):
def get_response(request):
return HttpResponse()
self.get_response_mock = mock.Mock()
self.response_mock = mock.Mock()
self.middleware = AuditlogMiddleware(get_response=self.get_response_mock)
@ -927,7 +969,7 @@ class RegisterModelSettingsTest(TestCase):
self.assertTrue(self.test_auditlog.contains(SimpleExcludeModel))
self.assertTrue(self.test_auditlog.contains(ChoicesFieldModel))
self.assertEqual(len(self.test_auditlog.get_models()), 18)
self.assertEqual(len(self.test_auditlog.get_models()), 19)
def test_register_models_register_model_with_attrs(self):
self.test_auditlog._register_models(
@ -947,6 +989,21 @@ class RegisterModelSettingsTest(TestCase):
self.assertEqual(fields["include_fields"], ["label"])
self.assertEqual(fields["exclude_fields"], ["text"])
def test_register_models_register_model_with_m2m_fields(self):
self.test_auditlog._register_models(
(
{
"model": "auditlog_tests.ManyRelatedModel",
"m2m_fields": {"related"},
},
)
)
self.assertTrue(self.test_auditlog.contains(ManyRelatedModel))
self.assertEqual(
self.test_auditlog._registry[ManyRelatedModel]["m2m_fields"], {"related"}
)
def test_register_from_settings_invalid_settings(self):
with override_settings(AUDITLOG_INCLUDE_ALL_MODELS="str"):
with self.assertRaisesMessage(
@ -1177,6 +1234,87 @@ class AdminPanelTest(TestCase):
assert res.status_code == 200
class DiffMsgTest(TestCase):
def setUp(self):
super().setUp()
self.site = AdminSite()
self.admin = LogEntryAdmin(LogEntry, self.site)
def _create_log_entry(self, action, changes):
return LogEntry.objects.log_create(
SimpleModel.objects.create(), # doesn't affect anything
action=action,
changes=json.dumps(changes),
)
def test_changes_msg__delete(self):
log_entry = self._create_log_entry(LogEntry.Action.DELETE, {})
self.assertEqual(self.admin.msg(log_entry), "")
def test_changes_msg__create(self):
log_entry = self._create_log_entry(
LogEntry.Action.CREATE,
{
"field two": [None, 11],
"field one": [None, "a value"],
},
)
self.assertEqual(
self.admin.msg(log_entry),
(
"<table>"
"<tr><th>#</th><th>Field</th><th>From</th><th>To</th></tr>"
"<tr><td>1</td><td>field one</td><td>None</td><td>a value</td></tr>"
"<tr><td>2</td><td>field two</td><td>None</td><td>11</td></tr>"
"</table>"
),
)
def test_changes_msg__update(self):
log_entry = self._create_log_entry(
LogEntry.Action.UPDATE,
{
"field two": [11, 42],
"field one": ["old value of field one", "new value of field one"],
},
)
self.assertEqual(
self.admin.msg(log_entry),
(
"<table>"
"<tr><th>#</th><th>Field</th><th>From</th><th>To</th></tr>"
"<tr><td>1</td><td>field one</td><td>old value of field one</td><td>new value of field one</td></tr>"
"<tr><td>2</td><td>field two</td><td>11</td><td>42</td></tr>"
"</table>"
),
)
def test_changes_msg__m2m(self):
log_entry = self._create_log_entry(
LogEntry.Action.UPDATE,
{ # mimicking the format used by log_m2m_changes
"some_m2m_field": {
"type": "m2m",
"operation": "add",
"objects": ["Example User (user 1)", "Illustration (user 42)"],
},
},
)
self.assertEqual(
self.admin.msg(log_entry),
(
"<table>"
"<tr><th>#</th><th>Relationship</th><th>Action</th><th>Objects</th></tr>"
"<tr><td>1</td><td>some_m2m_field</td><td>add</td><td>Example User (user 1)<br>Illustration (user 42)</td></tr>"
"</table>"
),
)
class NoDeleteHistoryTest(TestCase):
def test_delete_related(self):
instance = SimpleModel.objects.create(integer=1)

View file

@ -11,6 +11,8 @@ even more convenience, :py:class:`LogEntryManager` provides a number of methods
See :doc:`internals` for all details.
.. _Automatically logging changes:
Automatically logging changes
-----------------------------
@ -91,6 +93,19 @@ For example, to mask the field ``address``, use::
Masking fields
**Many-to-many fields**
Changes to many-to-many fields are not tracked by default. If you want to enable tracking of a many-to-many field on a model, pass ``m2m_fields`` to the ``register`` method:
.. code-block:: python
auditlog.register(MyModel, m2m_fields={"tags", "contacts"})
This functionality is based on the ``m2m_changed`` signal sent by the ``through`` model of the relationship.
Note that when the user changes multiple many-to-many fields on the same object through the admin, both adding and removing some objects from each, this code will generate multiple log entries: each log entry will represent a single operation (add or delete) of a single field, e.g. if you both add and delete values from 2 fields on the same form in the same request, you'll get 4 log entries.
.. versionadded:: 2.1.0
Settings
--------
@ -139,6 +154,7 @@ It must be a list or tuple. Each item in this setting can be a:
"field1": "FIELD",
},
"mask_fields": ["field5", "field6"],
"m2m_fields": ["field7", "field8"],
},
"<appname>.<model3>",
)
@ -250,10 +266,9 @@ Many-to-many relationships
.. versionadded:: 0.3.0
.. warning::
.. note::
To-many relations are not officially supported. However, this section shows a workaround which can be used for now.
In the future, this workaround may be used in an official API or a completly different strategy might be chosen.
This section shows a workaround which can be used to track many-to-many relationships on older versions of django-auditlog. For versions 2.1.0 and onwards, please see the many-to-many fields section of :ref:`Automatically logging changes`.
**Do not rely on the workaround here to be stable across releases.**
By default, many-to-many relationships are not tracked by Auditlog.