Make the search() method raise FieldErrors

Before, they were being raised at the point of evaluating the
SearchResults set. This makes the errors get raised at the point of
constructing the set so it's easier to catch and handle them.
This commit is contained in:
Karl Hobley 2017-10-20 12:19:58 +01:00
parent aefa7b5469
commit 25901aad05
3 changed files with 57 additions and 26 deletions

View file

@ -11,6 +11,20 @@ class FilterError(Exception):
class FieldError(Exception):
def __init__(self, *args, field_name=None, **kwargs):
self.field_name = field_name
super(FieldError, self).__init__(*args, **kwargs)
class SearchFieldError(FieldError):
pass
class FilterFieldError(FieldError):
pass
class OrderByFieldError(FieldError):
pass
@ -39,18 +53,20 @@ class BaseSearchQuery:
def _connect_filters(self, filters, connector, negated):
raise NotImplementedError
def _process_filter(self, field_attname, lookup, value):
def _process_filter(self, field_attname, lookup, value, check_only=False):
# Get the field
field = self._get_filterable_field(field_attname)
if field is None:
raise FieldError(
raise FilterFieldError(
'Cannot filter search results with field "' + field_attname + '". Please add index.FilterField(\'' +
field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.'
field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
field_name=field_attname
)
# Process the lookup
result = self._process_lookup(field, lookup, value)
if not check_only:
result = self._process_lookup(field, lookup, value)
if result is None:
raise FilterError(
@ -60,7 +76,7 @@ class BaseSearchQuery:
return result
def _get_filters_from_where_node(self, where_node):
def _get_filters_from_where_node(self, where_node, check_only=False):
# Check if this is a leaf node
if isinstance(where_node, Lookup):
field_attname = where_node.lhs.target.attname
@ -72,7 +88,7 @@ class BaseSearchQuery:
return
# Process the filter
return self._process_filter(field_attname, lookup, value)
return self._process_filter(field_attname, lookup, value, check_only=check_only)
elif isinstance(where_node, SubqueryConstraint):
raise FilterError('Could not apply filter on search results: Subqueries are not allowed.')
@ -81,9 +97,10 @@ class BaseSearchQuery:
# Get child filters
connector = where_node.connector
child_filters = [self._get_filters_from_where_node(child) for child in where_node.children]
child_filters = [child_filter for child_filter in child_filters if child_filter]
return self._connect_filters(child_filters, connector, where_node.negated)
if not check_only:
child_filters = [child_filter for child_filter in child_filters if child_filter]
return self._connect_filters(child_filters, connector, where_node.negated)
else:
raise FilterError('Could not apply filter on search results: Unknown where node: ' + str(type(where_node)))
@ -105,13 +122,35 @@ class BaseSearchQuery:
field = self._get_filterable_field(field_name)
if field is None:
raise FieldError(
raise OrderByFieldError(
'Cannot sort search results with field "' + field_name + '". Please add index.FilterField(\'' +
field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.'
field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
field_name=field_name
)
yield reverse, field
def check(self):
# Check search fields
if self.fields:
allowed_fields = {field.field_name for field in self.queryset.model.get_searchable_search_fields()}
for field_name in self.fields:
if field_name not in allowed_fields:
raise SearchFieldError(
'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' +
field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
field_name=field_name
)
# Check where clause
# Raises FilterFieldError if an unindexed field is being filtered on
self._get_filters_from_where_node(self.queryset.query.where, check_only=True)
# Check order by
# Raises OrderByFieldError if an unindexed field is being used to order by
list(self._get_order_by())
class BaseSearchResults:
def __init__(self, backend, query, prefetch_related=None):
@ -268,17 +307,6 @@ class BaseSearchBackend:
if query_string == "":
return EmptySearchResults()
# Only fields that are indexed as a SearchField can be passed in fields
if fields:
allowed_fields = {field.field_name for field in model.get_searchable_search_fields()}
for field_name in fields:
if field_name not in allowed_fields:
raise FieldError(
'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' +
field_name + '\') to ' + model.__name__ + '.search_fields.'
)
# Apply filters to queryset
if filters:
queryset = queryset.filter(**filters)
@ -298,4 +326,8 @@ class BaseSearchBackend:
search_query = self.query_class(
queryset, query_string, fields=fields, operator=operator, order_by_relevance=order_by_relevance
)
# Check the query
search_query.check()
return self.results_class(self, search_query)

View file

@ -74,9 +74,6 @@ class DatabaseSearchResults(BaseSearchResults):
def _do_search(self):
queryset = self.get_queryset()
# Call query._get_order_by so it can raise errors if a non-indexed field is used for ordering
list(self.query._get_order_by())
if self._score_field:
queryset = queryset.annotate(**{self._score_field: Value(None, output_field=models.FloatField())})

View file

@ -276,7 +276,9 @@ class Elasticsearch2SearchQuery(BaseSearchQuery):
fields.append(field_name)
self.fields = fields
self.remapped_fields = fields
else:
self.remapped_fields = None
def _process_lookup(self, field, lookup, value):
column_name = self.mapping.get_field_column_name(field)
@ -371,7 +373,7 @@ class Elasticsearch2SearchQuery(BaseSearchQuery):
def get_inner_query(self):
if self.query_string is not None:
fields = self.fields or ['_all', '_partials']
fields = self.remapped_fields or ['_all', '_partials']
if len(fields) == 1:
if self.operator == 'or':