diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index 29141542a..5561a3cee 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -13,7 +13,6 @@ from wagtail.wagtailsearch.backends.base import BaseSearch from wagtail.wagtailsearch.indexed import Indexed - class ElasticSearchQuery(object): def __init__(self, model, query_string, fields=None, filters={}): self.model = model @@ -113,9 +112,10 @@ class ElasticSearchQuery(object): class ElasticSearchResults(object): - def __init__(self, backend, query): + def __init__(self, backend, query, prefetch_related=None): self.backend = backend self.query = query + self.prefetch_related = prefetch_related self.start = 0 self.stop = None self._results_cache = None @@ -136,7 +136,7 @@ class ElasticSearchResults(object): def _clone(self): klass = self.__class__ - new = klass(self.backend, self.query) + new = klass(self.backend, self.query, prefetch_related=self.prefetch_related) new.start = self.start new.stop = self.stop return new @@ -167,8 +167,15 @@ class ElasticSearchResults(object): # Initialise results dictionary results = dict((str(pk), None) for pk in pks) - # Find objects in database and add them to dict + # Get queryset queryset = self.query.model.objects.filter(pk__in=pks) + + # Add prefetch related + if self.prefetch_related: + for prefetch in self.prefetch_related: + queryset = queryset.prefetch_related(prefetch) + + # Find objects in database and add them to dict for obj in queryset: results[str(obj.pk)] = obj @@ -417,8 +424,5 @@ class ElasticSearch(BaseSearch): if not query_string: return [] - # Give deprecation warning if prefetch_related was used - warnings.warn("prefetch_related on search queries is no longer implemented. ", DeprecationWarning) - # Return search results - return ElasticSearchResults(self, ElasticSearchQuery(model, query_string, fields=fields, filters=filters)) + return ElasticSearchResults(self, ElasticSearchQuery(model, query_string, fields=fields, filters=filters), prefetch_related=prefetch_related)