diff --git a/wagtail/wagtailsearch/backends/base.py b/wagtail/wagtailsearch/backends/base.py index 29dc9161e..beb63365e 100644 --- a/wagtail/wagtailsearch/backends/base.py +++ b/wagtail/wagtailsearch/backends/base.py @@ -5,6 +5,90 @@ from wagtail.wagtailsearch.index import class_is_indexed from wagtail.wagtailsearch.utils import normalise_query_string +class BaseSearchResults(object): + def __init__(self, backend, query, prefetch_related=None): + self.backend = backend + self.query = query + self.prefetch_related = prefetch_related + self.start = 0 + self.stop = None + self._results_cache = None + self._count_cache = None + + def _set_limits(self, start=None, stop=None): + if stop is not None: + if self.stop is not None: + self.stop = min(self.stop, self.start + stop) + else: + self.stop = self.start + stop + + if start is not None: + if self.stop is not None: + self.start = min(self.stop, self.start + start) + else: + self.start = self.start + start + + def _clone(self): + klass = self.__class__ + new = klass(self.backend, self.query, prefetch_related=self.prefetch_related) + new.start = self.start + new.stop = self.stop + return new + + def _do_search(self): + return NotImplemented + + def _do_count(self): + return NotImplemented + + def results(self): + if self._results_cache is None: + self._results_cache = self._do_search() + return self._results_cache + + def count(self): + if self._count_cache is None: + if self._results_cache is not None: + self._count_cache = len(self._results_cache) + else: + self._count_cache = self._do_count() + return self._count_cache + + def __getitem__(self, key): + new = self._clone() + + if isinstance(key, slice): + # Set limits + start = int(key.start) if key.start else None + stop = int(key.stop) if key.stop else None + new._set_limits(start, stop) + + # Copy results cache + if self._results_cache is not None: + new._results_cache = self._results_cache[key] + + return new + else: + if self._results_cache is not None: + return self._results_cache[key] + + new.start = key + new.stop = key + 1 + return list(new)[0] + + def __iter__(self): + return iter(self.results()) + + def __len__(self): + return len(self.results()) + + def __repr__(self): + data = list(self[:21]) + if len(data) > 20: + data[-1] = "...(remaining elements truncated)..." + return repr(data) + + class BaseSearch(object): def __init__(self, params): pass diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index 7a1eaf6f1..f053b5721 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -17,7 +17,7 @@ except ImportError: from elasticsearch import Elasticsearch, NotFoundError, RequestError from elasticsearch.helpers import bulk -from wagtail.wagtailsearch.backends.base import BaseSearch +from wagtail.wagtailsearch.backends.base import BaseSearch, BaseSearchResults from wagtail.wagtailsearch.index import Indexed, SearchField, FilterField, class_is_indexed @@ -326,36 +326,7 @@ class ElasticSearchQuery(object): return json.dumps(self.to_es()) -class ElasticSearchResults(object): - def __init__(self, backend, query, prefetch_related=None): - self.backend = backend - self.query = query - self.prefetch_related = prefetch_related - self.start = 0 - self.stop = None - self._results_cache = None - self._count_cache = None - - def _set_limits(self, start=None, stop=None): - if stop is not None: - if self.stop is not None: - self.stop = min(self.stop, self.start + stop) - else: - self.stop = self.start + stop - - if start is not None: - if self.stop is not None: - self.start = min(self.stop, self.start + start) - else: - self.start = self.start + start - - def _clone(self): - klass = self.__class__ - new = klass(self.backend, self.query, prefetch_related=self.prefetch_related) - new.start = self.start - new.stop = self.stop - return new - +class ElasticSearchResults(BaseSearchResults): def _do_search(self): # Params for elasticsearch query params = dict( @@ -417,53 +388,6 @@ class ElasticSearchResults(object): return max(hit_count, 0) - def results(self): - if self._results_cache is None: - self._results_cache = self._do_search() - return self._results_cache - - def count(self): - if self._count_cache is None: - if self._results_cache is not None: - self._count_cache = len(self._results_cache) - else: - self._count_cache = self._do_count() - return self._count_cache - - def __getitem__(self, key): - new = self._clone() - - if isinstance(key, slice): - # Set limits - start = int(key.start) if key.start else None - stop = int(key.stop) if key.stop else None - new._set_limits(start, stop) - - # Copy results cache - if self._results_cache is not None: - new._results_cache = self._results_cache[key] - - return new - else: - if self._results_cache is not None: - return self._results_cache[key] - - new.start = key - new.stop = key + 1 - return list(new)[0] - - def __iter__(self): - return iter(self.results()) - - def __len__(self): - return len(self.results()) - - def __repr__(self): - data = list(self[:21]) - if len(data) > 20: - data[-1] = "...(remaining elements truncated)..." - return repr(data) - class ElasticSearch(BaseSearch): def __init__(self, params):