From adac76f533c68f7ad54717c86cd886f34965c813 Mon Sep 17 00:00:00 2001 From: blag Date: Tue, 10 Jun 2025 02:14:14 -0600 Subject: [PATCH] Optionally connect signal and modernize tests --- dbtemplates/models.py | 27 +++---- dbtemplates/test_cases.py | 158 +++++++++++++++++++------------------- 2 files changed, 95 insertions(+), 90 deletions(-) diff --git a/dbtemplates/models.py b/dbtemplates/models.py index aa87dcb..6a9d35d 100644 --- a/dbtemplates/models.py +++ b/dbtemplates/models.py @@ -54,13 +54,6 @@ class Template(models.Model): except TemplateDoesNotExist: pass - def save(self, *args, **kwargs): - # If content is empty look for a template with the given name and - # populate the template instance with its content. - if settings.DBTEMPLATES_AUTO_POPULATE_CONTENT and not self.content: - self.populate() - super().save(*args, **kwargs) - def add_default_site(instance, **kwargs): """ @@ -68,13 +61,21 @@ def add_default_site(instance, **kwargs): in the database was added or changed, only if DBTEMPLATES_ADD_DEFAULT_SITE setting is set. """ - if not settings.DBTEMPLATES_ADD_DEFAULT_SITE: - return - current_site = Site.objects.get_current() - if current_site not in instance.sites.all(): - instance.sites.add(current_site) + instance.sites.add(Site.objects.get_current()) -signals.post_save.connect(add_default_site, sender=Template) +def populate_empty_content(instance, **kwargs): + # If content is empty look for a template with the given name and + # populate the template instance with its content. + if not instance.content: + instance.populate() + + +if settings.DBTEMPLATES_ADD_DEFAULT_SITE: + signals.post_save.connect(add_default_site, sender=Template) + +if settings.DBTEMPLATES_AUTO_POPULATE_CONTENT: + signals.pre_save.connect(populate_empty_content, sender=Template) + signals.post_save.connect(add_template_to_cache, sender=Template) signals.pre_delete.connect(remove_cached_template, sender=Template) diff --git a/dbtemplates/test_cases.py b/dbtemplates/test_cases.py index 062e2e5..d29f376 100644 --- a/dbtemplates/test_cases.py +++ b/dbtemplates/test_cases.py @@ -1,18 +1,22 @@ import os import shutil import tempfile +from contextlib import contextmanager +from pathlib import Path from unittest import mock from django.conf import settings as django_settings from django.core.cache.backends.base import BaseCache from django.core.management import call_command +from django.db.models.signals import post_save from django.template import loader, TemplateDoesNotExist -from django.test import TestCase +from django.test import TestCase, modify_settings, override_settings +from django.test.signals import receiver, setting_changed from django.contrib.sites.models import Site from dbtemplates.conf import settings -from dbtemplates.models import Template +from dbtemplates.models import Template, add_default_site from dbtemplates.utils.cache import get_cache_backend, get_cache_key from dbtemplates.utils.template import (get_template_source, check_template_syntax) @@ -20,15 +24,32 @@ from dbtemplates.management.commands.sync_templates import (FILES_TO_DATABASE, DATABASE_TO_FILES) -class DbTemplatesTestCase(TestCase): - def setUp(self): - self.old_TEMPLATES = settings.TEMPLATES - if 'dbtemplates.loader.Loader' not in settings.TEMPLATES: - loader.template_source_loaders = None - settings.TEMPLATES = list(settings.TEMPLATES) + [ - 'dbtemplates.loader.Loader' - ] +@receiver(setting_changed) +def handle_add_default_site(sender, setting, value, **kwargs): + if setting == "DBTEMPLATES_ADD_DEFAULT_SITE": + if value: + post_save.connect(add_default_site, sender=Template) + else: + post_save.disconnect(add_default_site, sender=Template) + +@contextmanager +def temptemplate(name: str, cleanup: bool = True): + temp_template_dir = Path(tempfile.mkdtemp('dbtemplates')) + temp_template_path = temp_template_dir / name + try: + yield temp_template_path + finally: + shutil.rmtree(temp_template_dir) + + +class DbTemplatesTestCase(TestCase): + @modify_settings( + TEMPLATES={ + "append": "dbtemplates.loader.Loader", + }, + ) + def setUp(self): self.site1, created1 = Site.objects.get_or_create( domain="example.com", name="example.com") self.site2, created2 = Site.objects.get_or_create( @@ -39,48 +60,35 @@ class DbTemplatesTestCase(TestCase): name='sub.html', content='sub') self.t2.sites.add(self.site2) - def tearDown(self): - loader.template_source_loaders = None - settings.TEMPLATES = self.old_TEMPLATES - def test_basics(self): - self.assertEqual(list(self.t1.sites.all()), [self.site1]) - self.assertTrue("base" in self.t1.content) - self.assertEqual(list(Template.objects.filter(sites=self.site1)), - [self.t1, self.t2]) - self.assertEqual(list(self.t2.sites.all()), [self.site1, self.site2]) + self.assertQuerySetEqual(self.t1.sites.all(), Site.objects.filter(id=self.site1.id)) + self.assertIn("base", self.t1.content) + self.assertQuerySetEqual(Template.objects.filter(sites=self.site1), + Template.objects.filter(id__in=[self.t1.id, self.t2.id])) + self.assertQuerySetEqual(self.t2.sites.all(), Site.objects.filter(id__in=[self.site1.id, self.site2.id])) + @override_settings(DBTEMPLATES_ADD_DEFAULT_SITE=False) def test_empty_sites(self): - old_add_default_site = settings.DBTEMPLATES_ADD_DEFAULT_SITE - try: - settings.DBTEMPLATES_ADD_DEFAULT_SITE = False - self.t3 = Template.objects.create( - name='footer.html', content='footer') - self.assertEqual(list(self.t3.sites.all()), []) - finally: - settings.DBTEMPLATES_ADD_DEFAULT_SITE = old_add_default_site + self.t3 = Template.objects.create( + name='footer.html', content='footer') + self.assertQuerySetEqual(self.t3.sites.all(), self.t3.sites.none()) + @override_settings(DBTEMPLATES_ADD_DEFAULT_SITE=False) def test_load_templates_sites(self): - old_add_default_site = settings.DBTEMPLATES_ADD_DEFAULT_SITE - old_site_id = django_settings.SITE_ID - try: - settings.DBTEMPLATES_ADD_DEFAULT_SITE = False - t_site1 = Template.objects.create( - name='copyright.html', content='(c) example.com') - t_site1.sites.add(self.site1) - t_site2 = Template.objects.create( - name='copyright.html', content='(c) example.org') - t_site2.sites.add(self.site2) + t_site1 = Template.objects.create( + name='copyright.html', content='(c) example.com') + t_site1.sites.add(self.site1) + t_site2 = Template.objects.create( + name='copyright.html', content='(c) example.org') + t_site2.sites.add(self.site2) - django_settings.SITE_ID = Site.objects.create( - domain="example.net", name="example.net").id + new_site = Site.objects.create( + domain="example.net", name="example.net") + with self.settings(SITE_ID=new_site.id): Site.objects.clear_cache() self.assertRaises(TemplateDoesNotExist, loader.get_template, "copyright.html") - finally: - django_settings.SITE_ID = old_site_id - settings.DBTEMPLATES_ADD_DEFAULT_SITE = old_add_default_site def test_load_templates(self): result = loader.get_template("base.html").render() @@ -90,8 +98,8 @@ class DbTemplatesTestCase(TestCase): def test_error_templates_creation(self): call_command('create_error_templates', force=True, verbosity=0) - self.assertEqual(list(Template.objects.filter(sites=self.site1)), - list(Template.objects.filter())) + self.assertQuerySetEqual(Template.objects.filter(sites=self.site1), + Template.objects.filter()) self.assertTrue(Template.objects.filter(name='404.html').exists()) def test_automatic_sync(self): @@ -101,42 +109,38 @@ class DbTemplatesTestCase(TestCase): def test_sync_templates(self): old_template_dirs = settings.TEMPLATES[0].get('DIRS', []) - temp_template_dir = tempfile.mkdtemp('dbtemplates') - temp_template_path = os.path.join(temp_template_dir, 'temp_test.html') - temp_template = open(temp_template_path, 'w', encoding='utf-8') - try: - temp_template.write('temp test') - settings.TEMPLATES[0]['DIRS'] = (temp_template_dir,) - # these works well if is not settings patched at runtime - # for supporting django < 1.7 tests we must patch dirs in runtime - from dbtemplates.management.commands import sync_templates - sync_templates.DIRS = settings.TEMPLATES[0]['DIRS'] + with temptemplate('temp_test.html') as temp_template_path: + with open(temp_template_path, 'w', encoding='utf-8') as temp_template: + temp_template.write('temp test') + try: + settings.TEMPLATES[0]['DIRS'] = (temp_template_path.parent,) + # these works well if is not settings patched at runtime + # for supporting django < 1.7 tests we must patch dirs in runtime + from dbtemplates.management.commands import sync_templates + sync_templates.DIRS = settings.TEMPLATES[0]['DIRS'] - self.assertFalse( - Template.objects.filter(name='temp_test.html').exists()) - call_command('sync_templates', force=True, - verbosity=0, overwrite=FILES_TO_DATABASE) - self.assertTrue( - Template.objects.filter(name='temp_test.html').exists()) + self.assertFalse( + Template.objects.filter(name='temp_test.html').exists()) + call_command('sync_templates', force=True, + verbosity=0, overwrite=FILES_TO_DATABASE) + self.assertTrue( + Template.objects.filter(name='temp_test.html').exists()) - t = Template.objects.get(name='temp_test.html') - t.content = 'temp test modified' - t.save() - call_command('sync_templates', force=True, - verbosity=0, overwrite=DATABASE_TO_FILES) - self.assertEqual('temp test modified', - open(temp_template_path, - encoding='utf-8').read()) + t = Template.objects.get(name='temp_test.html') + t.content = 'temp test modified' + t.save() + call_command('sync_templates', force=True, + verbosity=0, overwrite=DATABASE_TO_FILES) + with open(temp_template_path, encoding='utf-8') as f: + self.assertEqual('temp test modified', f.read()) - call_command('sync_templates', force=True, verbosity=0, - delete=True, overwrite=DATABASE_TO_FILES) - self.assertTrue(os.path.exists(temp_template_path)) - self.assertFalse( - Template.objects.filter(name='temp_test.html').exists()) - finally: - temp_template.close() - settings.TEMPLATES[0]['DIRS'] = old_template_dirs - shutil.rmtree(temp_template_dir) + call_command('sync_templates', force=True, verbosity=0, + delete=True, overwrite=DATABASE_TO_FILES) + self.assertTrue(os.path.exists(temp_template_path)) + self.assertFalse( + Template.objects.filter(name='temp_test.html').exists()) + finally: + settings.TEMPLATES[0]['DIRS'] = old_template_dirs def test_get_cache(self): self.assertTrue(isinstance(get_cache_backend(), BaseCache))