diff --git a/encrypted_fields/fields.py b/encrypted_fields/fields.py index 82bc216..b6ec057 100644 --- a/encrypted_fields/fields.py +++ b/encrypted_fields/fields.py @@ -28,20 +28,22 @@ class EncryptedFieldMixin: if isinstance(settings.SALT_KEY, list) else [settings.SALT_KEY] ) - for salt_key in salt_keys: - salt = bytes(salt_key, "utf-8") - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=salt, - iterations=100000, - backend=default_backend(), - ) - keys.append( - base64.urlsafe_b64encode( - kdf.derive(settings.SECRET_KEY.encode("utf-8")) + secret_keys = [settings.SECRET_KEY] + (settings.SECRET_KEY_FALLBACKS or []) + for secret_key in secret_keys: + for salt_key in salt_keys: + salt = bytes(salt_key, "utf-8") + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100_000, + backend=default_backend(), + ) + keys.append( + base64.urlsafe_b64encode( + kdf.derive(secret_key.encode("utf-8")) + ) ) - ) return keys @cached_property diff --git a/package_test/tests.py b/package_test/tests.py index d62e9b0..abeb65a 100644 --- a/package_test/tests.py +++ b/package_test/tests.py @@ -255,3 +255,57 @@ class RotatedSaltTestCase(TestCase): assert fresh_model.text == old_record.text assert ciphertext != FieldTest.get_db_value(self, "text", self.original.pk) + + +class RotatedSecretKeyTestCase(TestCase): + + @staticmethod + def clear_cached_properties(): + # we have to clear the cached properties of EncryptedFieldMixin so we have the right encryption keys + text_field = TestModel._meta.get_field('text') + if hasattr(text_field, 'keys'): + del text_field.keys + if hasattr(text_field, 'f'): + del text_field.f + + @classmethod + @override_settings(SECRET_KEY="oldkey") + def setUpTestData(cls) -> None: + """Create the initial record using the old key""" + cls.clear_cached_properties() + cls.original = TestModel.objects.create(text="Oh hi test reader") + cls.clear_cached_properties() + + def tearDown(self): + self.clear_cached_properties() + + @override_settings(SECRET_KEY="newkey", SECRET_KEY_FALLBACKS=["oldkey"]) + def test_old_and_new_secret_keys(self) -> None: + + plaintext = "Oh hi test reader" + model = TestModel() + model.text = plaintext + model.save() + + fresh_model = TestModel.objects.get(id=model.id) + assert fresh_model.text == plaintext + + old_record = TestModel.objects.get(id=self.original.id) + assert old_record.text == plaintext + + @override_settings(SECRET_KEY="newkey") + def test_cannot_decrypt_old_record_with_new_key(self) -> None: + plaintext = "Oh hi test reader" + model = TestModel() + model.text = plaintext + model.save() + + fresh_model = TestModel.objects.get(id=model.id) + assert fresh_model.text == plaintext + + old_record = TestModel.objects.get(id=self.original.id) + # assert that old record text is still encrypted + assert old_record.text.endswith("=") + # assert that old record cannot be decrypted now + assert old_record.text != plaintext +