Merge pull request #1 from StevenMapes/main

Add in support for rotating the salt but allow older records to still be read
This commit is contained in:
fragment 2021-12-10 07:55:45 +09:00 committed by GitHub
commit 3c9a6b9349
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 10 deletions

View file

@ -1,22 +1,33 @@
import base64
from django.conf import settings
from cryptography.fernet import Fernet
from cryptography.fernet import Fernet, MultiFernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from django.db import models
from django.utils.functional import cached_property
class EncryptedFieldMixin(object):
salt = bytes(settings.SALT_KEY, 'utf-8')
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
backend=default_backend())
@cached_property
def keys(self):
keys = []
salt_keys = settings.SALT_KEY if isinstance(settings.SALT_KEY, list) else [settings.SALT_KEY]
for salt_key in salt_keys:
salt = bytes(salt_key, 'utf-8')
kdf = PBKDF2HMAC(algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
backend=default_backend())
keys.append(base64.urlsafe_b64encode(kdf.derive(settings.SECRET_KEY.encode('utf-8'))))
return keys
key = base64.urlsafe_b64encode(kdf.derive(settings.SECRET_KEY.encode('utf-8')))
f = Fernet(key)
@cached_property
def f(self):
if len(self.keys) == 1:
return Fernet(self.keys[0])
return MultiFernet([Fernet(k) for k in self.keys])
def get_internal_type(self):
"""

View file

@ -1,7 +1,7 @@
import re
from django.db import connection
from django.test import TestCase
from django.test import TestCase, override_settings
from django.utils import timezone
from .models import TestModel
@ -139,3 +139,36 @@ class FieldTest(TestCase):
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.boolean, plaintext)
class RotatedSaltTestCase(TestCase):
@classmethod
@override_settings(SALT_KEY=['abcdefghijklmnopqrstuvwxyz0123456789'])
def setUpTestData(cls):
"""Create the initial record using the old salt"""
cls.original = TestModel.objects.create(
text="Oh hi test reader"
)
@override_settings(SALT_KEY=['newkeyhere', 'abcdefghijklmnopqrstuvwxyz0123456789'])
def test_rotated_salt(self):
"""Chage the salt, keep the old one as the last in the list for reading"""
plaintext = "Oh hi test reader"
model = TestModel()
model.text = plaintext
model.save()
ciphertext = FieldTest.get_db_value(self, 'text', model.id)
self.assertNotEqual(plaintext, ciphertext)
self.assertTrue('test' not in ciphertext)
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.text, plaintext)
old_record = TestModel.objects.get(id=self.original.id)
self.assertEqual(fresh_model.text, old_record.text)
self.assertNotEqual(ciphertext, FieldTest.get_db_value(self, 'text', self.original.pk))