diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index 8fc2950dd..a8c590de2 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -3,6 +3,7 @@ from __future__ import absolute_import import json from django.db import models +from django.db.models.query import QuerySet from elasticsearch import Elasticsearch, NotFoundError, RequestError from elasticsearch.helpers import bulk @@ -109,12 +110,123 @@ class ElasticSearchMapping(object): return '' % (self.model.__name__, ) +class FilterError(Exception): + pass + + +class FieldError(Exception): + pass + + class ElasticSearchQuery(object): - def __init__(self, model, query_string, fields=None, filters={}): - self.model = model + def __init__(self, queryset, query_string, fields=None): + self.queryset = queryset self.query_string = query_string - self.fields = fields or ['_all', '_partials'] - self.filters = filters + self.fields = fields or ['_all', 'partials'] + + def _get_filters_from_where(self, where_node): + # Check if this is a leaf node + if isinstance(where_node, tuple): + field_name = where_node[0].col + lookup = where_node[1] + value = where_node[3] + + # Get field + field = dict( + (field.get_attname(self.queryset.model), field) + for field in self.queryset.model.get_filterable_search_fields() + ).get(field_name, None) + + # Give error if the field doesn't exist + if field is None: + raise FieldError('Cannot filter ElasticSearch results with field "' + field_name + '". Please add FilterField(\'' + field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.') + + # Get the name of the field in the index + field_index_name = field.get_index_name(self.queryset.model) + + # Find lookup + if lookup == 'exact': + if value is None: + return { + 'missing': { + 'field': field_index_name, + } + } + else: + return { + 'term': { + field_index_name: value, + } + } + + if lookup == 'isnull': + if value: + return { + 'missing': { + 'field': field_index_name, + } + } + else: + return { + 'not': { + 'missing': { + 'field': field_index_name, + } + } + } + + if lookup in ['startswith', 'prefix']: + return { + 'prefix': { + field_index_name: value, + } + } + + if lookup in ['gt', 'gte', 'lt', 'lte']: + return { + 'range': { + field_index_name: { + lookup: value, + } + } + } + + if lookup == 'range': + lower, upper = value + + return { + 'range': { + field_index_name: { + 'gte': lower, + 'lte': upper, + } + } + } + + raise FilterError('Could not apply filter on ElasticSearch results "' + field_name + '__' + lookup + ' = ' + unicode(value) + '". Lookup "' + lookup + '"" not recognosed.') + + # Get child filters + connector = where_node.connector + child_filters = [self._get_filters_from_where(child) for child in where_node.children] + child_filters = [child_filter for child_filter in child_filters if child_filter] + + # Connect them + if child_filters: + if len(child_filters) == 1: + filter_out = child_filters[0] + else: + filter_out = { + connector.lower(): [ + fil for fil in child_filters if fil is not None + ] + } + + if where_node.negated: + filter_out = { + 'not': filter_out + } + + return filter_out def _get_filters(self): # Filters @@ -123,59 +235,14 @@ class ElasticSearchQuery(object): # Filter by content type filters.append({ 'prefix': { - 'content_type': self.model.indexed_get_content_type() + 'content_type': self.queryset.model.indexed_get_content_type() } }) - # Extra filters - if self.filters: - for key, value in self.filters.items(): - if '__' in key: - field, lookup = key.split('__') - else: - field = key - lookup = None - - if lookup is None: - if value is None: - filters.append({ - 'missing': { - 'field': field, - } - }) - else: - filters.append({ - 'term': { - field: value - } - }) - - if lookup in ['startswith', 'prefix']: - filters.append({ - 'prefix': { - field: value - } - }) - - if lookup in ['gt', 'gte', 'lt', 'lte']: - filters.append({ - 'range': { - field: { - lookup: value, - } - } - }) - - if lookup == 'range': - lower, upper = value - filters.append({ - 'range': { - field: { - 'gte': lower, - 'lte': upper, - } - } - }) + # Apply filters from queryset + queryset_filters = self._get_filters_from_where(self.queryset.query.where) + if queryset_filters: + filters.append(queryset_filters) return filters @@ -263,15 +330,8 @@ class ElasticSearchResults(object): # Initialise results dictionary results = dict((str(pk), None) for pk in pks) - # Get queryset - queryset = self.query.model.objects.filter(pk__in=pks) - - # Add prefetch related - if self.prefetch_related: - for prefetch in self.prefetch_related: - queryset = queryset.prefetch_related(prefetch) - # Find objects in database and add them to dict + queryset = self.query.queryset.filter(pk__in=pks) for obj in queryset: results[str(obj.pk)] = obj @@ -502,7 +562,15 @@ class ElasticSearch(BaseSearch): except NotFoundError: pass # Document doesn't exist, ignore this exception - def search(self, query_string, model, fields=None, filters={}, prefetch_related=[]): + def search(self, query_string, model_or_queryset, fields=None, filters={}, prefetch_related=[]): + # Find model/queryset + if isinstance(model_or_queryset, QuerySet): + model = model_or_queryset.model + queryset = model_or_queryset + else: + model = model_or_queryset + queryset = model_or_queryset.objects.all() + # Model must be a descendant of Indexed and be a django model if not issubclass(model, Indexed) or not issubclass(model, models.Model): return [] @@ -514,5 +582,13 @@ class ElasticSearch(BaseSearch): if not query_string: return [] + # Apply filters to queryset + if filters: + queryset = queryset.filter(**filters) + + # Prefetch related + for prefetch in prefetch_related: + queryset = queryset.prefetch_related(prefetch) + # Return search results - return ElasticSearchResults(self, ElasticSearchQuery(model, query_string, fields=fields, filters=filters), prefetch_related=prefetch_related) + return ElasticSearchResults(self, ElasticSearchQuery(queryset, query_string, fields=fields)) diff --git a/wagtail/wagtailsearch/indexed.py b/wagtail/wagtailsearch/indexed.py index b290f440b..a2cb54d14 100644 --- a/wagtail/wagtailsearch/indexed.py +++ b/wagtail/wagtailsearch/indexed.py @@ -112,13 +112,16 @@ class Indexed(object): @classmethod def get_searchable_search_fields(cls): - return filter(lambda field: field.searchable, cls.get_search_fields()) + return filter(lambda field: isinstance(field, SearchField), cls.get_search_fields()) + + @classmethod + def get_filterable_search_fields(cls): + return filter(lambda field: isinstance(field, FilterField), cls.get_search_fields()) indexed_fields = () class BaseField(object): - searchable = False suffix = '' def __init__(self, field_name, **kwargs): @@ -163,8 +166,6 @@ class BaseField(object): class SearchField(BaseField): - searchable = True - def __init__(self, field_name, boost=None, partial_match=False, **kwargs): super(SearchField, self).__init__(field_name, **kwargs) self.boost = boost diff --git a/wagtail/wagtailsearch/models.py b/wagtail/wagtailsearch/models.py index 509368463..7cb8546d7 100644 --- a/wagtail/wagtailsearch/models.py +++ b/wagtail/wagtailsearch/models.py @@ -91,7 +91,7 @@ class SearchTest(models.Model, indexed.Indexed): indexed.SearchField('title'), indexed.SearchField('content'), indexed.SearchField('callable_indexed_field'), - indexed.SearchField('live'), + indexed.FilterField('live'), ) def callable_indexed_field(self):