Fix lint error

This commit is contained in:
frgmt 2025-01-06 11:16:57 +09:00
parent 755636a27a
commit cf489b71af
4 changed files with 141 additions and 114 deletions

View file

@ -1 +1 @@
from .fields import *
from .fields import * # noqa: F403

View file

@ -1,19 +1,27 @@
from __future__ import annotations
import base64
import json
from typing import Any
from cryptography.fernet import Fernet, MultiFernet, InvalidToken
from cryptography.fernet import Fernet, InvalidToken, MultiFernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from django.conf import settings
from django.core import validators
from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models.expressions import Expression
from django.utils.functional import cached_property
class EncryptedFieldMixin(object):
_TypeAny = Any
class EncryptedFieldMixin:
@cached_property
def keys(self):
def keys(self) -> list[bytes]:
keys = []
salt_keys = (
settings.SALT_KEY
@ -37,18 +45,18 @@ class EncryptedFieldMixin(object):
return keys
@cached_property
def f(self):
def f(self) -> Fernet | MultiFernet:
if len(self.keys) == 1:
return Fernet(self.keys[0])
return MultiFernet([Fernet(k) for k in self.keys])
def get_internal_type(self):
def get_internal_type(self) -> str:
"""
To treat everything as text
"""
return "TextField"
def get_prep_value(self, value):
def get_prep_value(self, value: _TypeAny) -> _TypeAny:
value = super().get_prep_value(value)
if value:
if not isinstance(value, str):
@ -56,15 +64,25 @@ class EncryptedFieldMixin(object):
return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8")
return None
def get_db_prep_value(self, value, connection, prepared=False):
def get_db_prep_value(
self,
value: _TypeAny,
connection: BaseDatabaseWrapper, # noqa: ARG002
prepared: bool = False, # noqa: FBT001, FBT002
) -> _TypeAny:
if not prepared:
value = self.get_prep_value(value)
return value
def from_db_value(self, value, expression, connection):
def from_db_value(
self,
value: _TypeAny,
expression: Expression, # noqa: ARG002
connection: BaseDatabaseWrapper, # noqa: ARG002
) -> _TypeAny:
return self.to_python(value)
def to_python(self, value):
def to_python(self, value: _TypeAny) -> _TypeAny:
if (
value is None
or not isinstance(value, str)
@ -77,12 +95,12 @@ class EncryptedFieldMixin(object):
pass
except UnicodeEncodeError:
pass
return super(EncryptedFieldMixin, self).to_python(value)
return super().to_python(value)
def clean(self, value, model_instance):
def clean(self, value: _TypeAny, model_instance: models.Field) -> _TypeAny:
"""
Create and assign a semaphore so that to_python method will not try to decrypt an already decrypted value
during cleaning of a form
Create and assign a semaphore so that to_python method will not try
to decrypt an already decrypted value during cleaning of a form
"""
self._already_decrypted = True
ret = super().clean(value, model_instance)
@ -104,15 +122,17 @@ class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField):
class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
@cached_property
def validators(self):
def validators(self) -> list[MinValueValidator | MaxValueValidator]:
# These validators can't be added at field initialization time since
# they're based on values retrieved from `connection`.
validators_ = [*self.default_validators, *self._validators]
internal_type = models.IntegerField().get_internal_type()
min_value, max_value = BaseDatabaseOperations.integer_field_ranges[internal_type]
min_value, max_value = BaseDatabaseOperations.integer_field_ranges[
internal_type
]
if min_value is not None and not any(
(
isinstance(validator, validators.MinValueValidator)
isinstance(validator, MinValueValidator)
and (
validator.limit_value()
if callable(validator.limit_value)
@ -122,10 +142,10 @@ class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
)
for validator in validators_
):
validators_.append(validators.MinValueValidator(min_value))
validators_.append(MinValueValidator(min_value))
if max_value is not None and not any(
(
isinstance(validator, validators.MaxValueValidator)
isinstance(validator, MaxValueValidator)
and (
validator.limit_value()
if callable(validator.limit_value)
@ -135,7 +155,7 @@ class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
)
for validator in validators_
):
validators_.append(validators.MaxValueValidator(max_value))
validators_.append(MaxValueValidator(max_value))
return validators_
@ -156,33 +176,31 @@ class EncryptedBooleanField(EncryptedFieldMixin, models.BooleanField):
class EncryptedJSONField(EncryptedFieldMixin, models.JSONField):
def _encrypt_values(self, value):
def _encrypt_values(self, value: _TypeAny) -> _TypeAny:
if isinstance(value, dict):
return {key: self._encrypt_values(data) for key, data in value.items()}
elif isinstance(value, list):
return [self._encrypt_values(data) for data in value]
else:
value = str(value)
if isinstance(value, list):
return [self._encrypt_values(data) for data in value]
value = str(value)
return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8")
def _decrypt_values(self, value):
def _decrypt_values(self, value: _TypeAny) -> _TypeAny:
if value is None:
return value
if isinstance(value, dict):
return {key: self._decrypt_values(data) for key, data in value.items()}
elif isinstance(value, list):
return [self._decrypt_values(data) for data in value]
else:
value = str(value)
if isinstance(value, list):
return [self._decrypt_values(data) for data in value]
value = str(value)
return self.f.decrypt(bytes(value, "utf-8")).decode("utf-8")
def get_prep_value(self, value):
def get_prep_value(self, value: _TypeAny) -> str:
return json.dumps(self._encrypt_values(value=value), cls=self.encoder)
def get_internal_type(self):
def get_internal_type(self) -> str:
return "JSONField"
def to_python(self, value):
def to_python(self, value: _TypeAny) -> _TypeAny:
if (
value is None
or not isinstance(value, str)
@ -195,4 +213,4 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField):
pass
except UnicodeEncodeError:
pass
return super(EncryptedFieldMixin, self).to_python(value)
return super().to_python(value)

View file

@ -1,4 +1,16 @@
from encrypted_fields.fields import *
from django.db import models
from encrypted_fields.fields import (
EncryptedBooleanField,
EncryptedCharField,
EncryptedDateField,
EncryptedDateTimeField,
EncryptedEmailField,
EncryptedFloatField,
EncryptedIntegerField,
EncryptedJSONField,
EncryptedTextField,
)
class TestModel(models.Model):

View file

@ -1,25 +1,24 @@
import json
import re
import pytest
from django.core.exceptions import ValidationError
from django.db import connection
from django.test import TestCase, override_settings
from django.utils import timezone
from django.core.exceptions import ValidationError
from .models import TestModel
class FieldTest(TestCase):
def get_db_value(self, field, model_id):
def get_db_value(self, field: str, model_id: int) -> None:
cursor = connection.cursor()
cursor.execute(
"select {0} "
"from package_test_testmodel "
"where id = {1};".format(field, model_id)
f"select {field} from package_test_testmodel where id = {model_id};"
)
return cursor.fetchone()[0]
def test_char_field_encrypted(self):
def test_char_field_encrypted(self) -> None:
plaintext = "Oh hi, test reader!"
model = TestModel()
@ -29,13 +28,13 @@ class FieldTest(TestCase):
ciphertext = self.get_db_value("char", model.id)
self.assertNotEqual(plaintext, ciphertext)
self.assertTrue("test" not in ciphertext)
assert plaintext != ciphertext
assert "test" not in ciphertext
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.char, plaintext)
assert fresh_model.char == plaintext
def test_text_field_encrypted(self):
def test_text_field_encrypted(self) -> None:
plaintext = "Oh hi, test reader!" * 10
model = TestModel()
@ -45,13 +44,13 @@ class FieldTest(TestCase):
ciphertext = self.get_db_value("text", model.id)
self.assertNotEqual(plaintext, ciphertext)
self.assertTrue("test" not in ciphertext)
assert plaintext != ciphertext
assert "test" not in ciphertext
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.text, plaintext)
assert fresh_model.text == plaintext
def test_datetime_field_encrypted(self):
def test_datetime_field_encrypted(self) -> None:
plaintext = timezone.now()
model = TestModel()
@ -62,19 +61,19 @@ class FieldTest(TestCase):
ciphertext = self.get_db_value("datetime", model.id)
# Django's normal date serialization format
self.assertTrue(re.search("^\d\d\d\d-\d\d-\d\d", ciphertext) is None)
assert re.search(r"^\d\d\d\d-\d\d-\d\d", ciphertext) is None
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.datetime, plaintext)
assert fresh_model.datetime == plaintext
plaintext = "text"
model.datetime = plaintext
model.full_clean()
with self.assertRaises(ValidationError):
model.datetime = plaintext
model.full_clean()
with pytest.raises(ValidationError):
model.save()
def test_integer_field_encrypted(self):
def test_integer_field_encrypted(self) -> None:
plaintext = 42
model = TestModel()
@ -84,28 +83,26 @@ class FieldTest(TestCase):
ciphertext = self.get_db_value("integer", model.id)
self.assertNotEqual(plaintext, ciphertext)
self.assertNotEqual(plaintext, str(ciphertext))
assert plaintext != ciphertext
assert plaintext != str(ciphertext)
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.integer, plaintext)
assert fresh_model.integer == plaintext
# "IntegerField": (-2147483648, 2147483647)
plaintext = 2147483648
model.integer = plaintext
with self.assertRaises(ValidationError):
model.integer = plaintext
with pytest.raises(ValidationError):
model.full_clean()
model.save()
plaintext = "text"
model.integer = plaintext
with self.assertRaises(TypeError):
model.integer = plaintext
with pytest.raises(TypeError):
model.full_clean()
model.save()
def test_date_field_encrypted(self):
def test_date_field_encrypted(self) -> None:
plaintext = timezone.now().date()
model = TestModel()
@ -116,17 +113,17 @@ class FieldTest(TestCase):
ciphertext = self.get_db_value("date", model.id)
fresh_model = TestModel.objects.get(id=model.id)
self.assertNotEqual(ciphertext, plaintext.isoformat())
self.assertEqual(fresh_model.date, plaintext)
assert ciphertext != plaintext.isoformat()
assert fresh_model.date == plaintext
plaintext = "text"
model.date = plaintext
model.full_clean()
with self.assertRaises(ValidationError):
model.date = plaintext
model.full_clean()
with pytest.raises(ValidationError):
model.save()
def test_float_field_encrypted(self):
def test_float_field_encrypted(self) -> None:
plaintext = 42.44
model = TestModel()
@ -136,20 +133,20 @@ class FieldTest(TestCase):
ciphertext = self.get_db_value("floating", model.id)
self.assertNotEqual(plaintext, ciphertext)
self.assertNotEqual(plaintext, str(ciphertext))
assert plaintext != ciphertext
assert plaintext != str(ciphertext)
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.floating, plaintext)
assert fresh_model.floating == plaintext
plaintext = "text"
model.floating = plaintext
model.full_clean()
with self.assertRaises(ValueError):
model.floating = plaintext
model.full_clean()
with pytest.raises(ValueError):
model.save()
def test_email_field_encrypted(self):
def test_email_field_encrypted(self) -> None:
plaintext = "test@gmail.com"
model = TestModel()
@ -159,20 +156,19 @@ class FieldTest(TestCase):
ciphertext = self.get_db_value("email", model.id)
self.assertNotEqual(plaintext, ciphertext)
self.assertTrue("aron" not in ciphertext)
assert plaintext != ciphertext
assert "aron" not in ciphertext
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.email, plaintext)
assert fresh_model.email == plaintext
plaintext = "text"
model.email = plaintext
with self.assertRaises(ValidationError):
model.email = plaintext
with pytest.raises(ValidationError):
model.full_clean()
model.save()
def test_boolean_field_encrypted(self):
def test_boolean_field_encrypted(self) -> None:
plaintext = True
model = TestModel()
@ -182,26 +178,30 @@ class FieldTest(TestCase):
ciphertext = self.get_db_value("boolean", model.id)
self.assertNotEqual(plaintext, ciphertext)
self.assertNotEqual(True, ciphertext)
self.assertNotEqual("True", ciphertext)
self.assertNotEqual("true", ciphertext)
self.assertNotEqual("1", ciphertext)
self.assertNotEqual(1, ciphertext)
self.assertTrue(not isinstance(ciphertext, bool))
assert plaintext != ciphertext
assert ciphertext is not True
assert ciphertext != "True"
assert ciphertext != "true"
assert ciphertext != "1"
assert ciphertext != 1
assert not isinstance(ciphertext, bool)
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.boolean, plaintext)
assert fresh_model.boolean == plaintext
plaintext = "text"
model.boolean = plaintext
model.full_clean()
with self.assertRaises(ValidationError):
model.boolean = plaintext
model.full_clean()
with pytest.raises(ValidationError):
model.save()
def test_json_field_encrypted(self):
dict_values = {"key": "value", "list": ["nested", {"key": "val"}], "nested": {"child": "sibling"}}
def test_json_field_encrypted(self) -> None:
dict_values = {
"key": "value",
"list": ["nested", {"key": "val"}],
"nested": {"child": "sibling"},
}
model = TestModel()
model.json = dict_values
@ -210,15 +210,14 @@ class FieldTest(TestCase):
ciphertext = json.loads(self.get_db_value("json", model.id))
self.assertNotEqual(dict_values, ciphertext)
assert dict_values != ciphertext
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.json, dict_values)
assert fresh_model.json == dict_values
def test_json_field_retains_keys(self):
def test_json_field_retains_keys(self) -> None:
plain_value = {"key": "value", "another_key": "some value"}
model = TestModel()
model.json = plain_value
model.full_clean()
@ -226,18 +225,18 @@ class FieldTest(TestCase):
ciphertext = json.loads(self.get_db_value("json", model.id))
self.assertEqual(plain_value.keys(), ciphertext.keys())
assert plain_value.keys() == ciphertext.keys()
class RotatedSaltTestCase(TestCase):
@classmethod
@override_settings(SALT_KEY=["abcdefghijklmnopqrstuvwxyz0123456789"])
def setUpTestData(cls):
def setUpTestData(cls) -> None:
"""Create the initial record using the old salt"""
cls.original = TestModel.objects.create(text="Oh hi test reader")
@override_settings(SALT_KEY=["newkeyhere", "abcdefghijklmnopqrstuvwxyz0123456789"])
def test_rotated_salt(self):
def test_rotated_salt(self) -> None:
"""Change the salt, keep the old one as the last in the list for reading"""
plaintext = "Oh hi test reader"
model = TestModel()
@ -246,15 +245,13 @@ class RotatedSaltTestCase(TestCase):
ciphertext = FieldTest.get_db_value(self, "text", model.id)
self.assertNotEqual(plaintext, ciphertext)
self.assertTrue("test" not in ciphertext)
assert plaintext != ciphertext
assert "test" not in ciphertext
fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.text, plaintext)
assert fresh_model.text == plaintext
old_record = TestModel.objects.get(id=self.original.id)
self.assertEqual(fresh_model.text, old_record.text)
assert fresh_model.text == old_record.text
self.assertNotEqual(
ciphertext, FieldTest.get_db_value(self, "text", self.original.pk)
)
assert ciphertext != FieldTest.get_db_value(self, "text", self.original.pk)