Created BaseSearchResults class

Moved some logic from ElasticSearchResults into it
This commit is contained in:
Karl Hobley 2014-09-17 10:48:04 +01:00
parent 436116dc3d
commit ca17d062c3
2 changed files with 86 additions and 78 deletions

View file

@ -5,6 +5,90 @@ from wagtail.wagtailsearch.index import class_is_indexed
from wagtail.wagtailsearch.utils import normalise_query_string
class BaseSearchResults(object):
def __init__(self, backend, query, prefetch_related=None):
self.backend = backend
self.query = query
self.prefetch_related = prefetch_related
self.start = 0
self.stop = None
self._results_cache = None
self._count_cache = None
def _set_limits(self, start=None, stop=None):
if stop is not None:
if self.stop is not None:
self.stop = min(self.stop, self.start + stop)
else:
self.stop = self.start + stop
if start is not None:
if self.stop is not None:
self.start = min(self.stop, self.start + start)
else:
self.start = self.start + start
def _clone(self):
klass = self.__class__
new = klass(self.backend, self.query, prefetch_related=self.prefetch_related)
new.start = self.start
new.stop = self.stop
return new
def _do_search(self):
return NotImplemented
def _do_count(self):
return NotImplemented
def results(self):
if self._results_cache is None:
self._results_cache = self._do_search()
return self._results_cache
def count(self):
if self._count_cache is None:
if self._results_cache is not None:
self._count_cache = len(self._results_cache)
else:
self._count_cache = self._do_count()
return self._count_cache
def __getitem__(self, key):
new = self._clone()
if isinstance(key, slice):
# Set limits
start = int(key.start) if key.start else None
stop = int(key.stop) if key.stop else None
new._set_limits(start, stop)
# Copy results cache
if self._results_cache is not None:
new._results_cache = self._results_cache[key]
return new
else:
if self._results_cache is not None:
return self._results_cache[key]
new.start = key
new.stop = key + 1
return list(new)[0]
def __iter__(self):
return iter(self.results())
def __len__(self):
return len(self.results())
def __repr__(self):
data = list(self[:21])
if len(data) > 20:
data[-1] = "...(remaining elements truncated)..."
return repr(data)
class BaseSearch(object):
def __init__(self, params):
pass

View file

@ -17,7 +17,7 @@ except ImportError:
from elasticsearch import Elasticsearch, NotFoundError, RequestError
from elasticsearch.helpers import bulk
from wagtail.wagtailsearch.backends.base import BaseSearch
from wagtail.wagtailsearch.backends.base import BaseSearch, BaseSearchResults
from wagtail.wagtailsearch.index import Indexed, SearchField, FilterField, class_is_indexed
@ -326,36 +326,7 @@ class ElasticSearchQuery(object):
return json.dumps(self.to_es())
class ElasticSearchResults(object):
def __init__(self, backend, query, prefetch_related=None):
self.backend = backend
self.query = query
self.prefetch_related = prefetch_related
self.start = 0
self.stop = None
self._results_cache = None
self._count_cache = None
def _set_limits(self, start=None, stop=None):
if stop is not None:
if self.stop is not None:
self.stop = min(self.stop, self.start + stop)
else:
self.stop = self.start + stop
if start is not None:
if self.stop is not None:
self.start = min(self.stop, self.start + start)
else:
self.start = self.start + start
def _clone(self):
klass = self.__class__
new = klass(self.backend, self.query, prefetch_related=self.prefetch_related)
new.start = self.start
new.stop = self.stop
return new
class ElasticSearchResults(BaseSearchResults):
def _do_search(self):
# Params for elasticsearch query
params = dict(
@ -417,53 +388,6 @@ class ElasticSearchResults(object):
return max(hit_count, 0)
def results(self):
if self._results_cache is None:
self._results_cache = self._do_search()
return self._results_cache
def count(self):
if self._count_cache is None:
if self._results_cache is not None:
self._count_cache = len(self._results_cache)
else:
self._count_cache = self._do_count()
return self._count_cache
def __getitem__(self, key):
new = self._clone()
if isinstance(key, slice):
# Set limits
start = int(key.start) if key.start else None
stop = int(key.stop) if key.stop else None
new._set_limits(start, stop)
# Copy results cache
if self._results_cache is not None:
new._results_cache = self._results_cache[key]
return new
else:
if self._results_cache is not None:
return self._results_cache[key]
new.start = key
new.stop = key + 1
return list(new)[0]
def __iter__(self):
return iter(self.results())
def __len__(self):
return len(self.results())
def __repr__(self):
data = list(self[:21])
if len(data) > 20:
data[-1] = "...(remaining elements truncated)..."
return repr(data)
class ElasticSearch(BaseSearch):
def __init__(self, params):