From cf489b71af0cd086227b9b3b284666682bef5f0a Mon Sep 17 00:00:00 2001 From: frgmt Date: Mon, 6 Jan 2025 11:16:57 +0900 Subject: [PATCH] Fix lint error --- encrypted_fields/__init__.py | 2 +- encrypted_fields/fields.py | 86 ++++++++++++-------- package_test/models.py | 14 +++- package_test/tests.py | 153 +++++++++++++++++------------------ 4 files changed, 141 insertions(+), 114 deletions(-) diff --git a/encrypted_fields/__init__.py b/encrypted_fields/__init__.py index 7746f2c..0a9ff24 100644 --- a/encrypted_fields/__init__.py +++ b/encrypted_fields/__init__.py @@ -1 +1 @@ -from .fields import * +from .fields import * # noqa: F403 diff --git a/encrypted_fields/fields.py b/encrypted_fields/fields.py index 03e5c45..82bc216 100644 --- a/encrypted_fields/fields.py +++ b/encrypted_fields/fields.py @@ -1,19 +1,27 @@ +from __future__ import annotations + import base64 import json +from typing import Any -from cryptography.fernet import Fernet, MultiFernet, InvalidToken +from cryptography.fernet import Fernet, InvalidToken, MultiFernet from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from django.conf import settings -from django.core import validators +from django.core.validators import MaxValueValidator, MinValueValidator from django.db import models +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.base.operations import BaseDatabaseOperations +from django.db.models.expressions import Expression from django.utils.functional import cached_property -class EncryptedFieldMixin(object): +_TypeAny = Any + + +class EncryptedFieldMixin: @cached_property - def keys(self): + def keys(self) -> list[bytes]: keys = [] salt_keys = ( settings.SALT_KEY @@ -37,18 +45,18 @@ class EncryptedFieldMixin(object): return keys @cached_property - def f(self): + def f(self) -> Fernet | MultiFernet: if len(self.keys) == 1: return Fernet(self.keys[0]) return MultiFernet([Fernet(k) for k in self.keys]) - def get_internal_type(self): + def get_internal_type(self) -> str: """ To treat everything as text """ return "TextField" - def get_prep_value(self, value): + def get_prep_value(self, value: _TypeAny) -> _TypeAny: value = super().get_prep_value(value) if value: if not isinstance(value, str): @@ -56,15 +64,25 @@ class EncryptedFieldMixin(object): return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8") return None - def get_db_prep_value(self, value, connection, prepared=False): + def get_db_prep_value( + self, + value: _TypeAny, + connection: BaseDatabaseWrapper, # noqa: ARG002 + prepared: bool = False, # noqa: FBT001, FBT002 + ) -> _TypeAny: if not prepared: value = self.get_prep_value(value) return value - def from_db_value(self, value, expression, connection): + def from_db_value( + self, + value: _TypeAny, + expression: Expression, # noqa: ARG002 + connection: BaseDatabaseWrapper, # noqa: ARG002 + ) -> _TypeAny: return self.to_python(value) - def to_python(self, value): + def to_python(self, value: _TypeAny) -> _TypeAny: if ( value is None or not isinstance(value, str) @@ -77,12 +95,12 @@ class EncryptedFieldMixin(object): pass except UnicodeEncodeError: pass - return super(EncryptedFieldMixin, self).to_python(value) + return super().to_python(value) - def clean(self, value, model_instance): + def clean(self, value: _TypeAny, model_instance: models.Field) -> _TypeAny: """ - Create and assign a semaphore so that to_python method will not try to decrypt an already decrypted value - during cleaning of a form + Create and assign a semaphore so that to_python method will not try + to decrypt an already decrypted value during cleaning of a form """ self._already_decrypted = True ret = super().clean(value, model_instance) @@ -104,15 +122,17 @@ class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField): class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField): @cached_property - def validators(self): + def validators(self) -> list[MinValueValidator | MaxValueValidator]: # These validators can't be added at field initialization time since # they're based on values retrieved from `connection`. validators_ = [*self.default_validators, *self._validators] internal_type = models.IntegerField().get_internal_type() - min_value, max_value = BaseDatabaseOperations.integer_field_ranges[internal_type] + min_value, max_value = BaseDatabaseOperations.integer_field_ranges[ + internal_type + ] if min_value is not None and not any( ( - isinstance(validator, validators.MinValueValidator) + isinstance(validator, MinValueValidator) and ( validator.limit_value() if callable(validator.limit_value) @@ -122,10 +142,10 @@ class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField): ) for validator in validators_ ): - validators_.append(validators.MinValueValidator(min_value)) + validators_.append(MinValueValidator(min_value)) if max_value is not None and not any( ( - isinstance(validator, validators.MaxValueValidator) + isinstance(validator, MaxValueValidator) and ( validator.limit_value() if callable(validator.limit_value) @@ -135,7 +155,7 @@ class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField): ) for validator in validators_ ): - validators_.append(validators.MaxValueValidator(max_value)) + validators_.append(MaxValueValidator(max_value)) return validators_ @@ -156,33 +176,31 @@ class EncryptedBooleanField(EncryptedFieldMixin, models.BooleanField): class EncryptedJSONField(EncryptedFieldMixin, models.JSONField): - def _encrypt_values(self, value): + def _encrypt_values(self, value: _TypeAny) -> _TypeAny: if isinstance(value, dict): return {key: self._encrypt_values(data) for key, data in value.items()} - elif isinstance(value, list): - return [self._encrypt_values(data) for data in value] - else: - value = str(value) + if isinstance(value, list): + return [self._encrypt_values(data) for data in value] + value = str(value) return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8") - def _decrypt_values(self, value): + def _decrypt_values(self, value: _TypeAny) -> _TypeAny: if value is None: return value if isinstance(value, dict): return {key: self._decrypt_values(data) for key, data in value.items()} - elif isinstance(value, list): - return [self._decrypt_values(data) for data in value] - else: - value = str(value) + if isinstance(value, list): + return [self._decrypt_values(data) for data in value] + value = str(value) return self.f.decrypt(bytes(value, "utf-8")).decode("utf-8") - def get_prep_value(self, value): + def get_prep_value(self, value: _TypeAny) -> str: return json.dumps(self._encrypt_values(value=value), cls=self.encoder) - def get_internal_type(self): + def get_internal_type(self) -> str: return "JSONField" - def to_python(self, value): + def to_python(self, value: _TypeAny) -> _TypeAny: if ( value is None or not isinstance(value, str) @@ -195,4 +213,4 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField): pass except UnicodeEncodeError: pass - return super(EncryptedFieldMixin, self).to_python(value) + return super().to_python(value) diff --git a/package_test/models.py b/package_test/models.py index 115e7c9..9d56770 100644 --- a/package_test/models.py +++ b/package_test/models.py @@ -1,4 +1,16 @@ -from encrypted_fields.fields import * +from django.db import models + +from encrypted_fields.fields import ( + EncryptedBooleanField, + EncryptedCharField, + EncryptedDateField, + EncryptedDateTimeField, + EncryptedEmailField, + EncryptedFloatField, + EncryptedIntegerField, + EncryptedJSONField, + EncryptedTextField, +) class TestModel(models.Model): diff --git a/package_test/tests.py b/package_test/tests.py index 9d2c9b1..d62e9b0 100644 --- a/package_test/tests.py +++ b/package_test/tests.py @@ -1,25 +1,24 @@ import json import re +import pytest +from django.core.exceptions import ValidationError from django.db import connection from django.test import TestCase, override_settings from django.utils import timezone -from django.core.exceptions import ValidationError from .models import TestModel class FieldTest(TestCase): - def get_db_value(self, field, model_id): + def get_db_value(self, field: str, model_id: int) -> None: cursor = connection.cursor() cursor.execute( - "select {0} " - "from package_test_testmodel " - "where id = {1};".format(field, model_id) + f"select {field} from package_test_testmodel where id = {model_id};" ) return cursor.fetchone()[0] - def test_char_field_encrypted(self): + def test_char_field_encrypted(self) -> None: plaintext = "Oh hi, test reader!" model = TestModel() @@ -29,13 +28,13 @@ class FieldTest(TestCase): ciphertext = self.get_db_value("char", model.id) - self.assertNotEqual(plaintext, ciphertext) - self.assertTrue("test" not in ciphertext) + assert plaintext != ciphertext + assert "test" not in ciphertext fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.char, plaintext) + assert fresh_model.char == plaintext - def test_text_field_encrypted(self): + def test_text_field_encrypted(self) -> None: plaintext = "Oh hi, test reader!" * 10 model = TestModel() @@ -45,13 +44,13 @@ class FieldTest(TestCase): ciphertext = self.get_db_value("text", model.id) - self.assertNotEqual(plaintext, ciphertext) - self.assertTrue("test" not in ciphertext) + assert plaintext != ciphertext + assert "test" not in ciphertext fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.text, plaintext) + assert fresh_model.text == plaintext - def test_datetime_field_encrypted(self): + def test_datetime_field_encrypted(self) -> None: plaintext = timezone.now() model = TestModel() @@ -62,19 +61,19 @@ class FieldTest(TestCase): ciphertext = self.get_db_value("datetime", model.id) # Django's normal date serialization format - self.assertTrue(re.search("^\d\d\d\d-\d\d-\d\d", ciphertext) is None) + assert re.search(r"^\d\d\d\d-\d\d-\d\d", ciphertext) is None fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.datetime, plaintext) + assert fresh_model.datetime == plaintext plaintext = "text" + model.datetime = plaintext + model.full_clean() - with self.assertRaises(ValidationError): - model.datetime = plaintext - model.full_clean() + with pytest.raises(ValidationError): model.save() - def test_integer_field_encrypted(self): + def test_integer_field_encrypted(self) -> None: plaintext = 42 model = TestModel() @@ -84,28 +83,26 @@ class FieldTest(TestCase): ciphertext = self.get_db_value("integer", model.id) - self.assertNotEqual(plaintext, ciphertext) - self.assertNotEqual(plaintext, str(ciphertext)) + assert plaintext != ciphertext + assert plaintext != str(ciphertext) fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.integer, plaintext) + assert fresh_model.integer == plaintext # "IntegerField": (-2147483648, 2147483647) plaintext = 2147483648 + model.integer = plaintext - with self.assertRaises(ValidationError): - model.integer = plaintext + with pytest.raises(ValidationError): model.full_clean() - model.save() plaintext = "text" + model.integer = plaintext - with self.assertRaises(TypeError): - model.integer = plaintext + with pytest.raises(TypeError): model.full_clean() - model.save() - def test_date_field_encrypted(self): + def test_date_field_encrypted(self) -> None: plaintext = timezone.now().date() model = TestModel() @@ -116,17 +113,17 @@ class FieldTest(TestCase): ciphertext = self.get_db_value("date", model.id) fresh_model = TestModel.objects.get(id=model.id) - self.assertNotEqual(ciphertext, plaintext.isoformat()) - self.assertEqual(fresh_model.date, plaintext) + assert ciphertext != plaintext.isoformat() + assert fresh_model.date == plaintext plaintext = "text" + model.date = plaintext + model.full_clean() - with self.assertRaises(ValidationError): - model.date = plaintext - model.full_clean() + with pytest.raises(ValidationError): model.save() - def test_float_field_encrypted(self): + def test_float_field_encrypted(self) -> None: plaintext = 42.44 model = TestModel() @@ -136,20 +133,20 @@ class FieldTest(TestCase): ciphertext = self.get_db_value("floating", model.id) - self.assertNotEqual(plaintext, ciphertext) - self.assertNotEqual(plaintext, str(ciphertext)) + assert plaintext != ciphertext + assert plaintext != str(ciphertext) fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.floating, plaintext) + assert fresh_model.floating == plaintext plaintext = "text" + model.floating = plaintext + model.full_clean() - with self.assertRaises(ValueError): - model.floating = plaintext - model.full_clean() + with pytest.raises(ValueError): model.save() - def test_email_field_encrypted(self): + def test_email_field_encrypted(self) -> None: plaintext = "test@gmail.com" model = TestModel() @@ -159,20 +156,19 @@ class FieldTest(TestCase): ciphertext = self.get_db_value("email", model.id) - self.assertNotEqual(plaintext, ciphertext) - self.assertTrue("aron" not in ciphertext) + assert plaintext != ciphertext + assert "aron" not in ciphertext fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.email, plaintext) + assert fresh_model.email == plaintext plaintext = "text" + model.email = plaintext - with self.assertRaises(ValidationError): - model.email = plaintext + with pytest.raises(ValidationError): model.full_clean() - model.save() - def test_boolean_field_encrypted(self): + def test_boolean_field_encrypted(self) -> None: plaintext = True model = TestModel() @@ -182,26 +178,30 @@ class FieldTest(TestCase): ciphertext = self.get_db_value("boolean", model.id) - self.assertNotEqual(plaintext, ciphertext) - self.assertNotEqual(True, ciphertext) - self.assertNotEqual("True", ciphertext) - self.assertNotEqual("true", ciphertext) - self.assertNotEqual("1", ciphertext) - self.assertNotEqual(1, ciphertext) - self.assertTrue(not isinstance(ciphertext, bool)) + assert plaintext != ciphertext + assert ciphertext is not True + assert ciphertext != "True" + assert ciphertext != "true" + assert ciphertext != "1" + assert ciphertext != 1 + assert not isinstance(ciphertext, bool) fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.boolean, plaintext) + assert fresh_model.boolean == plaintext plaintext = "text" + model.boolean = plaintext + model.full_clean() - with self.assertRaises(ValidationError): - model.boolean = plaintext - model.full_clean() + with pytest.raises(ValidationError): model.save() - def test_json_field_encrypted(self): - dict_values = {"key": "value", "list": ["nested", {"key": "val"}], "nested": {"child": "sibling"}} + def test_json_field_encrypted(self) -> None: + dict_values = { + "key": "value", + "list": ["nested", {"key": "val"}], + "nested": {"child": "sibling"}, + } model = TestModel() model.json = dict_values @@ -210,15 +210,14 @@ class FieldTest(TestCase): ciphertext = json.loads(self.get_db_value("json", model.id)) - - self.assertNotEqual(dict_values, ciphertext) + assert dict_values != ciphertext fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.json, dict_values) + assert fresh_model.json == dict_values - def test_json_field_retains_keys(self): + def test_json_field_retains_keys(self) -> None: plain_value = {"key": "value", "another_key": "some value"} - + model = TestModel() model.json = plain_value model.full_clean() @@ -226,18 +225,18 @@ class FieldTest(TestCase): ciphertext = json.loads(self.get_db_value("json", model.id)) - self.assertEqual(plain_value.keys(), ciphertext.keys()) + assert plain_value.keys() == ciphertext.keys() class RotatedSaltTestCase(TestCase): @classmethod @override_settings(SALT_KEY=["abcdefghijklmnopqrstuvwxyz0123456789"]) - def setUpTestData(cls): + def setUpTestData(cls) -> None: """Create the initial record using the old salt""" cls.original = TestModel.objects.create(text="Oh hi test reader") @override_settings(SALT_KEY=["newkeyhere", "abcdefghijklmnopqrstuvwxyz0123456789"]) - def test_rotated_salt(self): + def test_rotated_salt(self) -> None: """Change the salt, keep the old one as the last in the list for reading""" plaintext = "Oh hi test reader" model = TestModel() @@ -246,15 +245,13 @@ class RotatedSaltTestCase(TestCase): ciphertext = FieldTest.get_db_value(self, "text", model.id) - self.assertNotEqual(plaintext, ciphertext) - self.assertTrue("test" not in ciphertext) + assert plaintext != ciphertext + assert "test" not in ciphertext fresh_model = TestModel.objects.get(id=model.id) - self.assertEqual(fresh_model.text, plaintext) + assert fresh_model.text == plaintext old_record = TestModel.objects.get(id=self.original.id) - self.assertEqual(fresh_model.text, old_record.text) + assert fresh_model.text == old_record.text - self.assertNotEqual( - ciphertext, FieldTest.get_db_value(self, "text", self.original.pk) - ) + assert ciphertext != FieldTest.get_db_value(self, "text", self.original.pk)