From 16b385da974402c20df7e7ef8a1a3b1595ba049e Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Fri, 20 Jun 2014 11:35:19 +0100 Subject: [PATCH] Slicing ESResults now returns a new ESResults object Previously, slicing an ESResults object made it run a query against ElasticSearch and return the results This commit changes this by making slice return a new ESResults object with start and stop limits applied. To get results, you now have to iterate the ESResults object. --- .../wagtailsearch/backends/elasticsearch.py | 62 ++++++++++++++++--- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index eecbd2f1d..6bd48aafe 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -116,24 +116,47 @@ class ElasticSearchResults(object): def __init__(self, backend, query): self.backend = backend self.query = query + self.start = 0 + self.stop = None - def _do_search(self, offset=0, limit=None): + def _set_limits(self, start=None, stop=None): + if stop is not None: + if self.stop is not None: + self.stop = min(self.stop, self.start + stop) + else: + self.stop = self.start + stop + + if start is not None: + if self.stop is not None: + self.start = min(self.stop, self.start + start) + else: + self.start = self.start + start + + def _clone(self): + klass = self.__class__ + new = klass(self.backend, self.query) + new.start = self.start + new.stop = self.stop + return new + + def _do_search(self): # Params for elasticsearch query params = dict( index=self.backend.es_index, body=dict(query=self.query.to_es()), _source=False, fields='pk', - from_=offset, + from_=self.start, ) - # Add limit if set - if limit is not None: - params['size'] = limit + # Add size if set + if self.stop is not None: + params['size'] = self.stop - self.start # Send to ElasticSearch hits = self.backend.es.search(**params) + # Get pks from results pks = [hit['fields']['pk'] for hit in hits['hits']['hits']] # ElasticSearch 1.x likes to pack pks into lists, unpack them if this has happened @@ -151,6 +174,7 @@ class ElasticSearchResults(object): return [results[str(pk)] for pk in pks if results[str(pk)]] def _do_count(self): + # Get query query = self.query.to_es() # Elasticsearch 1.x @@ -166,15 +190,33 @@ class ElasticSearchResults(object): body=query, ) - return count['count'] + # Get count + hit_count = count['count'] + + # Add limits + hit_count -= self.start + if self.stop is not None: + hit_count = min(hit_count, self.stop - self.start) + + return max(hit_count, 0) def __getitem__(self, key): + new = self._clone() + if isinstance(key, slice): - # Run query - return self._do_search(key.start, key.stop - key.start) + # Set limits + start = int(key.start) if key.start else None + stop = int(key.stop) if key.stop else None + new._set_limits(start, stop) + + return new else: - # Return a single item - return self._do_search(key, key + 1)[0] + new.start = key + new.stop = key + 1 + return list(new)[0] + + def __iter__(self): + return iter(self._do_search()) def __len__(self): return self._do_count()