Implements And/Or/Not/Term in postgres_search.

This commit is contained in:
Bertrand Bordage 2017-11-23 18:37:36 +01:00
parent 71a7ca5808
commit 27bcb3f38f
7 changed files with 182 additions and 26 deletions

View file

@ -2,7 +2,10 @@
from __future__ import absolute_import, unicode_literals
from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
from warnings import warn
from django.contrib.postgres.search import SearchQuery as PostgresSearchQuery
from django.contrib.postgres.search import SearchRank, SearchVector
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
from django.db.models import F, Manager, TextField, Value
from django.db.models.constants import LOOKUP_SEP
@ -12,13 +15,12 @@ from django.utils.encoding import force_text
from wagtail.wagtailsearch.backends.base import (
BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults)
from wagtail.wagtailsearch.index import RelatedFields, SearchField
from wagtail.wagtailsearch.query import MatchAll, PlainText
from wagtail.wagtailsearch.query import And, MatchAll, Not, Or, PlainText, Term
from .models import IndexEntry
from .utils import (
ADD, AND, OR, WEIGHTS_VALUES, get_ancestors_content_types_pks, get_content_type_pk,
get_descendants_content_types_pks, get_postgresql_connections, get_weight, keyword_split,
unidecode)
get_descendants_content_types_pks, get_postgresql_connections, get_weight, unidecode)
# TODO: Add autocomplete.
@ -174,12 +176,29 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
super(PostgresSearchQueryCompiler, self).__init__(*args, **kwargs)
self.search_fields = self.queryset.model.get_search_fields()
def get_search_query(self, config):
combine = OR if self.query.operator == 'or' else AND
search_terms = keyword_split(unidecode(self.query.query_string))
if not search_terms:
return SearchQuery('')
return combine(SearchQuery(q, config=config) for q in search_terms)
def build_database_query(self, query=None, config=None):
if query is None:
query = self.query
if isinstance(query, PlainText):
return self.build_database_query(query.to_combined_terms(), config)
if isinstance(query, Term):
# TODO: Find a way to use the term boosting.
if query.boost != 1:
warn('PostgreSQL search backend '
'does not support term boosting for now.')
return PostgresSearchQuery(unidecode(query.term), config=config)
if isinstance(query, Not):
return ~self.build_database_query(query.subquery, config)
if isinstance(query, And):
return AND(self.build_database_query(subquery, config)
for subquery in query.subqueries)
if isinstance(query, Or):
return OR(self.build_database_query(subquery, config)
for subquery in query.subqueries)
raise NotImplementedError(
'`%s` is not supported by the PostgreSQL search backend.'
% self.query.__class__.__name__)
def get_boost(self, field_name, fields=None):
if fields is None:
@ -198,15 +217,11 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
return field.boost
def search(self, config, start, stop):
# TODO: Handle MatchAll nested inside other search query classes.
if isinstance(self.query, MatchAll):
return self.queryset[start:stop]
if not isinstance(self.query, PlainText):
raise NotImplementedError(
'%s is not supported by the PostgreSQL search backend.'
% self.query.__class__)
search_query = self.get_search_query(config=config)
search_query = self.build_database_query(config=config)
queryset = self.queryset
query = queryset.query
if self.fields is None:

View file

@ -42,8 +42,8 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler):
if not isinstance(self.query, PlainText):
raise NotImplementedError(
'%s is not supported by the database search backend.'
% self.query.__class__)
'`%s` is not supported by the database search backend.'
% self.query.__class__.__name__)
# Get fields
fields = self.fields or [field.field_name for field in model.get_searchable_search_fields()]

View file

@ -377,8 +377,8 @@ class Elasticsearch2SearchQueryCompiler(BaseSearchQueryCompiler):
if not isinstance(self.query, PlainText):
raise NotImplementedError(
'%s is not supported by the Elasticsearch search backend.'
% self.query.__class__)
'`%s` is not supported by the Elasticsearch search backend.'
% self.query.__class__.__name__)
fields = self.fields or ['_all', '_partials']
operator = self.query.operator

View file

@ -36,27 +36,39 @@ class MatchAll(SearchQuery):
class PlainText(SearchQuery):
def __init__(self, query_string: str, operator: str = None,
boost: float = 1.0):
OPERATORS = {
'and': And,
'or': Or,
}
DEFAULT_OPERATOR = 'and'
def __init__(self, query_string: str, operator: str = DEFAULT_OPERATOR,
boost: float = 1):
self.query_string = query_string
if operator.lower() not in self.OPERATORS:
raise ValueError("`operator` must be either 'or' or 'and'.")
self.operator = operator
self.boost = boost
def to_combined_terms(self):
return self.OPERATORS[self.operator]([
Term(term) for term in self.query_string.split()])
class Term(SearchQuery):
def __init__(self, term: str, boost: float = 1.0):
def __init__(self, term: str, boost: float = 1):
self.term = term
self.boost = boost
class Prefix(SearchQuery):
def __init__(self, prefix: str, boost: float = 1.0):
def __init__(self, prefix: str, boost: float = 1):
self.prefix = prefix
self.boost = boost
class Fuzzy(SearchQuery):
def __init__(self, term: str, max_distance: float = 3, boost: float = 1.0):
def __init__(self, term: str, max_distance: float = 3, boost: float = 1):
self.term = term
self.max_distance = max_distance
self.boost = boost

View file

@ -17,7 +17,7 @@ from wagtail.wagtailsearch.backends import (
InvalidSearchBackendError, get_search_backend, get_search_backends)
from wagtail.wagtailsearch.backends.base import FieldError
from wagtail.wagtailsearch.backends.db import DatabaseSearchBackend
from wagtail.wagtailsearch.query import MATCH_ALL
from wagtail.wagtailsearch.query import MATCH_ALL, And, Not, Or, PlainText, Term
class BackendTests(WagtailTestUtils):
@ -430,6 +430,87 @@ class BackendTests(WagtailTestUtils):
"The Fellowship of the Ring"
])
#
# Query classes
#
def test_match_all(self):
results = self.backend.search(MATCH_ALL, models.Book.objects.all())
self.assertEqual(len(results), 13)
def test_term(self):
# Single word
results = self.backend.search(Term('Javascript'),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'JavaScript: The Definitive Guide',
'JavaScript: The good parts'})
# Multiple word
results = self.backend.search(Term('Javascript Guide'),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'JavaScript: The Definitive Guide'})
def test_plain_text(self):
# Single word
results = self.backend.search(PlainText('Javascript'),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'JavaScript: The Definitive Guide',
'JavaScript: The good parts'})
# Multiple words (OR operator)
results = self.backend.search(PlainText('Javascript Definitive',
operator='or'),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'JavaScript: The Definitive Guide',
'JavaScript: The good parts'})
# Multiple words (AND operator)
results = self.backend.search(PlainText('Javascript Definitive',
operator='and'),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'JavaScript: The Definitive Guide'})
def test_and(self):
results = self.backend.search(And([Term('Javascript'),
Term('Definitive')]),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'JavaScript: The Definitive Guide'})
def test_or(self):
results = self.backend.search(Or([Term('Hobbit'), Term('Towers')]),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'The Hobbit', 'The Two Towers'})
def test_not(self):
all_other_titles = {
'A Clash of Kings',
'A Game of Thrones',
'A Storm of Swords',
'Foundation',
'Learning Python',
'The Hobbit',
'The Two Towers',
'The Fellowship of the Ring',
'The Return of the King',
'The Rust Programming Language',
'Two Scoops of Django 1.11',
}
results = self.backend.search(Not(PlainText('Javascript')),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results}, all_other_titles)
results = self.backend.search(~PlainText('Javascript'),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results}, all_other_titles)
@override_settings(
WAGTAILSEARCH_BACKENDS={

View file

@ -44,3 +44,27 @@ class TestDBBackend(BackendTests, TestCase):
@unittest.expectedFailure
def test_same_rank_pages(self):
super(TestDBBackend, self).test_same_rank_pages()
#
# Query classes
#
# Not implemented yet
@unittest.expectedFailure
def test_term(self):
super().test_term()
# Not implemented yet
@unittest.expectedFailure
def test_and(self):
super().test_and()
# Not implemented yet
@unittest.expectedFailure
def test_or(self):
super().test_or()
# Not implemented yet
@unittest.expectedFailure
def test_not(self):
super().test_not()

View file

@ -37,6 +37,30 @@ class TestElasticsearch2SearchBackend(BackendTests, ElasticsearchCommonSearchBac
def test_delete(self):
super(TestElasticsearch2SearchBackend, self).test_delete()
#
# Query classes
#
# Not implemented yet
@unittest.expectedFailure
def test_term(self):
super().test_term()
# Not implemented yet
@unittest.expectedFailure
def test_and(self):
super().test_and()
# Not implemented yet
@unittest.expectedFailure
def test_or(self):
super().test_or()
# Not implemented yet
@unittest.expectedFailure
def test_not(self):
super().test_not()
class TestElasticsearch2SearchQuery(TestCase):
def assertDictEqual(self, a, b):