diff --git a/encrypted_fields/fields.py b/encrypted_fields/fields.py index fb95f23..1e0fa2c 100644 --- a/encrypted_fields/fields.py +++ b/encrypted_fields/fields.py @@ -1,22 +1,33 @@ import base64 from django.conf import settings -from cryptography.fernet import Fernet +from cryptography.fernet import Fernet, MultiFernet from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from django.db import models +from django.utils.functional import cached_property class EncryptedFieldMixin(object): - salt = bytes(settings.SALT_KEY, 'utf-8') - kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), - length=32, - salt=salt, - iterations=100000, - backend=default_backend()) + @cached_property + def keys(self): + keys = [] + salt_keys = settings.SALT_KEY if isinstance(settings.SALT_KEY, list) else [settings.SALT_KEY] + for salt_key in salt_keys: + salt = bytes(salt_key, 'utf-8') + kdf = PBKDF2HMAC(algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + backend=default_backend()) + keys.append(base64.urlsafe_b64encode(kdf.derive(settings.SECRET_KEY.encode('utf-8')))) + return keys - key = base64.urlsafe_b64encode(kdf.derive(settings.SECRET_KEY.encode('utf-8'))) - f = Fernet(key) + @cached_property + def f(self): + if len(self.keys) == 1: + return Fernet(self.keys[0]) + return MultiFernet([Fernet(k) for k in self.keys]) def get_internal_type(self): """ diff --git a/package_test/tests.py b/package_test/tests.py index 96caf24..68a4a5d 100644 --- a/package_test/tests.py +++ b/package_test/tests.py @@ -1,7 +1,7 @@ import re from django.db import connection -from django.test import TestCase +from django.test import TestCase, override_settings from django.utils import timezone from .models import TestModel @@ -139,3 +139,36 @@ class FieldTest(TestCase): fresh_model = TestModel.objects.get(id=model.id) self.assertEqual(fresh_model.boolean, plaintext) + + +class RotatedSaltTestCase(TestCase): + + @classmethod + @override_settings(SALT_KEY=['abcdefghijklmnopqrstuvwxyz0123456789']) + def setUpTestData(cls): + """Create the initial record using the old salt""" + cls.original = TestModel.objects.create( + text="Oh hi test reader" + ) + + @override_settings(SALT_KEY=['newkeyhere', 'abcdefghijklmnopqrstuvwxyz0123456789']) + def test_rotated_salt(self): + """Chage the salt, keep the old one as the last in the list for reading""" + plaintext = "Oh hi test reader" + model = TestModel() + model.text = plaintext + model.save() + + ciphertext = FieldTest.get_db_value(self, 'text', model.id) + + self.assertNotEqual(plaintext, ciphertext) + self.assertTrue('test' not in ciphertext) + + fresh_model = TestModel.objects.get(id=model.id) + self.assertEqual(fresh_model.text, plaintext) + + old_record = TestModel.objects.get(id=self.original.id) + self.assertEqual(fresh_model.text, old_record.text) + + self.assertNotEqual(ciphertext, FieldTest.get_db_value(self, 'text', self.original.pk)) +