Fix EncryptedIntegerField

This commit is contained in:
naohide anahara 2022-05-06 22:38:10 +09:00
parent 5c00880dd1
commit 9ec3563bfa
2 changed files with 47 additions and 1 deletions

View file

@ -1,9 +1,12 @@
import base64 import base64
from django.conf import settings from django.utils import timezone
import warnings
from cryptography.fernet import Fernet, MultiFernet from cryptography.fernet import Fernet, MultiFernet
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from django.conf import settings
from django.db import models from django.db import models
from django.utils.functional import cached_property from django.utils.functional import cached_property
@ -94,6 +97,18 @@ class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField):
class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField): class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
def get_prep_value(self, value):
if value is None:
return None
try:
value = int(value)
except (TypeError, ValueError) as e:
raise e.__class__(
"Field '%s' expected a number but got %r." % (self.name, value),
) from e
else:
return super().get_prep_value(value)
@cached_property @cached_property
def validators(self): def validators(self):
return [*self.default_validators, *self._validators] return [*self.default_validators, *self._validators]

View file

@ -62,6 +62,11 @@ class FieldTest(TestCase):
fresh_model = TestModel.objects.get(id=model.id) fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.datetime, plaintext) self.assertEqual(fresh_model.datetime, plaintext)
plaintext = "text"
model.datetime = plaintext
model.save()
def test_integer_field_encrypted(self): def test_integer_field_encrypted(self):
plaintext = 42 plaintext = 42
@ -77,6 +82,12 @@ class FieldTest(TestCase):
fresh_model = TestModel.objects.get(id=model.id) fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.integer, plaintext) self.assertEqual(fresh_model.integer, plaintext)
plaintext = "text"
with self.assertRaises(ValueError):
model.integer = plaintext
model.save()
def test_date_field_encrypted(self): def test_date_field_encrypted(self):
plaintext = timezone.now().date() plaintext = timezone.now().date()
@ -90,6 +101,11 @@ class FieldTest(TestCase):
self.assertNotEqual(ciphertext, plaintext.isoformat()) self.assertNotEqual(ciphertext, plaintext.isoformat())
self.assertEqual(fresh_model.date, plaintext) self.assertEqual(fresh_model.date, plaintext)
plaintext = "text"
model.date = plaintext
model.save()
def test_float_field_encrypted(self): def test_float_field_encrypted(self):
plaintext = 42.44 plaintext = 42.44
@ -105,6 +121,11 @@ class FieldTest(TestCase):
fresh_model = TestModel.objects.get(id=model.id) fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.floating, plaintext) self.assertEqual(fresh_model.floating, plaintext)
plaintext = "text"
model.floating = plaintext
model.save()
def test_email_field_encrypted(self): def test_email_field_encrypted(self):
plaintext = "test@gmail.com" plaintext = "test@gmail.com"
@ -120,6 +141,11 @@ class FieldTest(TestCase):
fresh_model = TestModel.objects.get(id=model.id) fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.email, plaintext) self.assertEqual(fresh_model.email, plaintext)
plaintext = "text"
model.email = plaintext
model.save()
def test_boolean_field_encrypted(self): def test_boolean_field_encrypted(self):
plaintext = True plaintext = True
@ -140,6 +166,11 @@ class FieldTest(TestCase):
fresh_model = TestModel.objects.get(id=model.id) fresh_model = TestModel.objects.get(id=model.id)
self.assertEqual(fresh_model.boolean, plaintext) self.assertEqual(fresh_model.boolean, plaintext)
plaintext = "text"
model.boolean = plaintext
model.save()
class RotatedSaltTestCase(TestCase): class RotatedSaltTestCase(TestCase):
@classmethod @classmethod