From 25901aad05997223dddb9fd7b32550d31071a594 Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Fri, 20 Oct 2017 12:19:58 +0100 Subject: [PATCH] Make the search() method raise FieldErrors Before, they were being raised at the point of evaluating the SearchResults set. This makes the errors get raised at the point of constructing the set so it's easier to catch and handle them. --- wagtail/search/backends/base.py | 74 ++++++++++++++++------- wagtail/search/backends/db.py | 3 - wagtail/search/backends/elasticsearch2.py | 6 +- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/wagtail/search/backends/base.py b/wagtail/search/backends/base.py index db5bdb0ae..f7c4863db 100644 --- a/wagtail/search/backends/base.py +++ b/wagtail/search/backends/base.py @@ -11,6 +11,20 @@ class FilterError(Exception): class FieldError(Exception): + def __init__(self, *args, field_name=None, **kwargs): + self.field_name = field_name + super(FieldError, self).__init__(*args, **kwargs) + + +class SearchFieldError(FieldError): + pass + + +class FilterFieldError(FieldError): + pass + + +class OrderByFieldError(FieldError): pass @@ -39,18 +53,20 @@ class BaseSearchQuery: def _connect_filters(self, filters, connector, negated): raise NotImplementedError - def _process_filter(self, field_attname, lookup, value): + def _process_filter(self, field_attname, lookup, value, check_only=False): # Get the field field = self._get_filterable_field(field_attname) if field is None: - raise FieldError( + raise FilterFieldError( 'Cannot filter search results with field "' + field_attname + '". Please add index.FilterField(\'' + - field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.' + field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.', + field_name=field_attname ) # Process the lookup - result = self._process_lookup(field, lookup, value) + if not check_only: + result = self._process_lookup(field, lookup, value) if result is None: raise FilterError( @@ -60,7 +76,7 @@ class BaseSearchQuery: return result - def _get_filters_from_where_node(self, where_node): + def _get_filters_from_where_node(self, where_node, check_only=False): # Check if this is a leaf node if isinstance(where_node, Lookup): field_attname = where_node.lhs.target.attname @@ -72,7 +88,7 @@ class BaseSearchQuery: return # Process the filter - return self._process_filter(field_attname, lookup, value) + return self._process_filter(field_attname, lookup, value, check_only=check_only) elif isinstance(where_node, SubqueryConstraint): raise FilterError('Could not apply filter on search results: Subqueries are not allowed.') @@ -81,9 +97,10 @@ class BaseSearchQuery: # Get child filters connector = where_node.connector child_filters = [self._get_filters_from_where_node(child) for child in where_node.children] - child_filters = [child_filter for child_filter in child_filters if child_filter] - return self._connect_filters(child_filters, connector, where_node.negated) + if not check_only: + child_filters = [child_filter for child_filter in child_filters if child_filter] + return self._connect_filters(child_filters, connector, where_node.negated) else: raise FilterError('Could not apply filter on search results: Unknown where node: ' + str(type(where_node))) @@ -105,13 +122,35 @@ class BaseSearchQuery: field = self._get_filterable_field(field_name) if field is None: - raise FieldError( + raise OrderByFieldError( 'Cannot sort search results with field "' + field_name + '". Please add index.FilterField(\'' + - field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.' + field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.', + field_name=field_name ) yield reverse, field + def check(self): + # Check search fields + if self.fields: + allowed_fields = {field.field_name for field in self.queryset.model.get_searchable_search_fields()} + + for field_name in self.fields: + if field_name not in allowed_fields: + raise SearchFieldError( + 'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' + + field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.', + field_name=field_name + ) + + # Check where clause + # Raises FilterFieldError if an unindexed field is being filtered on + self._get_filters_from_where_node(self.queryset.query.where, check_only=True) + + # Check order by + # Raises OrderByFieldError if an unindexed field is being used to order by + list(self._get_order_by()) + class BaseSearchResults: def __init__(self, backend, query, prefetch_related=None): @@ -268,17 +307,6 @@ class BaseSearchBackend: if query_string == "": return EmptySearchResults() - # Only fields that are indexed as a SearchField can be passed in fields - if fields: - allowed_fields = {field.field_name for field in model.get_searchable_search_fields()} - - for field_name in fields: - if field_name not in allowed_fields: - raise FieldError( - 'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' + - field_name + '\') to ' + model.__name__ + '.search_fields.' - ) - # Apply filters to queryset if filters: queryset = queryset.filter(**filters) @@ -298,4 +326,8 @@ class BaseSearchBackend: search_query = self.query_class( queryset, query_string, fields=fields, operator=operator, order_by_relevance=order_by_relevance ) + + # Check the query + search_query.check() + return self.results_class(self, search_query) diff --git a/wagtail/search/backends/db.py b/wagtail/search/backends/db.py index b32f3edde..d1d1e68b5 100644 --- a/wagtail/search/backends/db.py +++ b/wagtail/search/backends/db.py @@ -74,9 +74,6 @@ class DatabaseSearchResults(BaseSearchResults): def _do_search(self): queryset = self.get_queryset() - # Call query._get_order_by so it can raise errors if a non-indexed field is used for ordering - list(self.query._get_order_by()) - if self._score_field: queryset = queryset.annotate(**{self._score_field: Value(None, output_field=models.FloatField())}) diff --git a/wagtail/search/backends/elasticsearch2.py b/wagtail/search/backends/elasticsearch2.py index 2b9f688db..2ecc1379c 100644 --- a/wagtail/search/backends/elasticsearch2.py +++ b/wagtail/search/backends/elasticsearch2.py @@ -276,7 +276,9 @@ class Elasticsearch2SearchQuery(BaseSearchQuery): fields.append(field_name) - self.fields = fields + self.remapped_fields = fields + else: + self.remapped_fields = None def _process_lookup(self, field, lookup, value): column_name = self.mapping.get_field_column_name(field) @@ -371,7 +373,7 @@ class Elasticsearch2SearchQuery(BaseSearchQuery): def get_inner_query(self): if self.query_string is not None: - fields = self.fields or ['_all', '_partials'] + fields = self.remapped_fields or ['_all', '_partials'] if len(fields) == 1: if self.operator == 'or':