diff --git a/wagtail/contrib/postgres_search/backend.py b/wagtail/contrib/postgres_search/backend.py index 9fe538980..9084de260 100644 --- a/wagtail/contrib/postgres_search/backend.py +++ b/wagtail/contrib/postgres_search/backend.py @@ -2,7 +2,10 @@ from __future__ import absolute_import, unicode_literals -from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector +from warnings import warn + +from django.contrib.postgres.search import SearchQuery as PostgresSearchQuery +from django.contrib.postgres.search import SearchRank, SearchVector from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction from django.db.models import F, Manager, TextField, Value from django.db.models.constants import LOOKUP_SEP @@ -12,13 +15,12 @@ from django.utils.encoding import force_text from wagtail.wagtailsearch.backends.base import ( BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults) from wagtail.wagtailsearch.index import RelatedFields, SearchField -from wagtail.wagtailsearch.query import MatchAll, PlainText +from wagtail.wagtailsearch.query import And, MatchAll, Not, Or, PlainText, Term from .models import IndexEntry from .utils import ( ADD, AND, OR, WEIGHTS_VALUES, get_ancestors_content_types_pks, get_content_type_pk, - get_descendants_content_types_pks, get_postgresql_connections, get_weight, keyword_split, - unidecode) + get_descendants_content_types_pks, get_postgresql_connections, get_weight, unidecode) # TODO: Add autocomplete. @@ -174,12 +176,29 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): super(PostgresSearchQueryCompiler, self).__init__(*args, **kwargs) self.search_fields = self.queryset.model.get_search_fields() - def get_search_query(self, config): - combine = OR if self.query.operator == 'or' else AND - search_terms = keyword_split(unidecode(self.query.query_string)) - if not search_terms: - return SearchQuery('') - return combine(SearchQuery(q, config=config) for q in search_terms) + def build_database_query(self, query=None, config=None): + if query is None: + query = self.query + + if isinstance(query, PlainText): + return self.build_database_query(query.to_combined_terms(), config) + if isinstance(query, Term): + # TODO: Find a way to use the term boosting. + if query.boost != 1: + warn('PostgreSQL search backend ' + 'does not support term boosting for now.') + return PostgresSearchQuery(unidecode(query.term), config=config) + if isinstance(query, Not): + return ~self.build_database_query(query.subquery, config) + if isinstance(query, And): + return AND(self.build_database_query(subquery, config) + for subquery in query.subqueries) + if isinstance(query, Or): + return OR(self.build_database_query(subquery, config) + for subquery in query.subqueries) + raise NotImplementedError( + '`%s` is not supported by the PostgreSQL search backend.' + % self.query.__class__.__name__) def get_boost(self, field_name, fields=None): if fields is None: @@ -198,15 +217,11 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): return field.boost def search(self, config, start, stop): + # TODO: Handle MatchAll nested inside other search query classes. if isinstance(self.query, MatchAll): return self.queryset[start:stop] - if not isinstance(self.query, PlainText): - raise NotImplementedError( - '%s is not supported by the PostgreSQL search backend.' - % self.query.__class__) - - search_query = self.get_search_query(config=config) + search_query = self.build_database_query(config=config) queryset = self.queryset query = queryset.query if self.fields is None: diff --git a/wagtail/wagtailsearch/backends/db.py b/wagtail/wagtailsearch/backends/db.py index cc2ae42bd..30831b2a1 100644 --- a/wagtail/wagtailsearch/backends/db.py +++ b/wagtail/wagtailsearch/backends/db.py @@ -42,8 +42,8 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): if not isinstance(self.query, PlainText): raise NotImplementedError( - '%s is not supported by the database search backend.' - % self.query.__class__) + '`%s` is not supported by the database search backend.' + % self.query.__class__.__name__) # Get fields fields = self.fields or [field.field_name for field in model.get_searchable_search_fields()] diff --git a/wagtail/wagtailsearch/backends/elasticsearch2.py b/wagtail/wagtailsearch/backends/elasticsearch2.py index 320479f1a..f265774fd 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch2.py +++ b/wagtail/wagtailsearch/backends/elasticsearch2.py @@ -377,8 +377,8 @@ class Elasticsearch2SearchQueryCompiler(BaseSearchQueryCompiler): if not isinstance(self.query, PlainText): raise NotImplementedError( - '%s is not supported by the Elasticsearch search backend.' - % self.query.__class__) + '`%s` is not supported by the Elasticsearch search backend.' + % self.query.__class__.__name__) fields = self.fields or ['_all', '_partials'] operator = self.query.operator diff --git a/wagtail/wagtailsearch/query.py b/wagtail/wagtailsearch/query.py index 3df8e4bb3..813c87e1b 100644 --- a/wagtail/wagtailsearch/query.py +++ b/wagtail/wagtailsearch/query.py @@ -36,27 +36,39 @@ class MatchAll(SearchQuery): class PlainText(SearchQuery): - def __init__(self, query_string: str, operator: str = None, - boost: float = 1.0): + OPERATORS = { + 'and': And, + 'or': Or, + } + DEFAULT_OPERATOR = 'and' + + def __init__(self, query_string: str, operator: str = DEFAULT_OPERATOR, + boost: float = 1): self.query_string = query_string + if operator.lower() not in self.OPERATORS: + raise ValueError("`operator` must be either 'or' or 'and'.") self.operator = operator self.boost = boost + def to_combined_terms(self): + return self.OPERATORS[self.operator]([ + Term(term) for term in self.query_string.split()]) + class Term(SearchQuery): - def __init__(self, term: str, boost: float = 1.0): + def __init__(self, term: str, boost: float = 1): self.term = term self.boost = boost class Prefix(SearchQuery): - def __init__(self, prefix: str, boost: float = 1.0): + def __init__(self, prefix: str, boost: float = 1): self.prefix = prefix self.boost = boost class Fuzzy(SearchQuery): - def __init__(self, term: str, max_distance: float = 3, boost: float = 1.0): + def __init__(self, term: str, max_distance: float = 3, boost: float = 1): self.term = term self.max_distance = max_distance self.boost = boost diff --git a/wagtail/wagtailsearch/tests/test_backends.py b/wagtail/wagtailsearch/tests/test_backends.py index 8e79b4876..d3295b1f7 100644 --- a/wagtail/wagtailsearch/tests/test_backends.py +++ b/wagtail/wagtailsearch/tests/test_backends.py @@ -17,7 +17,7 @@ from wagtail.wagtailsearch.backends import ( InvalidSearchBackendError, get_search_backend, get_search_backends) from wagtail.wagtailsearch.backends.base import FieldError from wagtail.wagtailsearch.backends.db import DatabaseSearchBackend -from wagtail.wagtailsearch.query import MATCH_ALL +from wagtail.wagtailsearch.query import MATCH_ALL, And, Not, Or, PlainText, Term class BackendTests(WagtailTestUtils): @@ -430,6 +430,87 @@ class BackendTests(WagtailTestUtils): "The Fellowship of the Ring" ]) + # + # Query classes + # + + def test_match_all(self): + results = self.backend.search(MATCH_ALL, models.Book.objects.all()) + self.assertEqual(len(results), 13) + + def test_term(self): + # Single word + results = self.backend.search(Term('Javascript'), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'JavaScript: The Definitive Guide', + 'JavaScript: The good parts'}) + + # Multiple word + results = self.backend.search(Term('Javascript Guide'), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'JavaScript: The Definitive Guide'}) + + def test_plain_text(self): + # Single word + results = self.backend.search(PlainText('Javascript'), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'JavaScript: The Definitive Guide', + 'JavaScript: The good parts'}) + + # Multiple words (OR operator) + results = self.backend.search(PlainText('Javascript Definitive', + operator='or'), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'JavaScript: The Definitive Guide', + 'JavaScript: The good parts'}) + + # Multiple words (AND operator) + results = self.backend.search(PlainText('Javascript Definitive', + operator='and'), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'JavaScript: The Definitive Guide'}) + + def test_and(self): + results = self.backend.search(And([Term('Javascript'), + Term('Definitive')]), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'JavaScript: The Definitive Guide'}) + + def test_or(self): + results = self.backend.search(Or([Term('Hobbit'), Term('Towers')]), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'The Hobbit', 'The Two Towers'}) + + def test_not(self): + all_other_titles = { + 'A Clash of Kings', + 'A Game of Thrones', + 'A Storm of Swords', + 'Foundation', + 'Learning Python', + 'The Hobbit', + 'The Two Towers', + 'The Fellowship of the Ring', + 'The Return of the King', + 'The Rust Programming Language', + 'Two Scoops of Django 1.11', + } + + results = self.backend.search(Not(PlainText('Javascript')), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, all_other_titles) + + results = self.backend.search(~PlainText('Javascript'), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, all_other_titles) + @override_settings( WAGTAILSEARCH_BACKENDS={ diff --git a/wagtail/wagtailsearch/tests/test_db_backend.py b/wagtail/wagtailsearch/tests/test_db_backend.py index 5344bac77..47ce49f5f 100644 --- a/wagtail/wagtailsearch/tests/test_db_backend.py +++ b/wagtail/wagtailsearch/tests/test_db_backend.py @@ -44,3 +44,27 @@ class TestDBBackend(BackendTests, TestCase): @unittest.expectedFailure def test_same_rank_pages(self): super(TestDBBackend, self).test_same_rank_pages() + + # + # Query classes + # + + # Not implemented yet + @unittest.expectedFailure + def test_term(self): + super().test_term() + + # Not implemented yet + @unittest.expectedFailure + def test_and(self): + super().test_and() + + # Not implemented yet + @unittest.expectedFailure + def test_or(self): + super().test_or() + + # Not implemented yet + @unittest.expectedFailure + def test_not(self): + super().test_not() diff --git a/wagtail/wagtailsearch/tests/test_elasticsearch2_backend.py b/wagtail/wagtailsearch/tests/test_elasticsearch2_backend.py index b82d8bb2d..e9759b6db 100644 --- a/wagtail/wagtailsearch/tests/test_elasticsearch2_backend.py +++ b/wagtail/wagtailsearch/tests/test_elasticsearch2_backend.py @@ -37,6 +37,30 @@ class TestElasticsearch2SearchBackend(BackendTests, ElasticsearchCommonSearchBac def test_delete(self): super(TestElasticsearch2SearchBackend, self).test_delete() + # + # Query classes + # + + # Not implemented yet + @unittest.expectedFailure + def test_term(self): + super().test_term() + + # Not implemented yet + @unittest.expectedFailure + def test_and(self): + super().test_and() + + # Not implemented yet + @unittest.expectedFailure + def test_or(self): + super().test_or() + + # Not implemented yet + @unittest.expectedFailure + def test_not(self): + super().test_not() + class TestElasticsearch2SearchQuery(TestCase): def assertDictEqual(self, a, b):