Add support for custom masking functions (#725)

* Add test cases for the `mask_str` function

* Add custom masking function support through mask_callable

* Add test cases for custom masking function

* Update documentation for custom masking function

* fix test case

* rename `AUDITLOG_DEFAULT_MASK_CALLABLE` variable

-AUDITLOG_DEFAULT_MASK_CALLABLE to `AUDITLOG_MASK_CALLABLE`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update `CHANGELOG.md` to include mask function customization feature

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Youngkwang Yang 2025-06-09 22:29:59 +09:00 committed by GitHub
parent 3a58e0a999
commit af78976e53
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 168 additions and 9 deletions

View file

@ -5,6 +5,7 @@
#### Improvements
- feat: Support storing JSON in the changes field when ```AUDITLOG_STORE_JSON_CHANGES``` is enabled. ([#719](https://github.com/jazzband/django-auditlog/pull/719))
- feat: Added `AUDITLOG_MASK_CALLABLE` setting to allow custom masking functions ([#725](https://github.com/jazzband/django-auditlog/pull/725))
#### Fixes

View file

@ -60,3 +60,5 @@ settings.AUDITLOG_CHANGE_DISPLAY_TRUNCATE_LENGTH = getattr(
settings.AUDITLOG_STORE_JSON_CHANGES = getattr(
settings, "AUDITLOG_STORE_JSON_CHANGES", False
)
settings.AUDITLOG_MASK_CALLABLE = getattr(settings, "AUDITLOG_MASK_CALLABLE", None)

View file

@ -1,12 +1,13 @@
import json
from datetime import timezone
from typing import Optional
from typing import Callable, Optional
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.db.models import NOT_PROVIDED, DateTimeField, ForeignKey, JSONField, Model
from django.utils import timezone as django_timezone
from django.utils.encoding import smart_str
from django.utils.module_loading import import_string
def track_field(field):
@ -130,6 +131,29 @@ def is_primitive(obj) -> bool:
return isinstance(obj, primitive_types)
def get_mask_function(mask_callable: Optional[str] = None) -> Callable[[str], str]:
"""
Get the masking function to use based on the following priority:
1. Model-specific mask_callable if provided
2. mask_callable from settings if configured
3. Default mask_str function
:param mask_callable: The dotted path to a callable that will be used for masking.
:type mask_callable: str
:return: A callable that takes a string and returns a masked version.
:rtype: Callable[[str], str]
"""
if mask_callable:
return import_string(mask_callable)
default_mask_callable = settings.AUDITLOG_MASK_CALLABLE
if default_mask_callable:
return import_string(default_mask_callable)
return mask_str
def mask_str(value: str) -> str:
"""
Masks the first half of the input string to remove sensitive data.
@ -226,9 +250,11 @@ def model_instance_diff(
if old_value != new_value:
if model_fields and field.name in model_fields["mask_fields"]:
mask_func = get_mask_function(model_fields.get("mask_callable"))
diff[field.name] = (
mask_str(smart_str(old_value)),
mask_str(smart_str(new_value)),
mask_func(smart_str(old_value)),
mask_func(smart_str(new_value)),
)
else:
if not use_json_for_changes:

View file

@ -23,7 +23,7 @@ from django.utils import timezone as django_timezone
from django.utils.encoding import smart_str
from django.utils.translation import gettext_lazy as _
from auditlog.diff import mask_str
from auditlog.diff import get_mask_function
DEFAULT_OBJECT_REPR = "<error forming object repr>"
@ -247,7 +247,7 @@ class LogEntryManager(models.Manager):
mask_fields = model_fields["mask_fields"]
if mask_fields:
data = self._mask_serialized_fields(data, mask_fields)
data = self._mask_serialized_fields(data, mask_fields, model_fields)
return data
@ -287,14 +287,15 @@ class LogEntryManager(models.Manager):
return list(set(include_fields or all_field_names).difference(exclude_fields))
def _mask_serialized_fields(
self, data: dict[str, Any], mask_fields: list[str]
self, data: dict[str, Any], mask_fields: list[str], model_fields: dict[str, Any]
) -> dict[str, Any]:
all_field_data = data.pop("fields")
mask_func = get_mask_function(model_fields.get("mask_callable"))
masked_field_data = {}
for key, value in all_field_data.items():
if isinstance(value, str) and key in mask_fields:
masked_field_data[key] = mask_str(value)
masked_field_data[key] = mask_func(value)
else:
masked_field_data[key] = value

View file

@ -66,6 +66,7 @@ class AuditlogModelRegistry:
exclude_fields: Optional[list[str]] = None,
mapping_fields: Optional[dict[str, str]] = None,
mask_fields: Optional[list[str]] = None,
mask_callable: Optional[str] = None,
m2m_fields: Optional[Collection[str]] = None,
serialize_data: bool = False,
serialize_kwargs: Optional[dict[str, Any]] = None,
@ -79,6 +80,8 @@ class AuditlogModelRegistry:
:param exclude_fields: The fields to exclude. Overrides the fields to include.
:param mapping_fields: Mapping from field names to strings in diff.
:param mask_fields: The fields to mask for sensitive info.
:param mask_callable: The dotted path to a callable that will be used for masking. If not provided,
the default mask_callable will be used.
:param m2m_fields: The fields to handle as many to many.
:param serialize_data: Option to include a dictionary of the objects state in the auditlog.
:param serialize_kwargs: Optional kwargs to pass to Django serializer
@ -120,6 +123,7 @@ class AuditlogModelRegistry:
"exclude_fields": exclude_fields,
"mapping_fields": mapping_fields,
"mask_fields": mask_fields,
"mask_callable": mask_callable,
"m2m_fields": m2m_fields,
"serialize_data": serialize_data,
"serialize_kwargs": serialize_kwargs,
@ -172,6 +176,7 @@ class AuditlogModelRegistry:
"exclude_fields": list(self._registry[model]["exclude_fields"]),
"mapping_fields": dict(self._registry[model]["mapping_fields"]),
"mask_fields": list(self._registry[model]["mask_fields"]),
"mask_callable": self._registry[model]["mask_callable"],
}
def get_serialize_options(self, model: ModelBase):

View file

@ -0,0 +1,6 @@
def custom_mask_str(value: str) -> str:
"""Custom masking function that only shows the last 4 characters."""
if len(value) > 4:
return "****" + value[-4:]
return value

View file

@ -425,6 +425,13 @@ class AutoManyRelatedModel(models.Model):
related = models.ManyToManyField(SimpleModel)
class CustomMaskModel(models.Model):
credit_card = models.CharField(max_length=16)
text = models.TextField()
history = AuditlogHistoryField(delete_related=True)
auditlog.register(AltPrimaryKeyModel)
auditlog.register(UUIDPrimaryKeyModel)
auditlog.register(ModelPrimaryKeyModel)
@ -462,3 +469,8 @@ auditlog.register(
serialize_data=True,
serialize_kwargs={"use_natural_foreign_keys": True},
)
auditlog.register(
CustomMaskModel,
mask_fields=["credit_card"],
mask_callable="auditlog_tests.test_app.mask.custom_mask_str",
)

View file

@ -33,6 +33,7 @@ from test_app.models import (
AutoManyRelatedModel,
CharfieldTextfieldModel,
ChoicesFieldModel,
CustomMaskModel,
DateTimeFieldModel,
JSONModel,
ManyRelatedModel,
@ -62,7 +63,7 @@ from test_app.models import (
from auditlog.admin import LogEntryAdmin
from auditlog.cid import get_cid
from auditlog.context import disable_auditlog, set_actor
from auditlog.diff import model_instance_diff
from auditlog.diff import mask_str, model_instance_diff
from auditlog.middleware import AuditlogMiddleware
from auditlog.models import DEFAULT_OBJECT_REPR, LogEntry
from auditlog.registry import AuditlogModelRegistry, AuditLogRegistrationError, auditlog
@ -810,6 +811,21 @@ class SimpleMaskedFieldsModelTest(TestCase):
msg="The diff function masks 'address' field.",
)
@override_settings(
AUDITLOG_MASK_CALLABLE="auditlog_tests.test_app.mask.custom_mask_str"
)
def test_global_mask_callable(self):
"""Test that global mask_callable from settings is used when model-specific one is not provided"""
instance = SimpleMaskedModel.objects.create(
address="1234567890123456", text="Some text"
)
self.assertEqual(
instance.history.latest().changes_dict["address"][1],
"****3456",
msg="The global masking function should be used when model-specific one is not provided",
)
class AdditionalDataModelTest(TestCase):
"""Log additional data if get_additional_data is defined in the model"""
@ -1276,7 +1292,7 @@ class RegisterModelSettingsTest(TestCase):
self.assertTrue(self.test_auditlog.contains(SimpleExcludeModel))
self.assertTrue(self.test_auditlog.contains(ChoicesFieldModel))
self.assertEqual(len(self.test_auditlog.get_models()), 32)
self.assertEqual(len(self.test_auditlog.get_models()), 33)
def test_register_models_register_model_with_attrs(self):
self.test_auditlog._register_models(
@ -2888,3 +2904,62 @@ class ModelManagerTest(TestCase):
log = LogEntry.objects.get_for_object(self.public).first()
self.assertEqual(log.action, LogEntry.Action.UPDATE)
self.assertEqual(log.changes_dict["name"], ["Public", "Updated"])
class TestMaskStr(TestCase):
"""Test the mask_str function that masks sensitive data."""
def test_mask_str_empty(self):
self.assertEqual(mask_str(""), "")
def test_mask_str_single_char(self):
self.assertEqual(mask_str("a"), "a")
def test_mask_str_even_length(self):
self.assertEqual(mask_str("1234"), "**34")
def test_mask_str_odd_length(self):
self.assertEqual(mask_str("12345"), "**345")
def test_mask_str_long_text(self):
self.assertEqual(mask_str("confidential"), "******ential")
class CustomMaskModelTest(TestCase):
def test_custom_mask_function(self):
instance = CustomMaskModel.objects.create(
credit_card="1234567890123456", text="Some text"
)
self.assertEqual(
instance.history.latest().changes_dict["credit_card"][1],
"****3456",
msg="The custom masking function should mask all but last 4 digits",
)
def test_custom_mask_function_short_value(self):
"""Test that custom masking function handles short values correctly"""
instance = CustomMaskModel.objects.create(credit_card="123", text="Some text")
self.assertEqual(
instance.history.latest().changes_dict["credit_card"][1],
"123",
msg="The custom masking function should not mask values shorter than 4 characters",
)
def test_custom_mask_function_serialized_data(self):
instance = CustomMaskModel.objects.create(
credit_card="1234567890123456", text="Some text"
)
log = instance.history.latest()
self.assertTrue(isinstance(log, LogEntry))
self.assertEqual(log.action, LogEntry.Action.CREATE)
# Update to trigger serialization
instance.credit_card = "9876543210987654"
instance.save()
log = instance.history.latest()
self.assertEqual(
log.changes_dict["credit_card"][1],
"****7654",
msg="The custom masking function should be used in serialized data",
)

View file

@ -132,6 +132,37 @@ For example, to mask the field ``address``, use::
auditlog.register(MyModel, mask_fields=['address'])
You can also specify a custom masking function by passing ``mask_callable`` to the ``register``
method. The ``mask_callable`` should be a dotted path to a function that takes a string and returns
a masked version of that string.
For example, to use a custom masking function::
# In your_app/utils.py
def custom_mask(value: str) -> str:
return "****" + value[-4:] # Only show last 4 characters
# In your models.py
auditlog.register(
MyModel,
mask_fields=['credit_card'],
mask_callable='your_app.utils.custom_mask'
)
Additionally, you can set a global default masking function that will be used when a model-specific
mask_callable is not provided. To do this, add the following to your Django settings::
AUDITLOG_MASK_CALLABLE = 'your_app.utils.custom_mask'
The masking function priority is as follows:
1. Model-specific ``mask_callable`` if provided in ``register()``
2. ``AUDITLOG_MASK_CALLABLE`` from settings if configured
3. Default ``mask_str`` function which masks the first half of the string with asterisks
If ``mask_callable`` is not specified and no global default is configured, the default masking function will be used which masks
the first half of the string with asterisks.
.. versionadded:: 2.0.0
Masking fields