From a5834513c31b2260287c2411ec3a3dcb25e11d12 Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Fri, 20 Jun 2014 11:51:38 +0100 Subject: [PATCH] Implemented results caching on ElasticSearchResults --- .../wagtailsearch/backends/elasticsearch.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index e9db7fbf1..29141542a 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -118,6 +118,8 @@ class ElasticSearchResults(object): self.query = query self.start = 0 self.stop = None + self._results_cache = None + self._count_cache = None def _set_limits(self, start=None, stop=None): if stop is not None: @@ -173,7 +175,7 @@ class ElasticSearchResults(object): # Return results in order given by ElasticSearch return [results[str(pk)] for pk in pks if results[str(pk)]] - def count(self): + def _do_count(self): # Get query query = self.query.to_es() @@ -200,6 +202,19 @@ class ElasticSearchResults(object): return max(hit_count, 0) + def results(self): + if self._results_cache is None: + self._results_cache = self._do_search() + return self._results_cache + + def count(self): + if self._count_cache is None: + if self._results_cache is not None: + self._count_cache = len(self._results_cache) + else: + self._count_cache = self._do_count() + return self._count_cache + def __getitem__(self, key): new = self._clone() @@ -209,17 +224,24 @@ class ElasticSearchResults(object): stop = int(key.stop) if key.stop else None new._set_limits(start, stop) + # Copy results cache + if self._results_cache is not None: + new._results_cache = self._results_cache[key] + return new else: + if self._results_cache is not None: + return self._results_cache[key] + new.start = key new.stop = key + 1 return list(new)[0] def __iter__(self): - return iter(self._do_search()) + return iter(self.results()) def __len__(self): - return len(self._do_search()) + return len(self.results()) def __repr__(self): data = list(self[:21])