diff --git a/encrypted_fields/fields.py b/encrypted_fields/fields.py index 03e5c45..383d734 100644 --- a/encrypted_fields/fields.py +++ b/encrypted_fields/fields.py @@ -6,12 +6,12 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from django.conf import settings -from django.core import validators from django.db import models -from django.db.backends.base.operations import BaseDatabaseOperations from django.utils.functional import cached_property +from django.utils.encoding import force_bytes, force_str -class EncryptedFieldMixin(object): + +class EncryptedFieldMixin: @cached_property def keys(self): keys = [] @@ -21,7 +21,7 @@ class EncryptedFieldMixin(object): else [settings.SALT_KEY] ) for salt_key in salt_keys: - salt = bytes(salt_key, "utf-8") + salt = force_bytes(salt_key) kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, @@ -30,9 +30,7 @@ class EncryptedFieldMixin(object): backend=default_backend(), ) keys.append( - base64.urlsafe_b64encode( - kdf.derive(settings.SECRET_KEY.encode("utf-8")) - ) + base64.urlsafe_b64encode(kdf.derive(force_bytes(settings.SECRET_KEY))) ) return keys @@ -46,14 +44,14 @@ class EncryptedFieldMixin(object): """ To treat everything as text """ - return "TextField" + return getattr(self, "_internal_type", "TextField") def get_prep_value(self, value): value = super().get_prep_value(value) if value: if not isinstance(value, str): value = str(value) - return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8") + return force_str(self.f.encrypt(force_bytes(value))) return None def get_db_prep_value(self, value, connection, prepared=False): @@ -64,6 +62,13 @@ class EncryptedFieldMixin(object): def from_db_value(self, value, expression, connection): return self.to_python(value) + def decrypt(self, value): + try: + value = force_str(self.f.decrypt(force_bytes(value))) + except (InvalidToken, UnicodeEncodeError): + pass + return value + def to_python(self, value): if ( value is None @@ -71,13 +76,9 @@ class EncryptedFieldMixin(object): or hasattr(self, "_already_decrypted") ): return value - try: - value = self.f.decrypt(bytes(value, "utf-8")).decode("utf-8") - except InvalidToken: - pass - except UnicodeEncodeError: - pass - return super(EncryptedFieldMixin, self).to_python(value) + + value = self.decrypt(value) + return super().to_python(value) def clean(self, value, model_instance): """ @@ -89,6 +90,17 @@ class EncryptedFieldMixin(object): del self._already_decrypted return ret + @cached_property + def validators(self): + # Temporarily pretend to be whatever type of field we're masquerading + # as, for purposes of constructing validators (needed for + # IntegerField and subclasses). + self._internal_type = super().get_internal_type() + try: + return super().validators + finally: + del self._internal_type + class EncryptedCharField(EncryptedFieldMixin, models.CharField): pass @@ -103,40 +115,7 @@ class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField): class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField): - @cached_property - def validators(self): - # These validators can't be added at field initialization time since - # they're based on values retrieved from `connection`. - validators_ = [*self.default_validators, *self._validators] - internal_type = models.IntegerField().get_internal_type() - min_value, max_value = BaseDatabaseOperations.integer_field_ranges[internal_type] - if min_value is not None and not any( - ( - isinstance(validator, validators.MinValueValidator) - and ( - validator.limit_value() - if callable(validator.limit_value) - else validator.limit_value - ) - >= min_value - ) - for validator in validators_ - ): - validators_.append(validators.MinValueValidator(min_value)) - if max_value is not None and not any( - ( - isinstance(validator, validators.MaxValueValidator) - and ( - validator.limit_value() - if callable(validator.limit_value) - else validator.limit_value - ) - <= max_value - ) - for validator in validators_ - ): - validators_.append(validators.MaxValueValidator(max_value)) - return validators_ + pass class EncryptedDateField(EncryptedFieldMixin, models.DateField): @@ -160,10 +139,10 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField): if isinstance(value, dict): return {key: self._encrypt_values(data) for key, data in value.items()} elif isinstance(value, list): - return [self._encrypt_values(data) for data in value] + return [self._encrypt_values(data) for data in value] else: value = str(value) - return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8") + return force_str(self.f.encrypt(force_bytes(value))) def _decrypt_values(self, value): if value is None: @@ -171,10 +150,10 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField): if isinstance(value, dict): return {key: self._decrypt_values(data) for key, data in value.items()} elif isinstance(value, list): - return [self._decrypt_values(data) for data in value] + return [self._decrypt_values(data) for data in value] else: value = str(value) - return self.f.decrypt(bytes(value, "utf-8")).decode("utf-8") + return force_str(self.f.decrypt(force_bytes(value))) def get_prep_value(self, value): return json.dumps(self._encrypt_values(value=value), cls=self.encoder) @@ -182,17 +161,9 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField): def get_internal_type(self): return "JSONField" - def to_python(self, value): - if ( - value is None - or not isinstance(value, str) - or hasattr(self, "_already_decrypted") - ): - return value + def decrypt(self, value): try: value = self._decrypt_values(value=json.loads(value)) - except InvalidToken: + except (InvalidToken, UnicodeEncodeError): pass - except UnicodeEncodeError: - pass - return super(EncryptedFieldMixin, self).to_python(value) + return value diff --git a/tests/sqlite3/__init__.py b/tests/sqlite3/__init__.py new file mode 100644 index 0000000..e9001e5 --- /dev/null +++ b/tests/sqlite3/__init__.py @@ -0,0 +1 @@ +"""Override to default django SQLite backend required for integer validation test.""" diff --git a/tests/sqlite3/base.py b/tests/sqlite3/base.py new file mode 100644 index 0000000..73ad84e --- /dev/null +++ b/tests/sqlite3/base.py @@ -0,0 +1,16 @@ +from django.db.backends.sqlite3.base import DatabaseWrapper as BaseDatabaseWrapper +from django.db.backends.sqlite3.operations import ( + DatabaseOperations as BaseDatabaseOperations, +) + + +class DatabaseOperations(BaseDatabaseOperations): + def integer_field_range(self, internal_type): + # by default django does not enforce size on SQLite integers + # because it does not + # this is required to pass tests without using a real DB + return self.integer_field_ranges[internal_type] + + +class DatabaseWrapper(BaseDatabaseWrapper): + ops_class = DatabaseOperations