Make the search() method raise FieldErrors

Before, they were being raised at the point of evaluating the
SearchResults set. This makes the errors get raised at the point of
constructing the set so it's easier to catch and handle them.
This commit is contained in:
Karl Hobley 2017-10-20 12:19:58 +01:00
parent aefa7b5469
commit 25901aad05
3 changed files with 57 additions and 26 deletions

View file

@ -11,6 +11,20 @@ class FilterError(Exception):
class FieldError(Exception): class FieldError(Exception):
def __init__(self, *args, field_name=None, **kwargs):
self.field_name = field_name
super(FieldError, self).__init__(*args, **kwargs)
class SearchFieldError(FieldError):
pass
class FilterFieldError(FieldError):
pass
class OrderByFieldError(FieldError):
pass pass
@ -39,18 +53,20 @@ class BaseSearchQuery:
def _connect_filters(self, filters, connector, negated): def _connect_filters(self, filters, connector, negated):
raise NotImplementedError raise NotImplementedError
def _process_filter(self, field_attname, lookup, value): def _process_filter(self, field_attname, lookup, value, check_only=False):
# Get the field # Get the field
field = self._get_filterable_field(field_attname) field = self._get_filterable_field(field_attname)
if field is None: if field is None:
raise FieldError( raise FilterFieldError(
'Cannot filter search results with field "' + field_attname + '". Please add index.FilterField(\'' + 'Cannot filter search results with field "' + field_attname + '". Please add index.FilterField(\'' +
field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.' field_attname + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
field_name=field_attname
) )
# Process the lookup # Process the lookup
result = self._process_lookup(field, lookup, value) if not check_only:
result = self._process_lookup(field, lookup, value)
if result is None: if result is None:
raise FilterError( raise FilterError(
@ -60,7 +76,7 @@ class BaseSearchQuery:
return result return result
def _get_filters_from_where_node(self, where_node): def _get_filters_from_where_node(self, where_node, check_only=False):
# Check if this is a leaf node # Check if this is a leaf node
if isinstance(where_node, Lookup): if isinstance(where_node, Lookup):
field_attname = where_node.lhs.target.attname field_attname = where_node.lhs.target.attname
@ -72,7 +88,7 @@ class BaseSearchQuery:
return return
# Process the filter # Process the filter
return self._process_filter(field_attname, lookup, value) return self._process_filter(field_attname, lookup, value, check_only=check_only)
elif isinstance(where_node, SubqueryConstraint): elif isinstance(where_node, SubqueryConstraint):
raise FilterError('Could not apply filter on search results: Subqueries are not allowed.') raise FilterError('Could not apply filter on search results: Subqueries are not allowed.')
@ -81,9 +97,10 @@ class BaseSearchQuery:
# Get child filters # Get child filters
connector = where_node.connector connector = where_node.connector
child_filters = [self._get_filters_from_where_node(child) for child in where_node.children] child_filters = [self._get_filters_from_where_node(child) for child in where_node.children]
child_filters = [child_filter for child_filter in child_filters if child_filter]
return self._connect_filters(child_filters, connector, where_node.negated) if not check_only:
child_filters = [child_filter for child_filter in child_filters if child_filter]
return self._connect_filters(child_filters, connector, where_node.negated)
else: else:
raise FilterError('Could not apply filter on search results: Unknown where node: ' + str(type(where_node))) raise FilterError('Could not apply filter on search results: Unknown where node: ' + str(type(where_node)))
@ -105,13 +122,35 @@ class BaseSearchQuery:
field = self._get_filterable_field(field_name) field = self._get_filterable_field(field_name)
if field is None: if field is None:
raise FieldError( raise OrderByFieldError(
'Cannot sort search results with field "' + field_name + '". Please add index.FilterField(\'' + 'Cannot sort search results with field "' + field_name + '". Please add index.FilterField(\'' +
field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.' field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
field_name=field_name
) )
yield reverse, field yield reverse, field
def check(self):
# Check search fields
if self.fields:
allowed_fields = {field.field_name for field in self.queryset.model.get_searchable_search_fields()}
for field_name in self.fields:
if field_name not in allowed_fields:
raise SearchFieldError(
'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' +
field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.',
field_name=field_name
)
# Check where clause
# Raises FilterFieldError if an unindexed field is being filtered on
self._get_filters_from_where_node(self.queryset.query.where, check_only=True)
# Check order by
# Raises OrderByFieldError if an unindexed field is being used to order by
list(self._get_order_by())
class BaseSearchResults: class BaseSearchResults:
def __init__(self, backend, query, prefetch_related=None): def __init__(self, backend, query, prefetch_related=None):
@ -268,17 +307,6 @@ class BaseSearchBackend:
if query_string == "": if query_string == "":
return EmptySearchResults() return EmptySearchResults()
# Only fields that are indexed as a SearchField can be passed in fields
if fields:
allowed_fields = {field.field_name for field in model.get_searchable_search_fields()}
for field_name in fields:
if field_name not in allowed_fields:
raise FieldError(
'Cannot search with field "' + field_name + '". Please add index.SearchField(\'' +
field_name + '\') to ' + model.__name__ + '.search_fields.'
)
# Apply filters to queryset # Apply filters to queryset
if filters: if filters:
queryset = queryset.filter(**filters) queryset = queryset.filter(**filters)
@ -298,4 +326,8 @@ class BaseSearchBackend:
search_query = self.query_class( search_query = self.query_class(
queryset, query_string, fields=fields, operator=operator, order_by_relevance=order_by_relevance queryset, query_string, fields=fields, operator=operator, order_by_relevance=order_by_relevance
) )
# Check the query
search_query.check()
return self.results_class(self, search_query) return self.results_class(self, search_query)

View file

@ -74,9 +74,6 @@ class DatabaseSearchResults(BaseSearchResults):
def _do_search(self): def _do_search(self):
queryset = self.get_queryset() queryset = self.get_queryset()
# Call query._get_order_by so it can raise errors if a non-indexed field is used for ordering
list(self.query._get_order_by())
if self._score_field: if self._score_field:
queryset = queryset.annotate(**{self._score_field: Value(None, output_field=models.FloatField())}) queryset = queryset.annotate(**{self._score_field: Value(None, output_field=models.FloatField())})

View file

@ -276,7 +276,9 @@ class Elasticsearch2SearchQuery(BaseSearchQuery):
fields.append(field_name) fields.append(field_name)
self.fields = fields self.remapped_fields = fields
else:
self.remapped_fields = None
def _process_lookup(self, field, lookup, value): def _process_lookup(self, field, lookup, value):
column_name = self.mapping.get_field_column_name(field) column_name = self.mapping.get_field_column_name(field)
@ -371,7 +373,7 @@ class Elasticsearch2SearchQuery(BaseSearchQuery):
def get_inner_query(self): def get_inner_query(self):
if self.query_string is not None: if self.query_string is not None:
fields = self.fields or ['_all', '_partials'] fields = self.remapped_fields or ['_all', '_partials']
if len(fields) == 1: if len(fields) == 1:
if self.operator == 'or': if self.operator == 'or':