From 9ec3563bfa1d27803004776b0bded64670da13cd Mon Sep 17 00:00:00 2001 From: naohide anahara <57.x.mas@gmail.com> Date: Fri, 6 May 2022 22:38:10 +0900 Subject: [PATCH] Fix EncryptedIntegerField --- encrypted_fields/fields.py | 17 ++++++++++++++++- package_test/tests.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/encrypted_fields/fields.py b/encrypted_fields/fields.py index 4ef0bba..8e4dc2b 100644 --- a/encrypted_fields/fields.py +++ b/encrypted_fields/fields.py @@ -1,9 +1,12 @@ import base64 -from django.conf import settings +from django.utils import timezone + +import warnings from cryptography.fernet import Fernet, MultiFernet from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from django.conf import settings from django.db import models from django.utils.functional import cached_property @@ -94,6 +97,18 @@ class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField): class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField): + def get_prep_value(self, value): + if value is None: + return None + try: + value = int(value) + except (TypeError, ValueError) as e: + raise e.__class__( + "Field '%s' expected a number but got %r." % (self.name, value), + ) from e + else: + return super().get_prep_value(value) + @cached_property def validators(self): return [*self.default_validators, *self._validators] diff --git a/package_test/tests.py b/package_test/tests.py index 4372e4a..9e9c2c2 100644 --- a/package_test/tests.py +++ b/package_test/tests.py @@ -62,6 +62,11 @@ class FieldTest(TestCase): fresh_model = TestModel.objects.get(id=model.id) self.assertEqual(fresh_model.datetime, plaintext) + plaintext = "text" + + model.datetime = plaintext + model.save() + def test_integer_field_encrypted(self): plaintext = 42 @@ -77,6 +82,12 @@ class FieldTest(TestCase): fresh_model = TestModel.objects.get(id=model.id) self.assertEqual(fresh_model.integer, plaintext) + plaintext = "text" + + with self.assertRaises(ValueError): + model.integer = plaintext + model.save() + def test_date_field_encrypted(self): plaintext = timezone.now().date() @@ -90,6 +101,11 @@ class FieldTest(TestCase): self.assertNotEqual(ciphertext, plaintext.isoformat()) self.assertEqual(fresh_model.date, plaintext) + plaintext = "text" + + model.date = plaintext + model.save() + def test_float_field_encrypted(self): plaintext = 42.44 @@ -105,6 +121,11 @@ class FieldTest(TestCase): fresh_model = TestModel.objects.get(id=model.id) self.assertEqual(fresh_model.floating, plaintext) + plaintext = "text" + + model.floating = plaintext + model.save() + def test_email_field_encrypted(self): plaintext = "test@gmail.com" @@ -120,6 +141,11 @@ class FieldTest(TestCase): fresh_model = TestModel.objects.get(id=model.id) self.assertEqual(fresh_model.email, plaintext) + plaintext = "text" + + model.email = plaintext + model.save() + def test_boolean_field_encrypted(self): plaintext = True @@ -140,6 +166,11 @@ class FieldTest(TestCase): fresh_model = TestModel.objects.get(id=model.id) self.assertEqual(fresh_model.boolean, plaintext) + plaintext = "text" + + model.boolean = plaintext + model.save() + class RotatedSaltTestCase(TestCase): @classmethod