Add support for custom masking functions (#725)

* Add test cases for the `mask_str` function

* Add custom masking function support through mask_callable

* Add test cases for custom masking function

* Update documentation for custom masking function

* fix test case

* rename `AUDITLOG_DEFAULT_MASK_CALLABLE` variable

-AUDITLOG_DEFAULT_MASK_CALLABLE to `AUDITLOG_MASK_CALLABLE`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update `CHANGELOG.md` to include mask function customization feature

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Youngkwang Yang 2025-06-09 22:29:59 +09:00 committed by GitHub
parent 3a58e0a999
commit af78976e53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 168 additions and 9 deletions

View file

@ -5,6 +5,7 @@
#### Improvements #### Improvements
- feat: Support storing JSON in the changes field when ```AUDITLOG_STORE_JSON_CHANGES``` is enabled. ([#719](https://github.com/jazzband/django-auditlog/pull/719)) - feat: Support storing JSON in the changes field when ```AUDITLOG_STORE_JSON_CHANGES``` is enabled. ([#719](https://github.com/jazzband/django-auditlog/pull/719))
- feat: Added `AUDITLOG_MASK_CALLABLE` setting to allow custom masking functions ([#725](https://github.com/jazzband/django-auditlog/pull/725))
#### Fixes #### Fixes

View file

@ -60,3 +60,5 @@ settings.AUDITLOG_CHANGE_DISPLAY_TRUNCATE_LENGTH = getattr(
settings.AUDITLOG_STORE_JSON_CHANGES = getattr( settings.AUDITLOG_STORE_JSON_CHANGES = getattr(
settings, "AUDITLOG_STORE_JSON_CHANGES", False settings, "AUDITLOG_STORE_JSON_CHANGES", False
) )
settings.AUDITLOG_MASK_CALLABLE = getattr(settings, "AUDITLOG_MASK_CALLABLE", None)

View file

@ -1,12 +1,13 @@
import json import json
from datetime import timezone from datetime import timezone
from typing import Optional from typing import Callable, Optional
from django.conf import settings from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.db.models import NOT_PROVIDED, DateTimeField, ForeignKey, JSONField, Model from django.db.models import NOT_PROVIDED, DateTimeField, ForeignKey, JSONField, Model
from django.utils import timezone as django_timezone from django.utils import timezone as django_timezone
from django.utils.encoding import smart_str from django.utils.encoding import smart_str
from django.utils.module_loading import import_string
def track_field(field): def track_field(field):
@ -130,6 +131,29 @@ def is_primitive(obj) -> bool:
return isinstance(obj, primitive_types) return isinstance(obj, primitive_types)
def get_mask_function(mask_callable: Optional[str] = None) -> Callable[[str], str]:
"""
Get the masking function to use based on the following priority:
1. Model-specific mask_callable if provided
2. mask_callable from settings if configured
3. Default mask_str function
:param mask_callable: The dotted path to a callable that will be used for masking.
:type mask_callable: str
:return: A callable that takes a string and returns a masked version.
:rtype: Callable[[str], str]
"""
if mask_callable:
return import_string(mask_callable)
default_mask_callable = settings.AUDITLOG_MASK_CALLABLE
if default_mask_callable:
return import_string(default_mask_callable)
return mask_str
def mask_str(value: str) -> str: def mask_str(value: str) -> str:
""" """
Masks the first half of the input string to remove sensitive data. Masks the first half of the input string to remove sensitive data.
@ -226,9 +250,11 @@ def model_instance_diff(
if old_value != new_value: if old_value != new_value:
if model_fields and field.name in model_fields["mask_fields"]: if model_fields and field.name in model_fields["mask_fields"]:
mask_func = get_mask_function(model_fields.get("mask_callable"))
diff[field.name] = ( diff[field.name] = (
mask_str(smart_str(old_value)), mask_func(smart_str(old_value)),
mask_str(smart_str(new_value)), mask_func(smart_str(new_value)),
) )
else: else:
if not use_json_for_changes: if not use_json_for_changes:

View file

@ -23,7 +23,7 @@ from django.utils import timezone as django_timezone
from django.utils.encoding import smart_str from django.utils.encoding import smart_str
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from auditlog.diff import mask_str from auditlog.diff import get_mask_function
DEFAULT_OBJECT_REPR = "<error forming object repr>" DEFAULT_OBJECT_REPR = "<error forming object repr>"
@ -247,7 +247,7 @@ class LogEntryManager(models.Manager):
mask_fields = model_fields["mask_fields"] mask_fields = model_fields["mask_fields"]
if mask_fields: if mask_fields:
data = self._mask_serialized_fields(data, mask_fields) data = self._mask_serialized_fields(data, mask_fields, model_fields)
return data return data
@ -287,14 +287,15 @@ class LogEntryManager(models.Manager):
return list(set(include_fields or all_field_names).difference(exclude_fields)) return list(set(include_fields or all_field_names).difference(exclude_fields))
def _mask_serialized_fields( def _mask_serialized_fields(
self, data: dict[str, Any], mask_fields: list[str] self, data: dict[str, Any], mask_fields: list[str], model_fields: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
all_field_data = data.pop("fields") all_field_data = data.pop("fields")
mask_func = get_mask_function(model_fields.get("mask_callable"))
masked_field_data = {} masked_field_data = {}
for key, value in all_field_data.items(): for key, value in all_field_data.items():
if isinstance(value, str) and key in mask_fields: if isinstance(value, str) and key in mask_fields:
masked_field_data[key] = mask_str(value) masked_field_data[key] = mask_func(value)
else: else:
masked_field_data[key] = value masked_field_data[key] = value

View file

@ -66,6 +66,7 @@ class AuditlogModelRegistry:
exclude_fields: Optional[list[str]] = None, exclude_fields: Optional[list[str]] = None,
mapping_fields: Optional[dict[str, str]] = None, mapping_fields: Optional[dict[str, str]] = None,
mask_fields: Optional[list[str]] = None, mask_fields: Optional[list[str]] = None,
mask_callable: Optional[str] = None,
m2m_fields: Optional[Collection[str]] = None, m2m_fields: Optional[Collection[str]] = None,
serialize_data: bool = False, serialize_data: bool = False,
serialize_kwargs: Optional[dict[str, Any]] = None, serialize_kwargs: Optional[dict[str, Any]] = None,
@ -79,6 +80,8 @@ class AuditlogModelRegistry:
:param exclude_fields: The fields to exclude. Overrides the fields to include. :param exclude_fields: The fields to exclude. Overrides the fields to include.
:param mapping_fields: Mapping from field names to strings in diff. :param mapping_fields: Mapping from field names to strings in diff.
:param mask_fields: The fields to mask for sensitive info. :param mask_fields: The fields to mask for sensitive info.
:param mask_callable: The dotted path to a callable that will be used for masking. If not provided,
the default mask_callable will be used.
:param m2m_fields: The fields to handle as many to many. :param m2m_fields: The fields to handle as many to many.
:param serialize_data: Option to include a dictionary of the objects state in the auditlog. :param serialize_data: Option to include a dictionary of the objects state in the auditlog.
:param serialize_kwargs: Optional kwargs to pass to Django serializer :param serialize_kwargs: Optional kwargs to pass to Django serializer
@ -120,6 +123,7 @@ class AuditlogModelRegistry:
"exclude_fields": exclude_fields, "exclude_fields": exclude_fields,
"mapping_fields": mapping_fields, "mapping_fields": mapping_fields,
"mask_fields": mask_fields, "mask_fields": mask_fields,
"mask_callable": mask_callable,
"m2m_fields": m2m_fields, "m2m_fields": m2m_fields,
"serialize_data": serialize_data, "serialize_data": serialize_data,
"serialize_kwargs": serialize_kwargs, "serialize_kwargs": serialize_kwargs,
@ -172,6 +176,7 @@ class AuditlogModelRegistry:
"exclude_fields": list(self._registry[model]["exclude_fields"]), "exclude_fields": list(self._registry[model]["exclude_fields"]),
"mapping_fields": dict(self._registry[model]["mapping_fields"]), "mapping_fields": dict(self._registry[model]["mapping_fields"]),
"mask_fields": list(self._registry[model]["mask_fields"]), "mask_fields": list(self._registry[model]["mask_fields"]),
"mask_callable": self._registry[model]["mask_callable"],
} }
def get_serialize_options(self, model: ModelBase): def get_serialize_options(self, model: ModelBase):

View file

@ -0,0 +1,6 @@
def custom_mask_str(value: str) -> str:
"""Custom masking function that only shows the last 4 characters."""
if len(value) > 4:
return "****" + value[-4:]
return value

View file

@ -425,6 +425,13 @@ class AutoManyRelatedModel(models.Model):
related = models.ManyToManyField(SimpleModel) related = models.ManyToManyField(SimpleModel)
class CustomMaskModel(models.Model):
credit_card = models.CharField(max_length=16)
text = models.TextField()
history = AuditlogHistoryField(delete_related=True)
auditlog.register(AltPrimaryKeyModel) auditlog.register(AltPrimaryKeyModel)
auditlog.register(UUIDPrimaryKeyModel) auditlog.register(UUIDPrimaryKeyModel)
auditlog.register(ModelPrimaryKeyModel) auditlog.register(ModelPrimaryKeyModel)
@ -462,3 +469,8 @@ auditlog.register(
serialize_data=True, serialize_data=True,
serialize_kwargs={"use_natural_foreign_keys": True}, serialize_kwargs={"use_natural_foreign_keys": True},
) )
auditlog.register(
CustomMaskModel,
mask_fields=["credit_card"],
mask_callable="auditlog_tests.test_app.mask.custom_mask_str",
)

View file

@ -33,6 +33,7 @@ from test_app.models import (
AutoManyRelatedModel, AutoManyRelatedModel,
CharfieldTextfieldModel, CharfieldTextfieldModel,
ChoicesFieldModel, ChoicesFieldModel,
CustomMaskModel,
DateTimeFieldModel, DateTimeFieldModel,
JSONModel, JSONModel,
ManyRelatedModel, ManyRelatedModel,
@ -62,7 +63,7 @@ from test_app.models import (
from auditlog.admin import LogEntryAdmin from auditlog.admin import LogEntryAdmin
from auditlog.cid import get_cid from auditlog.cid import get_cid
from auditlog.context import disable_auditlog, set_actor from auditlog.context import disable_auditlog, set_actor
from auditlog.diff import model_instance_diff from auditlog.diff import mask_str, model_instance_diff
from auditlog.middleware import AuditlogMiddleware from auditlog.middleware import AuditlogMiddleware
from auditlog.models import DEFAULT_OBJECT_REPR, LogEntry from auditlog.models import DEFAULT_OBJECT_REPR, LogEntry
from auditlog.registry import AuditlogModelRegistry, AuditLogRegistrationError, auditlog from auditlog.registry import AuditlogModelRegistry, AuditLogRegistrationError, auditlog
@ -810,6 +811,21 @@ class SimpleMaskedFieldsModelTest(TestCase):
msg="The diff function masks 'address' field.", msg="The diff function masks 'address' field.",
) )
@override_settings(
AUDITLOG_MASK_CALLABLE="auditlog_tests.test_app.mask.custom_mask_str"
)
def test_global_mask_callable(self):
"""Test that global mask_callable from settings is used when model-specific one is not provided"""
instance = SimpleMaskedModel.objects.create(
address="1234567890123456", text="Some text"
)
self.assertEqual(
instance.history.latest().changes_dict["address"][1],
"****3456",
msg="The global masking function should be used when model-specific one is not provided",
)
class AdditionalDataModelTest(TestCase): class AdditionalDataModelTest(TestCase):
"""Log additional data if get_additional_data is defined in the model""" """Log additional data if get_additional_data is defined in the model"""
@ -1276,7 +1292,7 @@ class RegisterModelSettingsTest(TestCase):
self.assertTrue(self.test_auditlog.contains(SimpleExcludeModel)) self.assertTrue(self.test_auditlog.contains(SimpleExcludeModel))
self.assertTrue(self.test_auditlog.contains(ChoicesFieldModel)) self.assertTrue(self.test_auditlog.contains(ChoicesFieldModel))
self.assertEqual(len(self.test_auditlog.get_models()), 32) self.assertEqual(len(self.test_auditlog.get_models()), 33)
def test_register_models_register_model_with_attrs(self): def test_register_models_register_model_with_attrs(self):
self.test_auditlog._register_models( self.test_auditlog._register_models(
@ -2888,3 +2904,62 @@ class ModelManagerTest(TestCase):
log = LogEntry.objects.get_for_object(self.public).first() log = LogEntry.objects.get_for_object(self.public).first()
self.assertEqual(log.action, LogEntry.Action.UPDATE) self.assertEqual(log.action, LogEntry.Action.UPDATE)
self.assertEqual(log.changes_dict["name"], ["Public", "Updated"]) self.assertEqual(log.changes_dict["name"], ["Public", "Updated"])
class TestMaskStr(TestCase):
"""Test the mask_str function that masks sensitive data."""
def test_mask_str_empty(self):
self.assertEqual(mask_str(""), "")
def test_mask_str_single_char(self):
self.assertEqual(mask_str("a"), "a")
def test_mask_str_even_length(self):
self.assertEqual(mask_str("1234"), "**34")
def test_mask_str_odd_length(self):
self.assertEqual(mask_str("12345"), "**345")
def test_mask_str_long_text(self):
self.assertEqual(mask_str("confidential"), "******ential")
class CustomMaskModelTest(TestCase):
def test_custom_mask_function(self):
instance = CustomMaskModel.objects.create(
credit_card="1234567890123456", text="Some text"
)
self.assertEqual(
instance.history.latest().changes_dict["credit_card"][1],
"****3456",
msg="The custom masking function should mask all but last 4 digits",
)
def test_custom_mask_function_short_value(self):
"""Test that custom masking function handles short values correctly"""
instance = CustomMaskModel.objects.create(credit_card="123", text="Some text")
self.assertEqual(
instance.history.latest().changes_dict["credit_card"][1],
"123",
msg="The custom masking function should not mask values shorter than 4 characters",
)
def test_custom_mask_function_serialized_data(self):
instance = CustomMaskModel.objects.create(
credit_card="1234567890123456", text="Some text"
)
log = instance.history.latest()
self.assertTrue(isinstance(log, LogEntry))
self.assertEqual(log.action, LogEntry.Action.CREATE)
# Update to trigger serialization
instance.credit_card = "9876543210987654"
instance.save()
log = instance.history.latest()
self.assertEqual(
log.changes_dict["credit_card"][1],
"****7654",
msg="The custom masking function should be used in serialized data",
)

View file

@ -132,6 +132,37 @@ For example, to mask the field ``address``, use::
auditlog.register(MyModel, mask_fields=['address']) auditlog.register(MyModel, mask_fields=['address'])
You can also specify a custom masking function by passing ``mask_callable`` to the ``register``
method. The ``mask_callable`` should be a dotted path to a function that takes a string and returns
a masked version of that string.
For example, to use a custom masking function::
# In your_app/utils.py
def custom_mask(value: str) -> str:
return "****" + value[-4:] # Only show last 4 characters
# In your models.py
auditlog.register(
MyModel,
mask_fields=['credit_card'],
mask_callable='your_app.utils.custom_mask'
)
Additionally, you can set a global default masking function that will be used when a model-specific
mask_callable is not provided. To do this, add the following to your Django settings::
AUDITLOG_MASK_CALLABLE = 'your_app.utils.custom_mask'
The masking function priority is as follows:
1. Model-specific ``mask_callable`` if provided in ``register()``
2. ``AUDITLOG_MASK_CALLABLE`` from settings if configured
3. Default ``mask_str`` function which masks the first half of the string with asterisks
If ``mask_callable`` is not specified and no global default is configured, the default masking function will be used which masks
the first half of the string with asterisks.
.. versionadded:: 2.0.0 .. versionadded:: 2.0.0
Masking fields Masking fields