diff --git a/dbtemplates/test_cases.py b/dbtemplates/test_cases.py index ab553e5..66eb674 100644 --- a/dbtemplates/test_cases.py +++ b/dbtemplates/test_cases.py @@ -1,6 +1,7 @@ import os import shutil import tempfile +from unittest import mock from django.conf import settings as django_settings from django.core.cache.backends.base import BaseCache @@ -151,3 +152,27 @@ class DbTemplatesTestCase(TestCase): def test_get_cache_name(self): self.assertEqual(get_cache_key('name with spaces'), 'dbtemplates::name-with-spaces::1') + + def test_cache_invalidation(self): + # Add t1 into the cache of site2 + self.t1.sites.add(self.site2) + with mock.patch('django.contrib.sites.models.SiteManager.get_current', + return_value=self.site2): + result = loader.get_template("base.html").render() + self.assertEqual(result, 'base') + + # Update content + self.t1.content = 'new content' + self.t1.save() + result = loader.get_template("base.html").render() + self.assertEqual(result, 'new content') + + # Cache invalidation should work across sites. + # Site2 should see the new content. + with mock.patch('django.contrib.sites.models.SiteManager.get_current', + return_value=self.site2): + result = loader.get_template("base.html").render() + self.assertEqual(result, 'new content') + + + diff --git a/dbtemplates/utils/cache.py b/dbtemplates/utils/cache.py index 923ee59..89039ab 100644 --- a/dbtemplates/utils/cache.py +++ b/dbtemplates/utils/cache.py @@ -21,9 +21,10 @@ def get_cache_backend(): cache = get_cache_backend() -def get_cache_key(name): - current_site = Site.objects.get_current() - return f"dbtemplates::{slugify(name)}::{current_site.pk}" +def get_cache_key(name, site=None): + if site is None: + site = Site.objects.get_current() + return f"dbtemplates::{slugify(name)}::{site.pk}" def get_cache_notfound_key(name): @@ -57,4 +58,5 @@ def remove_cached_template(instance, **kwargs): Called via Django's signals to remove cached templates, if the template in the database was changed or deleted. """ - cache.delete(get_cache_key(instance.name)) + for site in instance.sites.all(): + cache.delete(get_cache_key(instance.name, site=site))