mirror of
https://github.com/Hopiu/wagtail.git
synced 2026-05-11 16:53:10 +00:00
Make the search() method raise FieldErrors
Before, they were being raised at the point of evaluating the SearchResults set. This makes the errors get raised at the point of constructing the set so it's easier to catch and handle them.
This commit is contained in:
parent
aefa7b5469
commit
25901aad05
3 changed files with 57 additions and 26 deletions
|
|
@ -11,6 +11,20 @@ class FilterError(Exception):
|
|||
|
||||
|
||||
class FieldError(Exception):
|
||||
def __init__(self, *args, field_name=None, **kwargs):
|
||||
self.field_name = field_name
|
||||
super(FieldError, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SearchFieldError(FieldError):
|
||||
pass
|
||||
|
||||
|
||||
class FilterFieldError(FieldError):
|
||||
pass
|
||||
|
||||
|
||||
class OrderByFieldError(FieldError):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -39,18 +53,20 @@ class BaseSearchQuery:
|
|||
def _connect_filters(self, filters, connector, negated):
|
||||
raise NotImplementedError
|
||||
|
||||
def _process_filter(self, field_attname, lookup, value):
|
||||
def _process_filter(self, field_attname, lookup, value, check_only=False):
|
||||
# Get the field
|
||||
field = self._get_filterable_field(field_attname)
|
||||
|
||||
if field is None:
|
||||
raise FieldError(
|
||||
raise FilterFieldError(
|
||||
'Cannot filter search results with field "' + field_attname + '". Please add index.FilterField(\'' +
|
||||
field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.'
|
||||
field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
|
||||
field_name=field_attname
|
||||
)
|
||||
|
||||
# Process the lookup
|
||||
result = self._process_lookup(field, lookup, value)
|
||||
if not check_only:
|
||||
result = self._process_lookup(field, lookup, value)
|
||||
|
||||
if result is None:
|
||||
raise FilterError(
|
||||
|
|
@ -60,7 +76,7 @@ class BaseSearchQuery:
|
|||
|
||||
return result
|
||||
|
||||
def _get_filters_from_where_node(self, where_node):
|
||||
def _get_filters_from_where_node(self, where_node, check_only=False):
|
||||
# Check if this is a leaf node
|
||||
if isinstance(where_node, Lookup):
|
||||
field_attname = where_node.lhs.target.attname
|
||||
|
|
@ -72,7 +88,7 @@ class BaseSearchQuery:
|
|||
return
|
||||
|
||||
# Process the filter
|
||||
return self._process_filter(field_attname, lookup, value)
|
||||
return self._process_filter(field_attname, lookup, value, check_only=check_only)
|
||||
|
||||
elif isinstance(where_node, SubqueryConstraint):
|
||||
raise FilterError('Could not apply filter on search results: Subqueries are not allowed.')
|
||||
|
|
@ -81,9 +97,10 @@ class BaseSearchQuery:
|
|||
# Get child filters
|
||||
connector = where_node.connector
|
||||
child_filters = [self._get_filters_from_where_node(child) for child in where_node.children]
|
||||
child_filters = [child_filter for child_filter in child_filters if child_filter]
|
||||
|
||||
return self._connect_filters(child_filters, connector, where_node.negated)
|
||||
if not check_only:
|
||||
child_filters = [child_filter for child_filter in child_filters if child_filter]
|
||||
return self._connect_filters(child_filters, connector, where_node.negated)
|
||||
|
||||
else:
|
||||
raise FilterError('Could not apply filter on search results: Unknown where node: ' + str(type(where_node)))
|
||||
|
|
@ -105,13 +122,35 @@ class BaseSearchQuery:
|
|||
field = self._get_filterable_field(field_name)
|
||||
|
||||
if field is None:
|
||||
raise FieldError(
|
||||
raise OrderByFieldError(
|
||||
'Cannot sort search results with field "' + field_name + '". Please add index.FilterField(\'' +
|
||||
field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.'
|
||||
field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
|
||||
field_name=field_name
|
||||
)
|
||||
|
||||
yield reverse, field
|
||||
|
||||
def check(self):
|
||||
# Check search fields
|
||||
if self.fields:
|
||||
allowed_fields = {field.field_name for field in self.queryset.model.get_searchable_search_fields()}
|
||||
|
||||
for field_name in self.fields:
|
||||
if field_name not in allowed_fields:
|
||||
raise SearchFieldError(
|
||||
'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' +
|
||||
field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
|
||||
field_name=field_name
|
||||
)
|
||||
|
||||
# Check where clause
|
||||
# Raises FilterFieldError if an unindexed field is being filtered on
|
||||
self._get_filters_from_where_node(self.queryset.query.where, check_only=True)
|
||||
|
||||
# Check order by
|
||||
# Raises OrderByFieldError if an unindexed field is being used to order by
|
||||
list(self._get_order_by())
|
||||
|
||||
|
||||
class BaseSearchResults:
|
||||
def __init__(self, backend, query, prefetch_related=None):
|
||||
|
|
@ -268,17 +307,6 @@ class BaseSearchBackend:
|
|||
if query_string == "":
|
||||
return EmptySearchResults()
|
||||
|
||||
# Only fields that are indexed as a SearchField can be passed in fields
|
||||
if fields:
|
||||
allowed_fields = {field.field_name for field in model.get_searchable_search_fields()}
|
||||
|
||||
for field_name in fields:
|
||||
if field_name not in allowed_fields:
|
||||
raise FieldError(
|
||||
'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' +
|
||||
field_name + '\') to ' + model.__name__ + '.search_fields.'
|
||||
)
|
||||
|
||||
# Apply filters to queryset
|
||||
if filters:
|
||||
queryset = queryset.filter(**filters)
|
||||
|
|
@ -298,4 +326,8 @@ class BaseSearchBackend:
|
|||
search_query = self.query_class(
|
||||
queryset, query_string, fields=fields, operator=operator, order_by_relevance=order_by_relevance
|
||||
)
|
||||
|
||||
# Check the query
|
||||
search_query.check()
|
||||
|
||||
return self.results_class(self, search_query)
|
||||
|
|
|
|||
|
|
@ -74,9 +74,6 @@ class DatabaseSearchResults(BaseSearchResults):
|
|||
def _do_search(self):
|
||||
queryset = self.get_queryset()
|
||||
|
||||
# Call query._get_order_by so it can raise errors if a non-indexed field is used for ordering
|
||||
list(self.query._get_order_by())
|
||||
|
||||
if self._score_field:
|
||||
queryset = queryset.annotate(**{self._score_field: Value(None, output_field=models.FloatField())})
|
||||
|
||||
|
|
|
|||
|
|
@ -276,7 +276,9 @@ class Elasticsearch2SearchQuery(BaseSearchQuery):
|
|||
|
||||
fields.append(field_name)
|
||||
|
||||
self.fields = fields
|
||||
self.remapped_fields = fields
|
||||
else:
|
||||
self.remapped_fields = None
|
||||
|
||||
def _process_lookup(self, field, lookup, value):
|
||||
column_name = self.mapping.get_field_column_name(field)
|
||||
|
|
@ -371,7 +373,7 @@ class Elasticsearch2SearchQuery(BaseSearchQuery):
|
|||
|
||||
def get_inner_query(self):
|
||||
if self.query_string is not None:
|
||||
fields = self.fields or ['_all', '_partials']
|
||||
fields = self.remapped_fields or ['_all', '_partials']
|
||||
|
||||
if len(fields) == 1:
|
||||
if self.operator == 'or':
|
||||
|
|
|
|||
Loading…
Reference in a new issue