diff --git a/imagekit/specs/sourcegroups.py b/imagekit/specs/sourcegroups.py index 1693470..69f2daf 100644 --- a/imagekit/specs/sourcegroups.py +++ b/imagekit/specs/sourcegroups.py @@ -13,8 +13,10 @@ have two responsibilities: from django.db.models.signals import post_init, post_save, post_delete from django.utils.functional import wraps +import inspect from ..cachefiles import LazyImageCacheFile from ..signals import source_created, source_changed, source_deleted +from ..utils import get_nonabstract_descendants def ik_model_receiver(fn): @@ -25,8 +27,15 @@ def ik_model_receiver(fn): """ @wraps(fn) def receiver(self, sender, **kwargs): - if sender in (src.model_class for src in self._source_groups): - fn(self, sender=sender, **kwargs) + if not inspect.isclass(sender): + return + for src in self._source_groups: + if issubclass(sender, src.model_class): + fn(self, sender=sender, **kwargs) + + # If we find a match, return. We don't want to handle the signal + # more than once. + return return receiver @@ -76,7 +85,7 @@ class ModelSignalRouter(object): """ return dict((src.image_field, getattr(instance, src.image_field)) for - src in self._source_groups if src.model_class is instance.__class__) + src in self._source_groups if isinstance(instance, src.model_class)) @ik_model_receiver def post_save_receiver(self, sender, instance=None, created=False, raw=False, **kwargs): @@ -109,14 +118,14 @@ class ModelSignalRouter(object): """ for source_group in self._source_groups: - if source_group.model_class is model_class and source_group.image_field == attname: + if issubclass(model_class, source_group.model_class) and source_group.image_field == attname: signal.send(sender=source_group, source=file) class ImageFieldSourceGroup(object): """ A source group that repesents a particular field across all instances of a - model. + model and its subclasses. """ def __init__(self, model_class, image_field): @@ -128,11 +137,12 @@ class ImageFieldSourceGroup(object): """ A generator that returns the source files that this source group represents; in this case, a particular field of every instance of a - particular model. + particular model and its subclasses. """ - for instance in self.model_class.objects.all(): - yield getattr(instance, self.image_field) + for model in get_nonabstract_descendants(self.model_class): + for instance in model.objects.all().iterator(): + yield getattr(instance, self.image_field) class SourceGroupFilesGenerator(object): diff --git a/imagekit/utils.py b/imagekit/utils.py index 6f4fd52..91f3344 100644 --- a/imagekit/utils.py +++ b/imagekit/utils.py @@ -23,6 +23,15 @@ def _get_models(apps): return models +def get_nonabstract_descendants(model): + """ Returns all non-abstract descendants of the model. """ + if not model._meta.abstract: + yield model + for s in model.__subclasses__(): + for m in get_nonabstract_descendants(s): + yield m + + def get_by_qname(path, desc): try: dot = path.rindex('.') diff --git a/tests/models.py b/tests/models.py index 17e887b..7520a5c 100644 --- a/tests/models.py +++ b/tests/models.py @@ -26,17 +26,35 @@ class ProcessedImageFieldModel(models.Model): options={'quality': 90}, upload_to='p') +class CountingCacheFileStrategy(object): + def __init__(self): + self.before_access_count = 0 + self.on_source_changed_count = 0 + self.on_source_created_count = 0 + + def before_access(self, file): + self.before_access_count += 1 + + def on_source_changed(self, file): + self.on_source_changed_count += 1 + + def on_source_created(self, file): + self.on_source_created_count += 1 + + class AbstractImageModel(models.Model): original_image = models.ImageField(upload_to='photos') - abstract_class_spec = ImageSpecField() + abstract_class_spec = ImageSpecField(source='original_image', + format='JPEG', + cachefile_strategy=CountingCacheFileStrategy()) class Meta: abstract = True -class ConcreteImageModel1(AbstractImageModel): - first_spec = ImageSpecField() +class ConcreteImageModel(AbstractImageModel): + pass -class ConcreteImageModel2(AbstractImageModel): - second_spec = ImageSpecField() +class ConcreteImageModelSubclass(ConcreteImageModel): + pass diff --git a/tests/test_abstract_models.py b/tests/test_abstract_models.py new file mode 100644 index 0000000..a3ec774 --- /dev/null +++ b/tests/test_abstract_models.py @@ -0,0 +1,29 @@ +from django.core.files import File +from imagekit.signals import source_created +from imagekit.specs.sourcegroups import ImageFieldSourceGroup +from imagekit.utils import get_nonabstract_descendants +from nose.tools import eq_ +from . models import (AbstractImageModel, ConcreteImageModel, + ConcreteImageModelSubclass) +from .utils import get_image_file + + +def test_source_created_signal(): + source_group = ImageFieldSourceGroup(AbstractImageModel, 'original_image') + count = [0] + + def receiver(sender, *args, **kwargs): + if sender is source_group: + count[0] += 1 + + source_created.connect(receiver, dispatch_uid='test_source_created') + instance = ConcreteImageModel() + img = File(get_image_file()) + instance.original_image.save('test_source_created_signal.jpg', img) + + eq_(count[0], 1) + + +def test_nonabstract_descendants_generator(): + descendants = list(get_nonabstract_descendants(AbstractImageModel)) + eq_(descendants, [ConcreteImageModel, ConcreteImageModelSubclass])