diff --git a/manage.py b/manage.py index dbe5fea..4a8dc4c 100644 --- a/manage.py +++ b/manage.py @@ -2,7 +2,7 @@ import os import sys if __name__ == "__main__": - os.environ["DJANGO_SETTINGS_MODULE"] = "package_test.settings" + os.environ["DJANGO_SETTINGS_MODULE"] = "tests.settings" from django.core.management import execute_from_command_line diff --git a/package_test/__init__.py b/tests/__init__.py similarity index 100% rename from package_test/__init__.py rename to tests/__init__.py diff --git a/package_test/models.py b/tests/models.py similarity index 100% rename from package_test/models.py rename to tests/models.py diff --git a/package_test/settings.py b/tests/settings.py similarity index 63% rename from package_test/settings.py rename to tests/settings.py index e169bcb..72b4ee6 100644 --- a/package_test/settings.py +++ b/tests/settings.py @@ -1,6 +1,6 @@ DATABASES = { "default": { - "ENGINE": "django.db.backends.sqlite3", + "ENGINE": "tests.sqlite3", "NAME": ":memory:", }, } @@ -8,7 +8,7 @@ DATABASES = { SECRET_KEY = "abc" SALT_KEY = "xyz" -INSTALLED_APPS = ("encrypted_fields", "package_test") +INSTALLED_APPS = ("encrypted_fields", "tests") MIDDLEWARE_CLASSES = [] DEFAULT_AUTO_FIELD = "django.db.models.AutoField" diff --git a/package_test/tests.py b/tests/tests.py similarity index 74% rename from package_test/tests.py rename to tests/tests.py index 9d2c9b1..f0a9dce 100644 --- a/package_test/tests.py +++ b/tests/tests.py @@ -1,7 +1,7 @@ import json -import re -from django.db import connection +from django.db.models import CharField +from django.db.models.functions import Cast from django.test import TestCase, override_settings from django.utils import timezone from django.core.exceptions import ValidationError @@ -10,29 +10,28 @@ from .models import TestModel class FieldTest(TestCase): - def get_db_value(self, field, model_id): - cursor = connection.cursor() - cursor.execute( - "select {0} " - "from package_test_testmodel " - "where id = {1};".format(field, model_id) + def get_db_value(self, field, pk): + queryset = ( + TestModel.objects.filter(pk=pk) + .annotate(raw_field=Cast(field, CharField())) + .values_list("raw_field", flat=True) ) - return cursor.fetchone()[0] + + return queryset.first() def test_char_field_encrypted(self): plaintext = "Oh hi, test reader!" model = TestModel() model.char = plaintext - model.full_clean() model.save() - ciphertext = self.get_db_value("char", model.id) + ciphertext = self.get_db_value("char", model.pk) self.assertNotEqual(plaintext, ciphertext) self.assertTrue("test" not in ciphertext) - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.char, plaintext) def test_text_field_encrypted(self): @@ -40,15 +39,14 @@ class FieldTest(TestCase): model = TestModel() model.text = plaintext - model.full_clean() model.save() - ciphertext = self.get_db_value("text", model.id) + ciphertext = self.get_db_value("text", model.pk) self.assertNotEqual(plaintext, ciphertext) self.assertTrue("test" not in ciphertext) - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.text, plaintext) def test_datetime_field_encrypted(self): @@ -56,22 +54,20 @@ class FieldTest(TestCase): model = TestModel() model.datetime = plaintext - model.full_clean() model.save() - ciphertext = self.get_db_value("datetime", model.id) + ciphertext = self.get_db_value("datetime", model.pk) # Django's normal date serialization format - self.assertTrue(re.search("^\d\d\d\d-\d\d-\d\d", ciphertext) is None) + self.assertNotRegex(ciphertext, r"^\d\d\d\d-\d\d-\d\d") - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.datetime, plaintext) plaintext = "text" with self.assertRaises(ValidationError): model.datetime = plaintext - model.full_clean() model.save() def test_integer_field_encrypted(self): @@ -79,15 +75,14 @@ class FieldTest(TestCase): model = TestModel() model.integer = plaintext - model.full_clean() model.save() - ciphertext = self.get_db_value("integer", model.id) + ciphertext = self.get_db_value("integer", model.pk) self.assertNotEqual(plaintext, ciphertext) self.assertNotEqual(plaintext, str(ciphertext)) - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.integer, plaintext) # "IntegerField": (-2147483648, 2147483647) @@ -96,7 +91,6 @@ class FieldTest(TestCase): with self.assertRaises(ValidationError): model.integer = plaintext model.full_clean() - model.save() plaintext = "text" @@ -110,11 +104,10 @@ class FieldTest(TestCase): model = TestModel() model.date = plaintext - model.full_clean() model.save() - ciphertext = self.get_db_value("date", model.id) - fresh_model = TestModel.objects.get(id=model.id) + ciphertext = self.get_db_value("date", model.pk) + fresh_model = TestModel.objects.get(id=model.pk) self.assertNotEqual(ciphertext, plaintext.isoformat()) self.assertEqual(fresh_model.date, plaintext) @@ -134,19 +127,18 @@ class FieldTest(TestCase): model.full_clean() model.save() - ciphertext = self.get_db_value("floating", model.id) + ciphertext = self.get_db_value("floating", model.pk) self.assertNotEqual(plaintext, ciphertext) self.assertNotEqual(plaintext, str(ciphertext)) - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.floating, plaintext) plaintext = "text" with self.assertRaises(ValueError): model.floating = plaintext - model.full_clean() model.save() def test_email_field_encrypted(self): @@ -154,15 +146,14 @@ class FieldTest(TestCase): model = TestModel() model.email = plaintext - model.full_clean() model.save() - ciphertext = self.get_db_value("email", model.id) + ciphertext = self.get_db_value("email", model.pk) self.assertNotEqual(plaintext, ciphertext) self.assertTrue("aron" not in ciphertext) - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.email, plaintext) plaintext = "text" @@ -170,7 +161,6 @@ class FieldTest(TestCase): with self.assertRaises(ValidationError): model.email = plaintext model.full_clean() - model.save() def test_boolean_field_encrypted(self): plaintext = True @@ -180,7 +170,7 @@ class FieldTest(TestCase): model.full_clean() model.save() - ciphertext = self.get_db_value("boolean", model.id) + ciphertext = self.get_db_value("boolean", model.pk) self.assertNotEqual(plaintext, ciphertext) self.assertNotEqual(True, ciphertext) @@ -190,41 +180,41 @@ class FieldTest(TestCase): self.assertNotEqual(1, ciphertext) self.assertTrue(not isinstance(ciphertext, bool)) - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.boolean, plaintext) plaintext = "text" with self.assertRaises(ValidationError): model.boolean = plaintext - model.full_clean() model.save() def test_json_field_encrypted(self): - dict_values = {"key": "value", "list": ["nested", {"key": "val"}], "nested": {"child": "sibling"}} + dict_values = { + "key": "value", + "list": ["nested", {"key": "val"}], + "nested": {"child": "sibling"}, + } model = TestModel() model.json = dict_values - model.full_clean() model.save() - ciphertext = json.loads(self.get_db_value("json", model.id)) - + ciphertext = json.loads(self.get_db_value("json", model.pk)) self.assertNotEqual(dict_values, ciphertext) - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.json, dict_values) def test_json_field_retains_keys(self): plain_value = {"key": "value", "another_key": "some value"} - + model = TestModel() model.json = plain_value - model.full_clean() model.save() - ciphertext = json.loads(self.get_db_value("json", model.id)) + ciphertext = json.loads(self.get_db_value("json", model.pk)) self.assertEqual(plain_value.keys(), ciphertext.keys()) @@ -244,15 +234,15 @@ class RotatedSaltTestCase(TestCase): model.text = plaintext model.save() - ciphertext = FieldTest.get_db_value(self, "text", model.id) + ciphertext = FieldTest.get_db_value(self, "text", model.pk) self.assertNotEqual(plaintext, ciphertext) self.assertTrue("test" not in ciphertext) - fresh_model = TestModel.objects.get(id=model.id) + fresh_model = TestModel.objects.get(id=model.pk) self.assertEqual(fresh_model.text, plaintext) - old_record = TestModel.objects.get(id=self.original.id) + old_record = TestModel.objects.get(id=self.original.pk) self.assertEqual(fresh_model.text, old_record.text) self.assertNotEqual(