From 0e2b621d39b3b54725d92ae20ad755f486a3a2a3 Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Wed, 23 Jul 2014 13:12:01 +0100 Subject: [PATCH] Added Django 1.7 support into Elasticsearch query building --- .../wagtailsearch/backends/elasticsearch.py | 219 ++++++++++-------- 1 file changed, 120 insertions(+), 99 deletions(-) diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index 8f8a78dbb..9e0d9b612 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -5,7 +5,13 @@ import json from six.moves.urllib.parse import urlparse from django.db import models -from django.db.models.sql.where import SubqueryConstraint +from django.db.models.sql.where import SubqueryConstraint, WhereNode + +# Django 1.7 lookups +try: + from django.db.models.lookups import Lookup +except ImportError: + Lookup = None from elasticsearch import Elasticsearch, NotFoundError, RequestError from elasticsearch.helpers import bulk @@ -125,118 +131,133 @@ class ElasticSearchQuery(object): self.query_string = query_string self.fields = fields + def _process_lookup(self, field_attname, lookup, value): + # Get field + field = dict( + (field.get_attname(self.queryset.model), field) + for field in self.queryset.model.get_filterable_search_fields() + ).get(field_attname, None) + + # Give error if the field doesn't exist + if field is None: + raise FieldError('Cannot filter ElasticSearch results with field "' + field_name + '". Please add FilterField(\'' + field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.') + + # Get the name of the field in the index + field_index_name = field.get_index_name(self.queryset.model) + + if lookup == 'exact': + if value is None: + return { + 'missing': { + 'field': field_index_name, + } + } + else: + return { + 'term': { + field_index_name: value, + } + } + + if lookup == 'isnull': + if value: + return { + 'missing': { + 'field': field_index_name, + } + } + else: + return { + 'not': { + 'missing': { + 'field': field_index_name, + } + } + } + + if lookup in ['startswith', 'prefix']: + return { + 'prefix': { + field_index_name: value, + } + } + + if lookup in ['gt', 'gte', 'lt', 'lte']: + return { + 'range': { + field_index_name: { + lookup: value, + } + } + } + + if lookup == 'range': + lower, upper = value + + return { + 'range': { + field_index_name: { + 'gte': lower, + 'lte': upper, + } + } + } + + if lookup == 'in': + return { + 'terms': { + field_index_name: value, + } + } + + raise FilterError('Could not apply filter on ElasticSearch results: "' + field_name + '__' + lookup + ' = ' + unicode(value) + '". Lookup "' + lookup + '"" not recognosed.') + def _get_filters_from_where(self, where_node): # Check if this is a leaf node - if isinstance(where_node, tuple): - field_name = where_node[0].col + if isinstance(where_node, tuple): # Django 1.6 and below + field_attname = where_node[0].col lookup = where_node[1] value = where_node[3] - # Get field - field = dict( - (field.get_attname(self.queryset.model), field) - for field in self.queryset.model.get_filterable_search_fields() - ).get(field_name, None) + # Process the filter + return self._process_lookup(field_attname, lookup, value) - # Give error if the field doesn't exist - if field is None: - raise FieldError('Cannot filter ElasticSearch results with field "' + field_name + '". Please add FilterField(\'' + field_name + '\') to ' + self.queryset.model.__name__ + '.search_fields.') + elif Lookup is not None and isinstance(where_node, Lookup): # Django 1.7 and above + field_attname = where_node.lhs.target.attname + lookup = where_node.lookup_name + value = where_node.rhs - # Get the name of the field in the index - field_index_name = field.get_index_name(self.queryset.model) + # Process the filter + return self._process_lookup(field_attname, lookup, value) - # Find lookup - if lookup == 'exact': - if value is None: - return { - 'missing': { - 'field': field_index_name, - } - } - else: - return { - 'term': { - field_index_name: value, - } - } - - if lookup == 'isnull': - if value: - return { - 'missing': { - 'field': field_index_name, - } - } - else: - return { - 'not': { - 'missing': { - 'field': field_index_name, - } - } - } - - if lookup in ['startswith', 'prefix']: - return { - 'prefix': { - field_index_name: value, - } - } - - if lookup in ['gt', 'gte', 'lt', 'lte']: - return { - 'range': { - field_index_name: { - lookup: value, - } - } - } - - if lookup == 'range': - lower, upper = value - - return { - 'range': { - field_index_name: { - 'gte': lower, - 'lte': upper, - } - } - } - - if lookup == 'in': - return { - 'terms': { - field_index_name: value, - } - } - - raise FilterError('Could not apply filter on ElasticSearch results: "' + field_name + '__' + lookup + ' = ' + unicode(value) + '". Lookup "' + lookup + '"" not recognosed.') elif isinstance(where_node, SubqueryConstraint): raise FilterError('Could not apply filter on ElasticSearch results: Subqueries are not allowed.') - # Get child filters - connector = where_node.connector - child_filters = [self._get_filters_from_where(child) for child in where_node.children] - child_filters = [child_filter for child_filter in child_filters if child_filter] + elif isinstance(where_node, WhereNode): + # Get child filters + connector = where_node.connector + child_filters = [self._get_filters_from_where(child) for child in where_node.children] + child_filters = [child_filter for child_filter in child_filters if child_filter] - # Connect them - if child_filters: - if len(child_filters) == 1: - filter_out = child_filters[0] - else: - filter_out = { - connector.lower(): [ - fil for fil in child_filters if fil is not None - ] - } + # Connect them + if child_filters: + if len(child_filters) == 1: + filter_out = child_filters[0] + else: + filter_out = { + connector.lower(): [ + fil for fil in child_filters if fil is not None + ] + } - if where_node.negated: - filter_out = { - 'not': filter_out - } + if where_node.negated: + filter_out = { + 'not': filter_out + } - return filter_out + return filter_out + else: + raise FilterError('Could not apply filter on ElasticSearch results: Unknown where node: ' + str(type(where_node))) def _get_filters(self): # Filters