diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index d626d1d1c..434a4e380 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -11,14 +11,12 @@ from wagtail.wagtailsearch.indexed import Indexed import string -class ElasticSearchResults(object): - def __init__(self, backend, model, query_string, fields=None, filters={}, prefetch_related=[]): - self.backend = backend +class ElasticSearchQuery(object): + def __init__(self, model, query_string, fields=None, filters={}): self.model = model self.query_string = query_string - self.fields = fields + self.fields = fields or ['_all'] self.filters = filters - self.prefetch_related = prefetch_related def _get_filters(self): # Filters @@ -83,7 +81,7 @@ class ElasticSearchResults(object): return filters - def _get_query(self): + def to_es(self): # Query query = { 'query_string': { @@ -107,26 +105,37 @@ class ElasticSearchResults(object): } } - def _get_results_pks(self, offset=0, limit=None): - query = self._get_query() - query['from'] = offset - if limit is not None: - query['size'] = limit - hits = self.backend.es.search( +class ElasticSearchResults(object): + def __init__(self, backend, query, prefetch_related=[]): + self.backend = backend + self.query = query + self.prefetch_related = prefetch_related + + def _get_results_pks(self, offset=0, limit=None): + # Params for elasticsearch query + params = dict( index=self.backend.es_index, - body=dict(query=query), + body=dict(query=self.query.to_es()), _source=False, fields='pk', + from_=offset, ) + # Add limit if set + if limit is not None: + params['size'] = limit + + # Send to ElasticSearch + hits = self.backend.es.search(**params) + pks = [hit['fields']['pk'] for hit in hits['hits']['hits']] # ElasticSearch 1.x likes to pack pks into lists, unpack them if this has happened return [pk[0] if isinstance(pk, list) else pk for pk in pks] def _get_count(self): - query = self._get_query() + query = self.query.to_es() # Elasticsearch 1.x count = self.backend.es.count( @@ -157,7 +166,7 @@ class ElasticSearchResults(object): pk_list.append(pk) # Get results - results = self.model.objects.filter(pk__in=pk_list) + results = self.query.model.objects.filter(pk__in=pk_list) # Prefetch related for prefetch in self.prefetch_related: @@ -174,7 +183,7 @@ class ElasticSearchResults(object): else: # Return a single item pk = self._get_results_pks(key, key + 1)[0] - return self.model.objects.get(pk=pk) + return self.query.model.objects.get(pk=pk) def __len__(self): return self._get_count() @@ -348,4 +357,4 @@ class ElasticSearch(BaseSearch): return [] # Return search results - return ElasticSearchResults(self, model, query_string, fields=fields, filters=filters, prefetch_related=prefetch_related) + return ElasticSearchResults(self, ElasticSearchQuery(model, query_string, fields=fields, filters=filters), prefetch_related=prefetch_related)