diff --git a/README.md b/README.md index ca77947..7179056 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,7 @@ Currently build in and unit-tested fields. They have the same APIs as their non- - `EncryptedFloatField` - `EncryptedEmailField` - `EncryptedBooleanField` +- `EncryptedJSONField` ### Compatible Django Version diff --git a/encrypted_fields/fields.py b/encrypted_fields/fields.py index 53cec54..302015d 100644 --- a/encrypted_fields/fields.py +++ b/encrypted_fields/fields.py @@ -1,4 +1,6 @@ import base64 +import json +from typing import Optional, Type from cryptography.fernet import Fernet, MultiFernet, InvalidToken from cryptography.hazmat.backends import default_backend @@ -6,11 +8,11 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from django.conf import settings from django.core import validators +from django.core.serializers.json import DjangoJSONEncoder from django.db import models from django.db.backends.base.operations import BaseDatabaseOperations from django.utils.functional import cached_property - class EncryptedFieldMixin(object): @cached_property def keys(self): @@ -153,3 +155,46 @@ class EncryptedEmailField(EncryptedFieldMixin, models.EmailField): class EncryptedBooleanField(EncryptedFieldMixin, models.BooleanField): pass + + +class EncryptedJSONField(EncryptedFieldMixin, models.JSONField): + def _encrypt_values(self, value): + if isinstance(value, dict): + return {key: self._encrypt_values(data) for key, data in value.items()} + elif isinstance(value, list): + return [self._encrypt_values(data) for data in value] + else: + value = str(value) + return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8") + + def _decrypt_values(self, value): + if value is None: + return value + if isinstance(value, dict): + return {key: self._decrypt_values(data) for key, data in value.items()} + elif isinstance(value, list): + return [self._decrypt_values(data) for data in value] + else: + value = str(value) + return self.f.decrypt(bytes(value, "utf-8")).decode("utf-8") + + def get_prep_value(self, value): + return json.dumps(self._encrypt_values(value=value), cls=self.encoder) + + def get_internal_type(self): + return "JSONField" + + def to_python(self, value): + if ( + value is None + or not isinstance(value, str) + or hasattr(self, "_already_decrypted") + ): + return value + try: + value = self._decrypt_values(value=json.loads(value)) + except InvalidToken: + pass + except UnicodeEncodeError: + pass + return super(EncryptedFieldMixin, self).to_python(value) diff --git a/package_test/models.py b/package_test/models.py index 99da24f..07601c0 100644 --- a/package_test/models.py +++ b/package_test/models.py @@ -10,3 +10,4 @@ class TestModel(models.Model): floating = EncryptedFloatField(null=True, blank=True) email = EncryptedEmailField(null=True, blank=True) boolean = EncryptedBooleanField(default=False, null=True) + json = EncryptedJSONField(default={}, null=True, blank=True) diff --git a/package_test/tests.py b/package_test/tests.py index 1bce167..e002443 100644 --- a/package_test/tests.py +++ b/package_test/tests.py @@ -199,6 +199,22 @@ class FieldTest(TestCase): model.full_clean() model.save() + def test_json_field_encrypted(self): + dict_values = {"key": "value", "list": ["nested", {"key": "val"}], "nested": {"child": "sibling"}} + + model = TestModel() + model.json = dict_values + model.full_clean() + model.save() + + ciphertext = self.get_db_value("json", model.id) + + self.assertNotEqual(dict_values, ciphertext) + + fresh_model = TestModel.objects.get(id=model.id) + self.assertEqual(fresh_model.json, dict_values) + + class RotatedSaltTestCase(TestCase): @classmethod