diff --git a/package_test/tests.py b/package_test/tests.py index abeb65a..f1e7a4a 100644 --- a/package_test/tests.py +++ b/package_test/tests.py @@ -195,6 +195,19 @@ class FieldTest(TestCase): with pytest.raises(ValidationError): model.save() + + def test_encrypted_boolean_field_preserves_true_false_and_none(self) -> None: + true_obj = TestModel.objects.create(boolean=True) + false_obj = TestModel.objects.create(boolean=False) + none_obj = TestModel.objects.create(boolean=None) + + true_obj.refresh_from_db() + false_obj.refresh_from_db() + none_obj.refresh_from_db() + + assert true_obj.boolean is True + assert false_obj.boolean is False + assert none_obj.boolean is None def test_json_field_encrypted(self) -> None: dict_values = { diff --git a/src/encrypted_fields/fields.py b/src/encrypted_fields/fields.py index 33a0622..77984f8 100644 --- a/src/encrypted_fields/fields.py +++ b/src/encrypted_fields/fields.py @@ -174,7 +174,11 @@ class EncryptedEmailField(EncryptedFieldMixin, models.EmailField): class EncryptedBooleanField(EncryptedFieldMixin, models.BooleanField): - pass + def get_prep_value(self, value: _TypeAny) -> _TypeAny: + value = models.BooleanField.get_prep_value(self, value) + if value is None: + return None + return self.f.encrypt(str(value).encode("utf-8")).decode("utf-8") class EncryptedJSONField(EncryptedFieldMixin, models.JSONField):