mirror of
https://github.com/Hopiu/wagtail.git
synced 2026-05-11 16:53:10 +00:00
Implements And/Or/Not/Term in postgres_search.
This commit is contained in:
parent
71a7ca5808
commit
27bcb3f38f
7 changed files with 182 additions and 26 deletions
|
|
@ -2,7 +2,10 @@
|
|||
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
|
||||
from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
|
||||
from warnings import warn
|
||||
|
||||
from django.contrib.postgres.search import SearchQuery as PostgresSearchQuery
|
||||
from django.contrib.postgres.search import SearchRank, SearchVector
|
||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
|
||||
from django.db.models import F, Manager, TextField, Value
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
|
|
@ -12,13 +15,12 @@ from django.utils.encoding import force_text
|
|||
from wagtail.wagtailsearch.backends.base import (
|
||||
BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults)
|
||||
from wagtail.wagtailsearch.index import RelatedFields, SearchField
|
||||
from wagtail.wagtailsearch.query import MatchAll, PlainText
|
||||
from wagtail.wagtailsearch.query import And, MatchAll, Not, Or, PlainText, Term
|
||||
|
||||
from .models import IndexEntry
|
||||
from .utils import (
|
||||
ADD, AND, OR, WEIGHTS_VALUES, get_ancestors_content_types_pks, get_content_type_pk,
|
||||
get_descendants_content_types_pks, get_postgresql_connections, get_weight, keyword_split,
|
||||
unidecode)
|
||||
get_descendants_content_types_pks, get_postgresql_connections, get_weight, unidecode)
|
||||
|
||||
|
||||
# TODO: Add autocomplete.
|
||||
|
|
@ -174,12 +176,29 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
|
|||
super(PostgresSearchQueryCompiler, self).__init__(*args, **kwargs)
|
||||
self.search_fields = self.queryset.model.get_search_fields()
|
||||
|
||||
def get_search_query(self, config):
|
||||
combine = OR if self.query.operator == 'or' else AND
|
||||
search_terms = keyword_split(unidecode(self.query.query_string))
|
||||
if not search_terms:
|
||||
return SearchQuery('')
|
||||
return combine(SearchQuery(q, config=config) for q in search_terms)
|
||||
def build_database_query(self, query=None, config=None):
|
||||
if query is None:
|
||||
query = self.query
|
||||
|
||||
if isinstance(query, PlainText):
|
||||
return self.build_database_query(query.to_combined_terms(), config)
|
||||
if isinstance(query, Term):
|
||||
# TODO: Find a way to use the term boosting.
|
||||
if query.boost != 1:
|
||||
warn('PostgreSQL search backend '
|
||||
'does not support term boosting for now.')
|
||||
return PostgresSearchQuery(unidecode(query.term), config=config)
|
||||
if isinstance(query, Not):
|
||||
return ~self.build_database_query(query.subquery, config)
|
||||
if isinstance(query, And):
|
||||
return AND(self.build_database_query(subquery, config)
|
||||
for subquery in query.subqueries)
|
||||
if isinstance(query, Or):
|
||||
return OR(self.build_database_query(subquery, config)
|
||||
for subquery in query.subqueries)
|
||||
raise NotImplementedError(
|
||||
'`%s` is not supported by the PostgreSQL search backend.'
|
||||
% self.query.__class__.__name__)
|
||||
|
||||
def get_boost(self, field_name, fields=None):
|
||||
if fields is None:
|
||||
|
|
@ -198,15 +217,11 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
|
|||
return field.boost
|
||||
|
||||
def search(self, config, start, stop):
|
||||
# TODO: Handle MatchAll nested inside other search query classes.
|
||||
if isinstance(self.query, MatchAll):
|
||||
return self.queryset[start:stop]
|
||||
|
||||
if not isinstance(self.query, PlainText):
|
||||
raise NotImplementedError(
|
||||
'%s is not supported by the PostgreSQL search backend.'
|
||||
% self.query.__class__)
|
||||
|
||||
search_query = self.get_search_query(config=config)
|
||||
search_query = self.build_database_query(config=config)
|
||||
queryset = self.queryset
|
||||
query = queryset.query
|
||||
if self.fields is None:
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler):
|
|||
|
||||
if not isinstance(self.query, PlainText):
|
||||
raise NotImplementedError(
|
||||
'%s is not supported by the database search backend.'
|
||||
% self.query.__class__)
|
||||
'`%s` is not supported by the database search backend.'
|
||||
% self.query.__class__.__name__)
|
||||
|
||||
# Get fields
|
||||
fields = self.fields or [field.field_name for field in model.get_searchable_search_fields()]
|
||||
|
|
|
|||
|
|
@ -377,8 +377,8 @@ class Elasticsearch2SearchQueryCompiler(BaseSearchQueryCompiler):
|
|||
|
||||
if not isinstance(self.query, PlainText):
|
||||
raise NotImplementedError(
|
||||
'%s is not supported by the Elasticsearch search backend.'
|
||||
% self.query.__class__)
|
||||
'`%s` is not supported by the Elasticsearch search backend.'
|
||||
% self.query.__class__.__name__)
|
||||
|
||||
fields = self.fields or ['_all', '_partials']
|
||||
operator = self.query.operator
|
||||
|
|
|
|||
|
|
@ -36,27 +36,39 @@ class MatchAll(SearchQuery):
|
|||
|
||||
|
||||
class PlainText(SearchQuery):
|
||||
def __init__(self, query_string: str, operator: str = None,
|
||||
boost: float = 1.0):
|
||||
OPERATORS = {
|
||||
'and': And,
|
||||
'or': Or,
|
||||
}
|
||||
DEFAULT_OPERATOR = 'and'
|
||||
|
||||
def __init__(self, query_string: str, operator: str = DEFAULT_OPERATOR,
|
||||
boost: float = 1):
|
||||
self.query_string = query_string
|
||||
if operator.lower() not in self.OPERATORS:
|
||||
raise ValueError("`operator` must be either 'or' or 'and'.")
|
||||
self.operator = operator
|
||||
self.boost = boost
|
||||
|
||||
def to_combined_terms(self):
|
||||
return self.OPERATORS[self.operator]([
|
||||
Term(term) for term in self.query_string.split()])
|
||||
|
||||
|
||||
class Term(SearchQuery):
|
||||
def __init__(self, term: str, boost: float = 1.0):
|
||||
def __init__(self, term: str, boost: float = 1):
|
||||
self.term = term
|
||||
self.boost = boost
|
||||
|
||||
|
||||
class Prefix(SearchQuery):
|
||||
def __init__(self, prefix: str, boost: float = 1.0):
|
||||
def __init__(self, prefix: str, boost: float = 1):
|
||||
self.prefix = prefix
|
||||
self.boost = boost
|
||||
|
||||
|
||||
class Fuzzy(SearchQuery):
|
||||
def __init__(self, term: str, max_distance: float = 3, boost: float = 1.0):
|
||||
def __init__(self, term: str, max_distance: float = 3, boost: float = 1):
|
||||
self.term = term
|
||||
self.max_distance = max_distance
|
||||
self.boost = boost
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from wagtail.wagtailsearch.backends import (
|
|||
InvalidSearchBackendError, get_search_backend, get_search_backends)
|
||||
from wagtail.wagtailsearch.backends.base import FieldError
|
||||
from wagtail.wagtailsearch.backends.db import DatabaseSearchBackend
|
||||
from wagtail.wagtailsearch.query import MATCH_ALL
|
||||
from wagtail.wagtailsearch.query import MATCH_ALL, And, Not, Or, PlainText, Term
|
||||
|
||||
|
||||
class BackendTests(WagtailTestUtils):
|
||||
|
|
@ -430,6 +430,87 @@ class BackendTests(WagtailTestUtils):
|
|||
"The Fellowship of the Ring"
|
||||
])
|
||||
|
||||
#
|
||||
# Query classes
|
||||
#
|
||||
|
||||
def test_match_all(self):
|
||||
results = self.backend.search(MATCH_ALL, models.Book.objects.all())
|
||||
self.assertEqual(len(results), 13)
|
||||
|
||||
def test_term(self):
|
||||
# Single word
|
||||
results = self.backend.search(Term('Javascript'),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results},
|
||||
{'JavaScript: The Definitive Guide',
|
||||
'JavaScript: The good parts'})
|
||||
|
||||
# Multiple word
|
||||
results = self.backend.search(Term('Javascript Guide'),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results},
|
||||
{'JavaScript: The Definitive Guide'})
|
||||
|
||||
def test_plain_text(self):
|
||||
# Single word
|
||||
results = self.backend.search(PlainText('Javascript'),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results},
|
||||
{'JavaScript: The Definitive Guide',
|
||||
'JavaScript: The good parts'})
|
||||
|
||||
# Multiple words (OR operator)
|
||||
results = self.backend.search(PlainText('Javascript Definitive',
|
||||
operator='or'),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results},
|
||||
{'JavaScript: The Definitive Guide',
|
||||
'JavaScript: The good parts'})
|
||||
|
||||
# Multiple words (AND operator)
|
||||
results = self.backend.search(PlainText('Javascript Definitive',
|
||||
operator='and'),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results},
|
||||
{'JavaScript: The Definitive Guide'})
|
||||
|
||||
def test_and(self):
|
||||
results = self.backend.search(And([Term('Javascript'),
|
||||
Term('Definitive')]),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results},
|
||||
{'JavaScript: The Definitive Guide'})
|
||||
|
||||
def test_or(self):
|
||||
results = self.backend.search(Or([Term('Hobbit'), Term('Towers')]),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results},
|
||||
{'The Hobbit', 'The Two Towers'})
|
||||
|
||||
def test_not(self):
|
||||
all_other_titles = {
|
||||
'A Clash of Kings',
|
||||
'A Game of Thrones',
|
||||
'A Storm of Swords',
|
||||
'Foundation',
|
||||
'Learning Python',
|
||||
'The Hobbit',
|
||||
'The Two Towers',
|
||||
'The Fellowship of the Ring',
|
||||
'The Return of the King',
|
||||
'The Rust Programming Language',
|
||||
'Two Scoops of Django 1.11',
|
||||
}
|
||||
|
||||
results = self.backend.search(Not(PlainText('Javascript')),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results}, all_other_titles)
|
||||
|
||||
results = self.backend.search(~PlainText('Javascript'),
|
||||
models.Book.objects.all())
|
||||
self.assertSetEqual({r.title for r in results}, all_other_titles)
|
||||
|
||||
|
||||
@override_settings(
|
||||
WAGTAILSEARCH_BACKENDS={
|
||||
|
|
|
|||
|
|
@ -44,3 +44,27 @@ class TestDBBackend(BackendTests, TestCase):
|
|||
@unittest.expectedFailure
|
||||
def test_same_rank_pages(self):
|
||||
super(TestDBBackend, self).test_same_rank_pages()
|
||||
|
||||
#
|
||||
# Query classes
|
||||
#
|
||||
|
||||
# Not implemented yet
|
||||
@unittest.expectedFailure
|
||||
def test_term(self):
|
||||
super().test_term()
|
||||
|
||||
# Not implemented yet
|
||||
@unittest.expectedFailure
|
||||
def test_and(self):
|
||||
super().test_and()
|
||||
|
||||
# Not implemented yet
|
||||
@unittest.expectedFailure
|
||||
def test_or(self):
|
||||
super().test_or()
|
||||
|
||||
# Not implemented yet
|
||||
@unittest.expectedFailure
|
||||
def test_not(self):
|
||||
super().test_not()
|
||||
|
|
|
|||
|
|
@ -37,6 +37,30 @@ class TestElasticsearch2SearchBackend(BackendTests, ElasticsearchCommonSearchBac
|
|||
def test_delete(self):
|
||||
super(TestElasticsearch2SearchBackend, self).test_delete()
|
||||
|
||||
#
|
||||
# Query classes
|
||||
#
|
||||
|
||||
# Not implemented yet
|
||||
@unittest.expectedFailure
|
||||
def test_term(self):
|
||||
super().test_term()
|
||||
|
||||
# Not implemented yet
|
||||
@unittest.expectedFailure
|
||||
def test_and(self):
|
||||
super().test_and()
|
||||
|
||||
# Not implemented yet
|
||||
@unittest.expectedFailure
|
||||
def test_or(self):
|
||||
super().test_or()
|
||||
|
||||
# Not implemented yet
|
||||
@unittest.expectedFailure
|
||||
def test_not(self):
|
||||
super().test_not()
|
||||
|
||||
|
||||
class TestElasticsearch2SearchQuery(TestCase):
|
||||
def assertDictEqual(self, a, b):
|
||||
|
|
|
|||
Loading…
Reference in a new issue