From af78976e534599ff22da71dad8dfb888e5568af7 Mon Sep 17 00:00:00 2001 From: Youngkwang Yang Date: Mon, 9 Jun 2025 22:29:59 +0900 Subject: [PATCH] Add support for custom masking functions (#725) * Add test cases for the `mask_str` function * Add custom masking function support through mask_callable * Add test cases for custom masking function * Update documentation for custom masking function * fix test case * rename `AUDITLOG_DEFAULT_MASK_CALLABLE` variable -AUDITLOG_DEFAULT_MASK_CALLABLE to `AUDITLOG_MASK_CALLABLE` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update `CHANGELOG.md` to include mask function customization feature --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 1 + auditlog/conf.py | 2 + auditlog/diff.py | 32 +++++++++++-- auditlog/models.py | 9 ++-- auditlog/registry.py | 5 ++ auditlog_tests/test_app/mask.py | 6 +++ auditlog_tests/test_app/models.py | 12 +++++ auditlog_tests/tests.py | 79 ++++++++++++++++++++++++++++++- docs/source/usage.rst | 31 ++++++++++++ 9 files changed, 168 insertions(+), 9 deletions(-) create mode 100644 auditlog_tests/test_app/mask.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 274f48a..fe816b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ #### Improvements - feat: Support storing JSON in the changes field when ```AUDITLOG_STORE_JSON_CHANGES``` is enabled. ([#719](https://github.com/jazzband/django-auditlog/pull/719)) +- feat: Added `AUDITLOG_MASK_CALLABLE` setting to allow custom masking functions ([#725](https://github.com/jazzband/django-auditlog/pull/725)) #### Fixes diff --git a/auditlog/conf.py b/auditlog/conf.py index 01d9f29..b151b24 100644 --- a/auditlog/conf.py +++ b/auditlog/conf.py @@ -60,3 +60,5 @@ settings.AUDITLOG_CHANGE_DISPLAY_TRUNCATE_LENGTH = getattr( settings.AUDITLOG_STORE_JSON_CHANGES = getattr( settings, "AUDITLOG_STORE_JSON_CHANGES", False ) + +settings.AUDITLOG_MASK_CALLABLE = getattr(settings, "AUDITLOG_MASK_CALLABLE", None) diff --git a/auditlog/diff.py b/auditlog/diff.py index b8455f1..ba61e28 100644 --- a/auditlog/diff.py +++ b/auditlog/diff.py @@ -1,12 +1,13 @@ import json from datetime import timezone -from typing import Optional +from typing import Callable, Optional from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.db.models import NOT_PROVIDED, DateTimeField, ForeignKey, JSONField, Model from django.utils import timezone as django_timezone from django.utils.encoding import smart_str +from django.utils.module_loading import import_string def track_field(field): @@ -130,6 +131,29 @@ def is_primitive(obj) -> bool: return isinstance(obj, primitive_types) +def get_mask_function(mask_callable: Optional[str] = None) -> Callable[[str], str]: + """ + Get the masking function to use based on the following priority: + 1. Model-specific mask_callable if provided + 2. mask_callable from settings if configured + 3. Default mask_str function + + :param mask_callable: The dotted path to a callable that will be used for masking. + :type mask_callable: str + :return: A callable that takes a string and returns a masked version. + :rtype: Callable[[str], str] + """ + + if mask_callable: + return import_string(mask_callable) + + default_mask_callable = settings.AUDITLOG_MASK_CALLABLE + if default_mask_callable: + return import_string(default_mask_callable) + + return mask_str + + def mask_str(value: str) -> str: """ Masks the first half of the input string to remove sensitive data. @@ -226,9 +250,11 @@ def model_instance_diff( if old_value != new_value: if model_fields and field.name in model_fields["mask_fields"]: + mask_func = get_mask_function(model_fields.get("mask_callable")) + diff[field.name] = ( - mask_str(smart_str(old_value)), - mask_str(smart_str(new_value)), + mask_func(smart_str(old_value)), + mask_func(smart_str(new_value)), ) else: if not use_json_for_changes: diff --git a/auditlog/models.py b/auditlog/models.py index a0029a1..9ec3bfa 100644 --- a/auditlog/models.py +++ b/auditlog/models.py @@ -23,7 +23,7 @@ from django.utils import timezone as django_timezone from django.utils.encoding import smart_str from django.utils.translation import gettext_lazy as _ -from auditlog.diff import mask_str +from auditlog.diff import get_mask_function DEFAULT_OBJECT_REPR = "" @@ -247,7 +247,7 @@ class LogEntryManager(models.Manager): mask_fields = model_fields["mask_fields"] if mask_fields: - data = self._mask_serialized_fields(data, mask_fields) + data = self._mask_serialized_fields(data, mask_fields, model_fields) return data @@ -287,14 +287,15 @@ class LogEntryManager(models.Manager): return list(set(include_fields or all_field_names).difference(exclude_fields)) def _mask_serialized_fields( - self, data: dict[str, Any], mask_fields: list[str] + self, data: dict[str, Any], mask_fields: list[str], model_fields: dict[str, Any] ) -> dict[str, Any]: all_field_data = data.pop("fields") + mask_func = get_mask_function(model_fields.get("mask_callable")) masked_field_data = {} for key, value in all_field_data.items(): if isinstance(value, str) and key in mask_fields: - masked_field_data[key] = mask_str(value) + masked_field_data[key] = mask_func(value) else: masked_field_data[key] = value diff --git a/auditlog/registry.py b/auditlog/registry.py index 2da9997..c8ca907 100644 --- a/auditlog/registry.py +++ b/auditlog/registry.py @@ -66,6 +66,7 @@ class AuditlogModelRegistry: exclude_fields: Optional[list[str]] = None, mapping_fields: Optional[dict[str, str]] = None, mask_fields: Optional[list[str]] = None, + mask_callable: Optional[str] = None, m2m_fields: Optional[Collection[str]] = None, serialize_data: bool = False, serialize_kwargs: Optional[dict[str, Any]] = None, @@ -79,6 +80,8 @@ class AuditlogModelRegistry: :param exclude_fields: The fields to exclude. Overrides the fields to include. :param mapping_fields: Mapping from field names to strings in diff. :param mask_fields: The fields to mask for sensitive info. + :param mask_callable: The dotted path to a callable that will be used for masking. If not provided, + the default mask_callable will be used. :param m2m_fields: The fields to handle as many to many. :param serialize_data: Option to include a dictionary of the objects state in the auditlog. :param serialize_kwargs: Optional kwargs to pass to Django serializer @@ -120,6 +123,7 @@ class AuditlogModelRegistry: "exclude_fields": exclude_fields, "mapping_fields": mapping_fields, "mask_fields": mask_fields, + "mask_callable": mask_callable, "m2m_fields": m2m_fields, "serialize_data": serialize_data, "serialize_kwargs": serialize_kwargs, @@ -172,6 +176,7 @@ class AuditlogModelRegistry: "exclude_fields": list(self._registry[model]["exclude_fields"]), "mapping_fields": dict(self._registry[model]["mapping_fields"]), "mask_fields": list(self._registry[model]["mask_fields"]), + "mask_callable": self._registry[model]["mask_callable"], } def get_serialize_options(self, model: ModelBase): diff --git a/auditlog_tests/test_app/mask.py b/auditlog_tests/test_app/mask.py new file mode 100644 index 0000000..0dc55d3 --- /dev/null +++ b/auditlog_tests/test_app/mask.py @@ -0,0 +1,6 @@ +def custom_mask_str(value: str) -> str: + """Custom masking function that only shows the last 4 characters.""" + if len(value) > 4: + return "****" + value[-4:] + + return value diff --git a/auditlog_tests/test_app/models.py b/auditlog_tests/test_app/models.py index cf35cc4..efd6552 100644 --- a/auditlog_tests/test_app/models.py +++ b/auditlog_tests/test_app/models.py @@ -425,6 +425,13 @@ class AutoManyRelatedModel(models.Model): related = models.ManyToManyField(SimpleModel) +class CustomMaskModel(models.Model): + credit_card = models.CharField(max_length=16) + text = models.TextField() + + history = AuditlogHistoryField(delete_related=True) + + auditlog.register(AltPrimaryKeyModel) auditlog.register(UUIDPrimaryKeyModel) auditlog.register(ModelPrimaryKeyModel) @@ -462,3 +469,8 @@ auditlog.register( serialize_data=True, serialize_kwargs={"use_natural_foreign_keys": True}, ) +auditlog.register( + CustomMaskModel, + mask_fields=["credit_card"], + mask_callable="auditlog_tests.test_app.mask.custom_mask_str", +) diff --git a/auditlog_tests/tests.py b/auditlog_tests/tests.py index 4be1251..f16b533 100644 --- a/auditlog_tests/tests.py +++ b/auditlog_tests/tests.py @@ -33,6 +33,7 @@ from test_app.models import ( AutoManyRelatedModel, CharfieldTextfieldModel, ChoicesFieldModel, + CustomMaskModel, DateTimeFieldModel, JSONModel, ManyRelatedModel, @@ -62,7 +63,7 @@ from test_app.models import ( from auditlog.admin import LogEntryAdmin from auditlog.cid import get_cid from auditlog.context import disable_auditlog, set_actor -from auditlog.diff import model_instance_diff +from auditlog.diff import mask_str, model_instance_diff from auditlog.middleware import AuditlogMiddleware from auditlog.models import DEFAULT_OBJECT_REPR, LogEntry from auditlog.registry import AuditlogModelRegistry, AuditLogRegistrationError, auditlog @@ -810,6 +811,21 @@ class SimpleMaskedFieldsModelTest(TestCase): msg="The diff function masks 'address' field.", ) + @override_settings( + AUDITLOG_MASK_CALLABLE="auditlog_tests.test_app.mask.custom_mask_str" + ) + def test_global_mask_callable(self): + """Test that global mask_callable from settings is used when model-specific one is not provided""" + instance = SimpleMaskedModel.objects.create( + address="1234567890123456", text="Some text" + ) + + self.assertEqual( + instance.history.latest().changes_dict["address"][1], + "****3456", + msg="The global masking function should be used when model-specific one is not provided", + ) + class AdditionalDataModelTest(TestCase): """Log additional data if get_additional_data is defined in the model""" @@ -1276,7 +1292,7 @@ class RegisterModelSettingsTest(TestCase): self.assertTrue(self.test_auditlog.contains(SimpleExcludeModel)) self.assertTrue(self.test_auditlog.contains(ChoicesFieldModel)) - self.assertEqual(len(self.test_auditlog.get_models()), 32) + self.assertEqual(len(self.test_auditlog.get_models()), 33) def test_register_models_register_model_with_attrs(self): self.test_auditlog._register_models( @@ -2888,3 +2904,62 @@ class ModelManagerTest(TestCase): log = LogEntry.objects.get_for_object(self.public).first() self.assertEqual(log.action, LogEntry.Action.UPDATE) self.assertEqual(log.changes_dict["name"], ["Public", "Updated"]) + + +class TestMaskStr(TestCase): + """Test the mask_str function that masks sensitive data.""" + + def test_mask_str_empty(self): + self.assertEqual(mask_str(""), "") + + def test_mask_str_single_char(self): + self.assertEqual(mask_str("a"), "a") + + def test_mask_str_even_length(self): + self.assertEqual(mask_str("1234"), "**34") + + def test_mask_str_odd_length(self): + self.assertEqual(mask_str("12345"), "**345") + + def test_mask_str_long_text(self): + self.assertEqual(mask_str("confidential"), "******ential") + + +class CustomMaskModelTest(TestCase): + def test_custom_mask_function(self): + instance = CustomMaskModel.objects.create( + credit_card="1234567890123456", text="Some text" + ) + self.assertEqual( + instance.history.latest().changes_dict["credit_card"][1], + "****3456", + msg="The custom masking function should mask all but last 4 digits", + ) + + def test_custom_mask_function_short_value(self): + """Test that custom masking function handles short values correctly""" + instance = CustomMaskModel.objects.create(credit_card="123", text="Some text") + self.assertEqual( + instance.history.latest().changes_dict["credit_card"][1], + "123", + msg="The custom masking function should not mask values shorter than 4 characters", + ) + + def test_custom_mask_function_serialized_data(self): + instance = CustomMaskModel.objects.create( + credit_card="1234567890123456", text="Some text" + ) + log = instance.history.latest() + self.assertTrue(isinstance(log, LogEntry)) + self.assertEqual(log.action, LogEntry.Action.CREATE) + + # Update to trigger serialization + instance.credit_card = "9876543210987654" + instance.save() + + log = instance.history.latest() + self.assertEqual( + log.changes_dict["credit_card"][1], + "****7654", + msg="The custom masking function should be used in serialized data", + ) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index de731bb..ab35fdd 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -132,6 +132,37 @@ For example, to mask the field ``address``, use:: auditlog.register(MyModel, mask_fields=['address']) +You can also specify a custom masking function by passing ``mask_callable`` to the ``register`` +method. The ``mask_callable`` should be a dotted path to a function that takes a string and returns +a masked version of that string. + +For example, to use a custom masking function:: + + # In your_app/utils.py + def custom_mask(value: str) -> str: + return "****" + value[-4:] # Only show last 4 characters + + # In your models.py + auditlog.register( + MyModel, + mask_fields=['credit_card'], + mask_callable='your_app.utils.custom_mask' + ) + +Additionally, you can set a global default masking function that will be used when a model-specific +mask_callable is not provided. To do this, add the following to your Django settings:: + + AUDITLOG_MASK_CALLABLE = 'your_app.utils.custom_mask' + +The masking function priority is as follows: + +1. Model-specific ``mask_callable`` if provided in ``register()`` +2. ``AUDITLOG_MASK_CALLABLE`` from settings if configured +3. Default ``mask_str`` function which masks the first half of the string with asterisks + +If ``mask_callable`` is not specified and no global default is configured, the default masking function will be used which masks +the first half of the string with asterisks. + .. versionadded:: 2.0.0 Masking fields