diff --git a/wagtail/contrib/postgres_search/backend.py b/wagtail/contrib/postgres_search/backend.py index 9084de260..eb7ca7306 100644 --- a/wagtail/contrib/postgres_search/backend.py +++ b/wagtail/contrib/postgres_search/backend.py @@ -16,10 +16,11 @@ from wagtail.wagtailsearch.backends.base import ( BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults) from wagtail.wagtailsearch.index import RelatedFields, SearchField from wagtail.wagtailsearch.query import And, MatchAll, Not, Or, PlainText, Term +from wagtail.wagtailsearch.utils import ADD, AND, OR from .models import IndexEntry from .utils import ( - ADD, AND, OR, WEIGHTS_VALUES, get_ancestors_content_types_pks, get_content_type_pk, + WEIGHTS_VALUES, get_ancestors_content_types_pks, get_content_type_pk, get_descendants_content_types_pks, get_postgresql_connections, get_weight, unidecode) diff --git a/wagtail/contrib/postgres_search/utils.py b/wagtail/contrib/postgres_search/utils.py index 0e3e25960..fd0dc06cb 100644 --- a/wagtail/contrib/postgres_search/utils.py +++ b/wagtail/contrib/postgres_search/utils.py @@ -1,7 +1,5 @@ from __future__ import absolute_import, division, unicode_literals -import operator -from functools import partial, reduce from itertools import zip_longest from django.apps import apps @@ -22,14 +20,6 @@ def get_postgresql_connections(): if connection.vendor == 'postgresql'] -# Reduce any iterable to a single value using a logical OR e.g. (a | b | ...) -OR = partial(reduce, operator.or_) -# Reduce any iterable to a single value using a logical AND e.g. (a & b & ...) -AND = partial(reduce, operator.and_) -# Reduce any iterable to a single value using an addition -ADD = partial(reduce, operator.add) - - def get_descendant_models(model): """ Returns all descendants of a model, including the model itself. diff --git a/wagtail/wagtailsearch/backends/db.py b/wagtail/wagtailsearch/backends/db.py index 30831b2a1..bcfd89ce2 100644 --- a/wagtail/wagtailsearch/backends/db.py +++ b/wagtail/wagtailsearch/backends/db.py @@ -1,16 +1,36 @@ from __future__ import absolute_import, unicode_literals +from warnings import warn + from django.db import models from django.db.models.expressions import Value from wagtail.wagtailsearch.backends.base import ( BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults) -from wagtail.wagtailsearch.query import MatchAll, PlainText +from wagtail.wagtailsearch.query import And, MatchAll, Not, Or, PlainText, Term +from wagtail.wagtailsearch.utils import AND, OR class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): DEFAULT_OPERATOR = 'and' + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields_names = list(self.get_fields_names()) + + def get_fields_names(self): + model = self.queryset.model + fields_names = self.fields or [field.field_name for field in + model.get_searchable_search_fields()] + # Check if the field exists (this will filter out indexed callables) + for field_name in fields_names: + try: + model._meta.get_field(field_name) + except models.fields.FieldDoesNotExist: + continue + else: + yield field_name + def _process_lookup(self, field, lookup, value): return models.Q(**{field.get_attname(self.queryset.model) + '__' + lookup: value}) @@ -29,57 +49,47 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): return q - def get_extra_q(self): - # Run _get_filters_from_queryset to test that no fields that are not - # a FilterField have been used in the query. - self._get_filters_from_queryset() + def build_single_term_filter(self, term): + term_query = models.Q() + for field_name in self.fields_names: + term_query |= models.Q(**{field_name + '__icontains': term}) + return term_query - q = models.Q() - model = self.queryset.model + def build_database_filter(self, query=None): + if query is None: + query = self.query if isinstance(self.query, MatchAll): - return q + return models.Q() - if not isinstance(self.query, PlainText): - raise NotImplementedError( - '`%s` is not supported by the database search backend.' - % self.query.__class__.__name__) - - # Get fields - fields = self.fields or [field.field_name for field in model.get_searchable_search_fields()] - - # Get terms - terms = self.query.query_string.split() - if not terms: - return model.objects.none() - - # Filter by terms - for term in terms: - term_query = models.Q() - for field_name in fields: - # Check if the field exists (this will filter out indexed callables) - try: - model._meta.get_field(field_name) - except models.fields.FieldDoesNotExist: - continue - - # Filter on this field - term_query |= models.Q(**{'%s__icontains' % field_name: term}) - - operator = self.query.operator - - if operator == 'or': - q |= term_query - elif operator == 'and': - q &= term_query - - return q + if isinstance(query, PlainText): + return self.build_database_filter(query.to_combined_terms()) + if isinstance(query, Term): + if query.boost != 1: + warn('Database search backend does not support term boosting.') + return self.build_single_term_filter(query.term) + if isinstance(query, Not): + return ~self.build_database_filter(query.subquery) + if isinstance(query, And): + return AND(self.build_database_filter(subquery) + for subquery in query.subqueries) + if isinstance(query, Or): + return OR(self.build_database_filter(subquery) + for subquery in query.subqueries) + raise NotImplementedError( + '`%s` is not supported by the database search backend.' + % self.query.__class__.__name__) class DatabaseSearchResults(BaseSearchResults): def get_queryset(self): queryset = self.query_compiler.queryset - q = self.query_compiler.get_extra_q() + + # Run _get_filters_from_queryset to test that no fields that are not + # a FilterField have been used in the query. + self.query_compiler._get_filters_from_queryset() + + q = self.query_compiler.build_database_filter() return queryset.filter(q).distinct()[self.start:self.stop] diff --git a/wagtail/wagtailsearch/query.py b/wagtail/wagtailsearch/query.py index 813c87e1b..de341ce44 100644 --- a/wagtail/wagtailsearch/query.py +++ b/wagtail/wagtailsearch/query.py @@ -52,7 +52,8 @@ class PlainText(SearchQuery): def to_combined_terms(self): return self.OPERATORS[self.operator]([ - Term(term) for term in self.query_string.split()]) + Term(term, boost=self.boost) + for term in self.query_string.split()]) class Term(SearchQuery): diff --git a/wagtail/wagtailsearch/tests/test_backends.py b/wagtail/wagtailsearch/tests/test_backends.py index 4ef4fbeca..66ffe6156 100644 --- a/wagtail/wagtailsearch/tests/test_backends.py +++ b/wagtail/wagtailsearch/tests/test_backends.py @@ -447,7 +447,7 @@ class BackendTests(WagtailTestUtils): 'JavaScript: The good parts'}) # Multiple word - results = self.backend.search(Term('Javascript Guide'), + results = self.backend.search(Term('Definitive Guide'), models.Book.objects.all()) self.assertSetEqual({r.title for r in results}, {'JavaScript: The Definitive Guide'}) diff --git a/wagtail/wagtailsearch/tests/test_db_backend.py b/wagtail/wagtailsearch/tests/test_db_backend.py index 3db2bd16c..5344bac77 100644 --- a/wagtail/wagtailsearch/tests/test_db_backend.py +++ b/wagtail/wagtailsearch/tests/test_db_backend.py @@ -44,32 +44,3 @@ class TestDBBackend(BackendTests, TestCase): @unittest.expectedFailure def test_same_rank_pages(self): super(TestDBBackend, self).test_same_rank_pages() - - # - # Query classes - # - - # Not implemented yet - @unittest.expectedFailure - def test_term(self): - super().test_term() - - # Not implemented yet - @unittest.expectedFailure - def test_and(self): - super().test_and() - - # Not implemented yet - @unittest.expectedFailure - def test_or(self): - super().test_or() - - # Not implemented yet - @unittest.expectedFailure - def test_not(self): - super().test_not() - - # Not implemented yet - @unittest.expectedFailure - def test_operators_combination(self): - super().test_operators_combination() diff --git a/wagtail/wagtailsearch/utils.py b/wagtail/wagtailsearch/utils.py index 495a58833..5adddaec9 100644 --- a/wagtail/wagtailsearch/utils.py +++ b/wagtail/wagtailsearch/utils.py @@ -1,7 +1,16 @@ from __future__ import absolute_import, unicode_literals +import operator import re import string +from functools import partial, reduce + +# Reduce any iterable to a single value using a logical OR e.g. (a | b | ...) +OR = partial(reduce, operator.or_) +# Reduce any iterable to a single value using a logical AND e.g. (a & b & ...) +AND = partial(reduce, operator.and_) +# Reduce any iterable to a single value using an addition +ADD = partial(reduce, operator.add) MAX_QUERY_STRING_LENGTH = 255