diff --git a/wagtail/search/backends/base.py b/wagtail/search/backends/base.py index db5bdb0ae..f7c4863db 100644 --- a/wagtail/search/backends/base.py +++ b/wagtail/search/backends/base.py @@ -11,6 +11,20 @@ class FilterError(Exception): class FieldError(Exception): + def __init__(self, *args, field_name=None, **kwargs): + self.field_name = field_name + super(FieldError, self).__init__(*args, **kwargs) + + +class SearchFieldError(FieldError): + pass + + +class FilterFieldError(FieldError): + pass + + +class OrderByFieldError(FieldError): pass @@ -39,18 +53,20 @@ class BaseSearchQuery: def _connect_filters(self, filters, connector, negated): raise NotImplementedError - def _process_filter(self, field_attname, lookup, value): + def _process_filter(self, field_attname, lookup, value, check_only=False): # Get the field field = self._get_filterable_field(field_attname) if field is None: - raise FieldError( + raise FilterFieldError( 'Cannot filter search results with field "' + field_attname + '". Please add index.FilterField(\'' + - field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.' + field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.', + field_name=field_attname ) # Process the lookup - result = self._process_lookup(field, lookup, value) + if not check_only: + result = self._process_lookup(field, lookup, value) if result is None: raise FilterError( @@ -60,7 +76,7 @@ class BaseSearchQuery: return result - def _get_filters_from_where_node(self, where_node): + def _get_filters_from_where_node(self, where_node, check_only=False): # Check if this is a leaf node if isinstance(where_node, Lookup): field_attname = where_node.lhs.target.attname @@ -72,7 +88,7 @@ class BaseSearchQuery: return # Process the filter - return self._process_filter(field_attname, lookup, value) + return self._process_filter(field_attname, lookup, value, check_only=check_only) elif isinstance(where_node, SubqueryConstraint): raise FilterError('Could not apply filter on search results: Subqueries are not allowed.') @@ -81,9 +97,10 @@ class BaseSearchQuery: # Get child filters connector = where_node.connector child_filters = [self._get_filters_from_where_node(child) for child in where_node.children] - child_filters = [child_filter for child_filter in child_filters if child_filter] - return self._connect_filters(child_filters, connector, where_node.negated) + if not check_only: + child_filters = [child_filter for child_filter in child_filters if child_filter] + return self._connect_filters(child_filters, connector, where_node.negated) else: raise FilterError('Could not apply filter on search results: Unknown where node: ' + str(type(where_node))) @@ -105,13 +122,35 @@ class BaseSearchQuery: field = self._get_filterable_field(field_name) if field is None: - raise FieldError( + raise OrderByFieldError( 'Cannot sort search results with field "' + field_name + '". Please add index.FilterField(\'' + - field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.' + field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.', + field_name=field_name ) yield reverse, field + def check(self): + # Check search fields + if self.fields: + allowed_fields = {field.field_name for field in self.queryset.model.get_searchable_search_fields()} + + for field_name in self.fields: + if field_name not in allowed_fields: + raise SearchFieldError( + 'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' + + field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.', + field_name=field_name + ) + + # Check where clause + # Raises FilterFieldError if an unindexed field is being filtered on + self._get_filters_from_where_node(self.queryset.query.where, check_only=True) + + # Check order by + # Raises OrderByFieldError if an unindexed field is being used to order by + list(self._get_order_by()) + class BaseSearchResults: def __init__(self, backend, query, prefetch_related=None): @@ -268,17 +307,6 @@ class BaseSearchBackend: if query_string == "": return EmptySearchResults() - # Only fields that are indexed as a SearchField can be passed in fields - if fields: - allowed_fields = {field.field_name for field in model.get_searchable_search_fields()} - - for field_name in fields: - if field_name not in allowed_fields: - raise FieldError( - 'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' + - field_name + '\') to ' + model.__name__ + '.search_fields.' - ) - # Apply filters to queryset if filters: queryset = queryset.filter(**filters) @@ -298,4 +326,8 @@ class BaseSearchBackend: search_query = self.query_class( queryset, query_string, fields=fields, operator=operator, order_by_relevance=order_by_relevance ) + + # Check the query + search_query.check() + return self.results_class(self, search_query) diff --git a/wagtail/search/backends/db.py b/wagtail/search/backends/db.py index b32f3edde..d1d1e68b5 100644 --- a/wagtail/search/backends/db.py +++ b/wagtail/search/backends/db.py @@ -74,9 +74,6 @@ class DatabaseSearchResults(BaseSearchResults): def _do_search(self): queryset = self.get_queryset() - # Call query._get_order_by so it can raise errors if a non-indexed field is used for ordering - list(self.query._get_order_by()) - if self._score_field: queryset = queryset.annotate(**{self._score_field: Value(None, output_field=models.FloatField())}) diff --git a/wagtail/search/backends/elasticsearch2.py b/wagtail/search/backends/elasticsearch2.py index 2b9f688db..2ecc1379c 100644 --- a/wagtail/search/backends/elasticsearch2.py +++ b/wagtail/search/backends/elasticsearch2.py @@ -276,7 +276,9 @@ class Elasticsearch2SearchQuery(BaseSearchQuery): fields.append(field_name) - self.fields = fields + self.remapped_fields = fields + else: + self.remapped_fields = None def _process_lookup(self, field, lookup, value): column_name = self.mapping.get_field_column_name(field) @@ -371,7 +373,7 @@ class Elasticsearch2SearchQuery(BaseSearchQuery): def get_inner_query(self): if self.query_string is not None: - fields = self.fields or ['_all', '_partials'] + fields = self.remapped_fields or ['_all', '_partials'] if len(fields) == 1: if self.operator == 'or':