Change validation of integer field to reflect current database connection and avoid duplicating code from django

This commit is contained in:
Davide 2024-04-09 11:18:52 +02:00
parent 470ecbaccb
commit 844e0f783a
No known key found for this signature in database
GPG key ID: D939AF7A93A9C178
3 changed files with 53 additions and 65 deletions

View file

@ -6,12 +6,12 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from django.conf import settings
from django.core import validators
from django.db import models
from django.db.backends.base.operations import BaseDatabaseOperations
from django.utils.functional import cached_property
from django.utils.encoding import force_bytes, force_str
class EncryptedFieldMixin(object):
class EncryptedFieldMixin:
@cached_property
def keys(self):
keys = []
@ -21,7 +21,7 @@ class EncryptedFieldMixin(object):
else [settings.SALT_KEY]
)
for salt_key in salt_keys:
salt = bytes(salt_key, "utf-8")
salt = force_bytes(salt_key)
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
@ -30,9 +30,7 @@ class EncryptedFieldMixin(object):
backend=default_backend(),
)
keys.append(
base64.urlsafe_b64encode(
kdf.derive(settings.SECRET_KEY.encode("utf-8"))
)
base64.urlsafe_b64encode(kdf.derive(force_bytes(settings.SECRET_KEY)))
)
return keys
@ -46,14 +44,14 @@ class EncryptedFieldMixin(object):
"""
To treat everything as text
"""
return "TextField"
return getattr(self, "_internal_type", "TextField")
def get_prep_value(self, value):
value = super().get_prep_value(value)
if value:
if not isinstance(value, str):
value = str(value)
return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8")
return force_str(self.f.encrypt(force_bytes(value)))
return None
def get_db_prep_value(self, value, connection, prepared=False):
@ -64,6 +62,13 @@ class EncryptedFieldMixin(object):
def from_db_value(self, value, expression, connection):
return self.to_python(value)
def decrypt(self, value):
try:
value = force_str(self.f.decrypt(force_bytes(value)))
except (InvalidToken, UnicodeEncodeError):
pass
return value
def to_python(self, value):
if (
value is None
@ -71,13 +76,9 @@ class EncryptedFieldMixin(object):
or hasattr(self, "_already_decrypted")
):
return value
try:
value = self.f.decrypt(bytes(value, "utf-8")).decode("utf-8")
except InvalidToken:
pass
except UnicodeEncodeError:
pass
return super(EncryptedFieldMixin, self).to_python(value)
value = self.decrypt(value)
return super().to_python(value)
def clean(self, value, model_instance):
"""
@ -89,6 +90,17 @@ class EncryptedFieldMixin(object):
del self._already_decrypted
return ret
@cached_property
def validators(self):
# Temporarily pretend to be whatever type of field we're masquerading
# as, for purposes of constructing validators (needed for
# IntegerField and subclasses).
self._internal_type = super().get_internal_type()
try:
return super().validators
finally:
del self._internal_type
class EncryptedCharField(EncryptedFieldMixin, models.CharField):
pass
@ -103,40 +115,7 @@ class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField):
class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
@cached_property
def validators(self):
# These validators can't be added at field initialization time since
# they're based on values retrieved from `connection`.
validators_ = [*self.default_validators, *self._validators]
internal_type = models.IntegerField().get_internal_type()
min_value, max_value = BaseDatabaseOperations.integer_field_ranges[internal_type]
if min_value is not None and not any(
(
isinstance(validator, validators.MinValueValidator)
and (
validator.limit_value()
if callable(validator.limit_value)
else validator.limit_value
)
>= min_value
)
for validator in validators_
):
validators_.append(validators.MinValueValidator(min_value))
if max_value is not None and not any(
(
isinstance(validator, validators.MaxValueValidator)
and (
validator.limit_value()
if callable(validator.limit_value)
else validator.limit_value
)
<= max_value
)
for validator in validators_
):
validators_.append(validators.MaxValueValidator(max_value))
return validators_
pass
class EncryptedDateField(EncryptedFieldMixin, models.DateField):
@ -160,10 +139,10 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField):
if isinstance(value, dict):
return {key: self._encrypt_values(data) for key, data in value.items()}
elif isinstance(value, list):
return [self._encrypt_values(data) for data in value]
return [self._encrypt_values(data) for data in value]
else:
value = str(value)
return self.f.encrypt(bytes(value, "utf-8")).decode("utf-8")
return force_str(self.f.encrypt(force_bytes(value)))
def _decrypt_values(self, value):
if value is None:
@ -171,10 +150,10 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField):
if isinstance(value, dict):
return {key: self._decrypt_values(data) for key, data in value.items()}
elif isinstance(value, list):
return [self._decrypt_values(data) for data in value]
return [self._decrypt_values(data) for data in value]
else:
value = str(value)
return self.f.decrypt(bytes(value, "utf-8")).decode("utf-8")
return force_str(self.f.decrypt(force_bytes(value)))
def get_prep_value(self, value):
return json.dumps(self._encrypt_values(value=value), cls=self.encoder)
@ -182,17 +161,9 @@ class EncryptedJSONField(EncryptedFieldMixin, models.JSONField):
def get_internal_type(self):
return "JSONField"
def to_python(self, value):
if (
value is None
or not isinstance(value, str)
or hasattr(self, "_already_decrypted")
):
return value
def decrypt(self, value):
try:
value = self._decrypt_values(value=json.loads(value))
except InvalidToken:
except (InvalidToken, UnicodeEncodeError):
pass
except UnicodeEncodeError:
pass
return super(EncryptedFieldMixin, self).to_python(value)
return value

View file

@ -0,0 +1 @@
"""Override to default django SQLite backend required for integer validation test."""

16
tests/sqlite3/base.py Normal file
View file

@ -0,0 +1,16 @@
from django.db.backends.sqlite3.base import DatabaseWrapper as BaseDatabaseWrapper
from django.db.backends.sqlite3.operations import (
DatabaseOperations as BaseDatabaseOperations,
)
class DatabaseOperations(BaseDatabaseOperations):
def integer_field_range(self, internal_type):
# by default django does not enforce size on SQLite integers
# because it does not
# this is required to pass tests without using a real DB
return self.integer_field_ranges[internal_type]
class DatabaseWrapper(BaseDatabaseWrapper):
ops_class = DatabaseOperations