mirror of
https://github.com/jazzband/django-fernet-encrypted-fields.git
synced 2026-03-16 22:40:27 +00:00
Fix lint error
This commit is contained in:
parent
755636a27a
commit
cf489b71af
4 changed files with 141 additions and 114 deletions
|
|
@ -1 +1 @@
|
|||
from .fields import *
|
||||
from .fields import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -1,19 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, MultiFernet, InvalidToken
|
||||
from cryptography.fernet import Fernet, InvalidToken, MultiFernet
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from django.conf import settings
|
||||
from django.core import validators
|
||||
from django.core.validators import MaxValueValidator, MinValueValidator
|
||||
from django.db import models
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.models.expressions import Expression
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
class EncryptedFieldMixin(object):
|
||||
_TypeAny = Any
|
||||
|
||||
|
||||
class EncryptedFieldMixin:
|
||||
@cached_property
|
||||
def keys(self):
|
||||
def keys(self) -> list[bytes]:
|
||||
keys = []
|
||||
salt_keys = (
|
||||
settings.SALT_KEY
|
||||
|
|
@ -37,18 +45,18 @@ class EncryptedFieldMixin(object):
|
|||
return keys
|
||||
|
||||
@cached_property
|
||||
def f(self):
|
||||
def f(self) -> Fernet | MultiFernet:
|
||||
if len(self.keys) == 1:
|
||||
return Fernet(self.keys[0])
|
||||
return MultiFernet([Fernet(k) for k in self.keys])
|
||||
|
||||
def get_internal_type(self):
|
||||
def get_internal_type(self) -> str:
|
||||
"""
|
||||
To treat everything as text
|
||||
"""
|
||||
return "TextField"
|
||||
|
||||
def get_prep_value(self, value):
|
||||
def get_prep_value(self, value: _TypeAny) -> _TypeAny:
|
||||
value = super().get_prep_value(value)
|
||||
if value:
|
||||
if not isinstance(value, str):
|
||||
|
|
@ -56,15 +64,25 @@ class EncryptedFieldMixin(object):
|
|||
return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8")
|
||||
return None
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
def get_db_prep_value(
|
||||
self,
|
||||
value: _TypeAny,
|
||||
connection: BaseDatabaseWrapper, # noqa: ARG002
|
||||
prepared: bool = False, # noqa: FBT001, FBT002
|
||||
) -> _TypeAny:
|
||||
if not prepared:
|
||||
value = self.get_prep_value(value)
|
||||
return value
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
def from_db_value(
|
||||
self,
|
||||
value: _TypeAny,
|
||||
expression: Expression, # noqa: ARG002
|
||||
connection: BaseDatabaseWrapper, # noqa: ARG002
|
||||
) -> _TypeAny:
|
||||
return self.to_python(value)
|
||||
|
||||
def to_python(self, value):
|
||||
def to_python(self, value: _TypeAny) -> _TypeAny:
|
||||
if (
|
||||
value is None
|
||||
or not isinstance(value, str)
|
||||
|
|
@ -77,12 +95,12 @@ class EncryptedFieldMixin(object):
|
|||
pass
|
||||
except UnicodeEncodeError:
|
||||
pass
|
||||
return super(EncryptedFieldMixin, self).to_python(value)
|
||||
return super().to_python(value)
|
||||
|
||||
def clean(self, value, model_instance):
|
||||
def clean(self, value: _TypeAny, model_instance: models.Field) -> _TypeAny:
|
||||
"""
|
||||
Create and assign a semaphore so that to_python method will not try to decrypt an already decrypted value
|
||||
during cleaning of a form
|
||||
Create and assign a semaphore so that to_python method will not try
|
||||
to decrypt an already decrypted value during cleaning of a form
|
||||
"""
|
||||
self._already_decrypted = True
|
||||
ret = super().clean(value, model_instance)
|
||||
|
|
@ -104,15 +122,17 @@ class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField):
|
|||
|
||||
class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
|
||||
@cached_property
|
||||
def validators(self):
|
||||
def validators(self) -> list[MinValueValidator | MaxValueValidator]:
|
||||
# These validators can't be added at field initialization time since
|
||||
# they're based on values retrieved from `connection`.
|
||||
validators_ = [*self.default_validators, *self._validators]
|
||||
internal_type = models.IntegerField().get_internal_type()
|
||||
min_value, max_value = BaseDatabaseOperations.integer_field_ranges[internal_type]
|
||||
min_value, max_value = BaseDatabaseOperations.integer_field_ranges[
|
||||
internal_type
|
||||
]
|
||||
if min_value is not None and not any(
|
||||
(
|
||||
isinstance(validator, validators.MinValueValidator)
|
||||
isinstance(validator, MinValueValidator)
|
||||
and (
|
||||
validator.limit_value()
|
||||
if callable(validator.limit_value)
|
||||
|
|
@ -122,10 +142,10 @@ class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
|
|||
)
|
||||
for validator in validators_
|
||||
):
|
||||
validators_.append(validators.MinValueValidator(min_value))
|
||||
validators_.append(MinValueValidator(min_value))
|
||||
if max_value is not None and not any(
|
||||
(
|
||||
isinstance(validator, validators.MaxValueValidator)
|
||||
isinstance(validator, MaxValueValidator)
|
||||
and (
|
||||
validator.limit_value()
|
||||
if callable(validator.limit_value)
|
||||
|
|
@ -135,7 +155,7 @@ class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
|
|||
)
|
||||
for validator in validators_
|
||||
):
|
||||
validators_.append(validators.MaxValueValidator(max_value))
|
||||
validators_.append(MaxValueValidator(max_value))
|
||||
return validators_
|
||||
|
||||
|
||||
|
|
@ -156,33 +176,31 @@ class EncryptedBooleanField(EncryptedFieldMixin, models.BooleanField):
|
|||
|
||||
|
||||
class EncryptedJSONField(EncryptedFieldMixin, models.JSONField):
|
||||
def _encrypt_values(self, value):
|
||||
def _encrypt_values(self, value: _TypeAny) -> _TypeAny:
|
||||
if isinstance(value, dict):
|
||||
return {key: self._encrypt_values(data) for key, data in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [self._encrypt_values(data) for data in value]
|
||||
else:
|
||||
value = str(value)
|
||||
if isinstance(value, list):
|
||||
return [self._encrypt_values(data) for data in value]
|
||||
value = str(value)
|
||||
return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8")
|
||||
|
||||
def _decrypt_values(self, value):
|
||||
def _decrypt_values(self, value: _TypeAny) -> _TypeAny:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return {key: self._decrypt_values(data) for key, data in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [self._decrypt_values(data) for data in value]
|
||||
else:
|
||||
value = str(value)
|
||||
if isinstance(value, list):
|
||||
return [self._decrypt_values(data) for data in value]
|
||||
value = str(value)
|
||||
return self.f.decrypt(bytes(value, "utf-8")).decode("utf-8")
|
||||
|
||||
def get_prep_value(self, value):
|
||||
def get_prep_value(self, value: _TypeAny) -> str:
|
||||
return json.dumps(self._encrypt_values(value=value), cls=self.encoder)
|
||||
|
||||
def get_internal_type(self):
|
||||
def get_internal_type(self) -> str:
|
||||
return "JSONField"
|
||||
|
||||
def to_python(self, value):
|
||||
def to_python(self, value: _TypeAny) -> _TypeAny:
|
||||
if (
|
||||
value is None
|
||||
or not isinstance(value, str)
|
||||
|
|
@ -195,4 +213,4 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField):
|
|||
pass
|
||||
except UnicodeEncodeError:
|
||||
pass
|
||||
return super(EncryptedFieldMixin, self).to_python(value)
|
||||
return super().to_python(value)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,16 @@
|
|||
from encrypted_fields.fields import *
|
||||
from django.db import models
|
||||
|
||||
from encrypted_fields.fields import (
|
||||
EncryptedBooleanField,
|
||||
EncryptedCharField,
|
||||
EncryptedDateField,
|
||||
EncryptedDateTimeField,
|
||||
EncryptedEmailField,
|
||||
EncryptedFloatField,
|
||||
EncryptedIntegerField,
|
||||
EncryptedJSONField,
|
||||
EncryptedTextField,
|
||||
)
|
||||
|
||||
|
||||
class TestModel(models.Model):
|
||||
|
|
|
|||
|
|
@ -1,25 +1,24 @@
|
|||
import json
|
||||
import re
|
||||
|
||||
import pytest
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import connection
|
||||
from django.test import TestCase, override_settings
|
||||
from django.utils import timezone
|
||||
from django.core.exceptions import ValidationError
|
||||
|
||||
from .models import TestModel
|
||||
|
||||
|
||||
class FieldTest(TestCase):
|
||||
def get_db_value(self, field, model_id):
|
||||
def get_db_value(self, field: str, model_id: int) -> None:
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(
|
||||
"select {0} "
|
||||
"from package_test_testmodel "
|
||||
"where id = {1};".format(field, model_id)
|
||||
f"select {field} from package_test_testmodel where id = {model_id};"
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def test_char_field_encrypted(self):
|
||||
def test_char_field_encrypted(self) -> None:
|
||||
plaintext = "Oh hi, test reader!"
|
||||
|
||||
model = TestModel()
|
||||
|
|
@ -29,13 +28,13 @@ class FieldTest(TestCase):
|
|||
|
||||
ciphertext = self.get_db_value("char", model.id)
|
||||
|
||||
self.assertNotEqual(plaintext, ciphertext)
|
||||
self.assertTrue("test" not in ciphertext)
|
||||
assert plaintext != ciphertext
|
||||
assert "test" not in ciphertext
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.char, plaintext)
|
||||
assert fresh_model.char == plaintext
|
||||
|
||||
def test_text_field_encrypted(self):
|
||||
def test_text_field_encrypted(self) -> None:
|
||||
plaintext = "Oh hi, test reader!" * 10
|
||||
|
||||
model = TestModel()
|
||||
|
|
@ -45,13 +44,13 @@ class FieldTest(TestCase):
|
|||
|
||||
ciphertext = self.get_db_value("text", model.id)
|
||||
|
||||
self.assertNotEqual(plaintext, ciphertext)
|
||||
self.assertTrue("test" not in ciphertext)
|
||||
assert plaintext != ciphertext
|
||||
assert "test" not in ciphertext
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.text, plaintext)
|
||||
assert fresh_model.text == plaintext
|
||||
|
||||
def test_datetime_field_encrypted(self):
|
||||
def test_datetime_field_encrypted(self) -> None:
|
||||
plaintext = timezone.now()
|
||||
|
||||
model = TestModel()
|
||||
|
|
@ -62,19 +61,19 @@ class FieldTest(TestCase):
|
|||
ciphertext = self.get_db_value("datetime", model.id)
|
||||
|
||||
# Django's normal date serialization format
|
||||
self.assertTrue(re.search("^\d\d\d\d-\d\d-\d\d", ciphertext) is None)
|
||||
assert re.search(r"^\d\d\d\d-\d\d-\d\d", ciphertext) is None
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.datetime, plaintext)
|
||||
assert fresh_model.datetime == plaintext
|
||||
|
||||
plaintext = "text"
|
||||
model.datetime = plaintext
|
||||
model.full_clean()
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
model.datetime = plaintext
|
||||
model.full_clean()
|
||||
with pytest.raises(ValidationError):
|
||||
model.save()
|
||||
|
||||
def test_integer_field_encrypted(self):
|
||||
def test_integer_field_encrypted(self) -> None:
|
||||
plaintext = 42
|
||||
|
||||
model = TestModel()
|
||||
|
|
@ -84,28 +83,26 @@ class FieldTest(TestCase):
|
|||
|
||||
ciphertext = self.get_db_value("integer", model.id)
|
||||
|
||||
self.assertNotEqual(plaintext, ciphertext)
|
||||
self.assertNotEqual(plaintext, str(ciphertext))
|
||||
assert plaintext != ciphertext
|
||||
assert plaintext != str(ciphertext)
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.integer, plaintext)
|
||||
assert fresh_model.integer == plaintext
|
||||
|
||||
# "IntegerField": (-2147483648, 2147483647)
|
||||
plaintext = 2147483648
|
||||
model.integer = plaintext
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
model.integer = plaintext
|
||||
with pytest.raises(ValidationError):
|
||||
model.full_clean()
|
||||
model.save()
|
||||
|
||||
plaintext = "text"
|
||||
model.integer = plaintext
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
model.integer = plaintext
|
||||
with pytest.raises(TypeError):
|
||||
model.full_clean()
|
||||
model.save()
|
||||
|
||||
def test_date_field_encrypted(self):
|
||||
def test_date_field_encrypted(self) -> None:
|
||||
plaintext = timezone.now().date()
|
||||
|
||||
model = TestModel()
|
||||
|
|
@ -116,17 +113,17 @@ class FieldTest(TestCase):
|
|||
ciphertext = self.get_db_value("date", model.id)
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
|
||||
self.assertNotEqual(ciphertext, plaintext.isoformat())
|
||||
self.assertEqual(fresh_model.date, plaintext)
|
||||
assert ciphertext != plaintext.isoformat()
|
||||
assert fresh_model.date == plaintext
|
||||
|
||||
plaintext = "text"
|
||||
model.date = plaintext
|
||||
model.full_clean()
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
model.date = plaintext
|
||||
model.full_clean()
|
||||
with pytest.raises(ValidationError):
|
||||
model.save()
|
||||
|
||||
def test_float_field_encrypted(self):
|
||||
def test_float_field_encrypted(self) -> None:
|
||||
plaintext = 42.44
|
||||
|
||||
model = TestModel()
|
||||
|
|
@ -136,20 +133,20 @@ class FieldTest(TestCase):
|
|||
|
||||
ciphertext = self.get_db_value("floating", model.id)
|
||||
|
||||
self.assertNotEqual(plaintext, ciphertext)
|
||||
self.assertNotEqual(plaintext, str(ciphertext))
|
||||
assert plaintext != ciphertext
|
||||
assert plaintext != str(ciphertext)
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.floating, plaintext)
|
||||
assert fresh_model.floating == plaintext
|
||||
|
||||
plaintext = "text"
|
||||
model.floating = plaintext
|
||||
model.full_clean()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
model.floating = plaintext
|
||||
model.full_clean()
|
||||
with pytest.raises(ValueError):
|
||||
model.save()
|
||||
|
||||
def test_email_field_encrypted(self):
|
||||
def test_email_field_encrypted(self) -> None:
|
||||
plaintext = "test@gmail.com"
|
||||
|
||||
model = TestModel()
|
||||
|
|
@ -159,20 +156,19 @@ class FieldTest(TestCase):
|
|||
|
||||
ciphertext = self.get_db_value("email", model.id)
|
||||
|
||||
self.assertNotEqual(plaintext, ciphertext)
|
||||
self.assertTrue("aron" not in ciphertext)
|
||||
assert plaintext != ciphertext
|
||||
assert "aron" not in ciphertext
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.email, plaintext)
|
||||
assert fresh_model.email == plaintext
|
||||
|
||||
plaintext = "text"
|
||||
model.email = plaintext
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
model.email = plaintext
|
||||
with pytest.raises(ValidationError):
|
||||
model.full_clean()
|
||||
model.save()
|
||||
|
||||
def test_boolean_field_encrypted(self):
|
||||
def test_boolean_field_encrypted(self) -> None:
|
||||
plaintext = True
|
||||
|
||||
model = TestModel()
|
||||
|
|
@ -182,26 +178,30 @@ class FieldTest(TestCase):
|
|||
|
||||
ciphertext = self.get_db_value("boolean", model.id)
|
||||
|
||||
self.assertNotEqual(plaintext, ciphertext)
|
||||
self.assertNotEqual(True, ciphertext)
|
||||
self.assertNotEqual("True", ciphertext)
|
||||
self.assertNotEqual("true", ciphertext)
|
||||
self.assertNotEqual("1", ciphertext)
|
||||
self.assertNotEqual(1, ciphertext)
|
||||
self.assertTrue(not isinstance(ciphertext, bool))
|
||||
assert plaintext != ciphertext
|
||||
assert ciphertext is not True
|
||||
assert ciphertext != "True"
|
||||
assert ciphertext != "true"
|
||||
assert ciphertext != "1"
|
||||
assert ciphertext != 1
|
||||
assert not isinstance(ciphertext, bool)
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.boolean, plaintext)
|
||||
assert fresh_model.boolean == plaintext
|
||||
|
||||
plaintext = "text"
|
||||
model.boolean = plaintext
|
||||
model.full_clean()
|
||||
|
||||
with self.assertRaises(ValidationError):
|
||||
model.boolean = plaintext
|
||||
model.full_clean()
|
||||
with pytest.raises(ValidationError):
|
||||
model.save()
|
||||
|
||||
def test_json_field_encrypted(self):
|
||||
dict_values = {"key": "value", "list": ["nested", {"key": "val"}], "nested": {"child": "sibling"}}
|
||||
def test_json_field_encrypted(self) -> None:
|
||||
dict_values = {
|
||||
"key": "value",
|
||||
"list": ["nested", {"key": "val"}],
|
||||
"nested": {"child": "sibling"},
|
||||
}
|
||||
|
||||
model = TestModel()
|
||||
model.json = dict_values
|
||||
|
|
@ -210,15 +210,14 @@ class FieldTest(TestCase):
|
|||
|
||||
ciphertext = json.loads(self.get_db_value("json", model.id))
|
||||
|
||||
|
||||
self.assertNotEqual(dict_values, ciphertext)
|
||||
assert dict_values != ciphertext
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.json, dict_values)
|
||||
assert fresh_model.json == dict_values
|
||||
|
||||
def test_json_field_retains_keys(self):
|
||||
def test_json_field_retains_keys(self) -> None:
|
||||
plain_value = {"key": "value", "another_key": "some value"}
|
||||
|
||||
|
||||
model = TestModel()
|
||||
model.json = plain_value
|
||||
model.full_clean()
|
||||
|
|
@ -226,18 +225,18 @@ class FieldTest(TestCase):
|
|||
|
||||
ciphertext = json.loads(self.get_db_value("json", model.id))
|
||||
|
||||
self.assertEqual(plain_value.keys(), ciphertext.keys())
|
||||
assert plain_value.keys() == ciphertext.keys()
|
||||
|
||||
|
||||
class RotatedSaltTestCase(TestCase):
|
||||
@classmethod
|
||||
@override_settings(SALT_KEY=["abcdefghijklmnopqrstuvwxyz0123456789"])
|
||||
def setUpTestData(cls):
|
||||
def setUpTestData(cls) -> None:
|
||||
"""Create the initial record using the old salt"""
|
||||
cls.original = TestModel.objects.create(text="Oh hi test reader")
|
||||
|
||||
@override_settings(SALT_KEY=["newkeyhere", "abcdefghijklmnopqrstuvwxyz0123456789"])
|
||||
def test_rotated_salt(self):
|
||||
def test_rotated_salt(self) -> None:
|
||||
"""Change the salt, keep the old one as the last in the list for reading"""
|
||||
plaintext = "Oh hi test reader"
|
||||
model = TestModel()
|
||||
|
|
@ -246,15 +245,13 @@ class RotatedSaltTestCase(TestCase):
|
|||
|
||||
ciphertext = FieldTest.get_db_value(self, "text", model.id)
|
||||
|
||||
self.assertNotEqual(plaintext, ciphertext)
|
||||
self.assertTrue("test" not in ciphertext)
|
||||
assert plaintext != ciphertext
|
||||
assert "test" not in ciphertext
|
||||
|
||||
fresh_model = TestModel.objects.get(id=model.id)
|
||||
self.assertEqual(fresh_model.text, plaintext)
|
||||
assert fresh_model.text == plaintext
|
||||
|
||||
old_record = TestModel.objects.get(id=self.original.id)
|
||||
self.assertEqual(fresh_model.text, old_record.text)
|
||||
assert fresh_model.text == old_record.text
|
||||
|
||||
self.assertNotEqual(
|
||||
ciphertext, FieldTest.get_db_value(self, "text", self.original.pk)
|
||||
)
|
||||
assert ciphertext != FieldTest.get_db_value(self, "text", self.original.pk)
|
||||
|
|
|
|||
Loading…
Reference in a new issue