From 2926304e8418f2ee6cb0367ffd15aa8b8d87428e Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Fri, 8 Jun 2018 15:09:46 +0100 Subject: [PATCH] Rewrite Boost to no longer be a shortcut --- wagtail/contrib/postgres_search/backend.py | 27 +++++++------ .../postgres_search/tests/test_backend.py | 7 ++++ wagtail/search/backends/db.py | 24 ++++++----- wagtail/search/query.py | 26 +++++------- wagtail/search/tests/test_backends.py | 40 +++++++------------ wagtail/search/tests/test_db_backend.py | 5 +++ 6 files changed, 64 insertions(+), 65 deletions(-) diff --git a/wagtail/contrib/postgres_search/backend.py b/wagtail/contrib/postgres_search/backend.py index 1664e5355..9fe0091b3 100644 --- a/wagtail/contrib/postgres_search/backend.py +++ b/wagtail/contrib/postgres_search/backend.py @@ -11,7 +11,7 @@ from django.utils.encoding import force_text from wagtail.search.backends.base import ( BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults) from wagtail.search.index import RelatedFields, SearchField, get_indexed_models -from wagtail.search.query import And, MatchAll, Not, Or, PlainText, Prefix, SearchQueryShortcut, Term +from wagtail.search.query import And, Boost, MatchAll, Not, Or, PlainText, Prefix, SearchQueryShortcut, Term from wagtail.search.utils import ADD, AND, OR from .models import SearchAutocomplete as PostgresSearchAutocomplete @@ -226,12 +226,12 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): return self.get_search_field(sub_field_name, field.fields) # TODO: Find a way to use the term boosting. - def check_boost(self, query): - if query.boost != 1: + def check_boost(self, query, boost=1.0): + if query.boost * boost != 1.0: warn('PostgreSQL search backend ' 'does not support term boosting for now.') - def build_database_query(self, query=None, config=None): + def build_database_query(self, query=None, config=None, boost=1.0): if query is None: query = self.query @@ -244,28 +244,31 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): Term(term, boost=query.boost) for term in query.query_string.split() ]) - return self.build_database_query(q, config) + return self.build_database_query(q, config, boost=boost) + if isinstance(query, Boost): + boost *= query.boost + return self.build_database_query(query.subquery, config, boost=boost) if isinstance(query, SearchQueryShortcut): - return self.build_database_query(query.get_equivalent(), config) + return self.build_database_query(query.get_equivalent(), config, boost=boost) if isinstance(query, Prefix): - self.check_boost(query) + self.check_boost(query, boost=boost) self.is_autocomplete = True return PostgresSearchAutocomplete(unidecode(query.prefix), config=config) if isinstance(query, Term): - self.check_boost(query) + self.check_boost(query, boost=boost) return PostgresSearchQuery(unidecode(query.term), config=config) if isinstance(query, Not): - return ~self.build_database_query(query.subquery, config) + return ~self.build_database_query(query.subquery, config, boost=boost) if isinstance(query, And): - return AND(self.build_database_query(subquery, config) + return AND(self.build_database_query(subquery, config, boost=boost) for subquery in query.subqueries) if isinstance(query, Or): - return OR(self.build_database_query(subquery, config) + return OR(self.build_database_query(subquery, config, boost=boost) for subquery in query.subqueries) raise NotImplementedError( '`%s` is not supported by the PostgreSQL search backend.' - % self.query.__class__.__name__) + % query.__class__.__name__) def search(self, config, start, stop, score_field=None): # TODO: Handle MatchAll nested inside other search query classes. diff --git a/wagtail/contrib/postgres_search/tests/test_backend.py b/wagtail/contrib/postgres_search/tests/test_backend.py index 50447365a..632b44bd9 100644 --- a/wagtail/contrib/postgres_search/tests/test_backend.py +++ b/wagtail/contrib/postgres_search/tests/test_backend.py @@ -1,3 +1,5 @@ +import unittest + from django.test import TestCase from wagtail.search.tests.test_backends import BackendTests @@ -35,3 +37,8 @@ class TestPostgresSearchBackend(BackendTests, TestCase): [(6, 'A'), (4, 'B'), (2, 'C'), (0, 'D')]) self.assertListEqual(determine_boosts_weights([-2, -1, 0, 1, 2, 3, 4]), [(4, 'A'), (2, 'B'), (0, 'C'), (-2, 'D')]) + + # Doesn't support Boost() query class + @unittest.expectedFailure + def test_boost(self): + super().test_boost() diff --git a/wagtail/search/backends/db.py b/wagtail/search/backends/db.py index 3e92c66c6..227132671 100644 --- a/wagtail/search/backends/db.py +++ b/wagtail/search/backends/db.py @@ -5,7 +5,7 @@ from django.db.models.expressions import Value from wagtail.search.backends.base import ( BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults) -from wagtail.search.query import And, MatchAll, Not, Or, PlainText, Prefix, SearchQueryShortcut, Term +from wagtail.search.query import And, Boost, MatchAll, Not, Or, PlainText, Prefix, SearchQueryShortcut, Term from wagtail.search.utils import AND, OR @@ -51,11 +51,11 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): term_query |= models.Q(**{field_name + '__icontains': term}) return term_query - def check_boost(self, query): - if query.boost != 1: + def check_boost(self, query, boost=1.0): + if query.boost * boost != 1.0: warn('Database search backend does not support term boosting.') - def build_database_filter(self, query=None): + def build_database_filter(self, query=None, boost=1.0): if query is None: query = self.query @@ -68,13 +68,17 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): Term(term, boost=query.boost) for term in query.query_string.split() ]) - return self.build_database_filter(q) + return self.build_database_filter(q, boost=boost) + + if isinstance(query, Boost): + boost *= query.boost + return self.build_database_filter(query.subquery, boost=boost) if isinstance(self.query, MatchAll): return models.Q() if isinstance(query, SearchQueryShortcut): - return self.build_database_filter(query.get_equivalent()) + return self.build_database_filter(query.get_equivalent(), boost=boost) if isinstance(query, Term): self.check_boost(query) return self.build_single_term_filter(query.term) @@ -82,16 +86,16 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): self.check_boost(query) return self.build_single_term_filter(query.prefix) if isinstance(query, Not): - return ~self.build_database_filter(query.subquery) + return ~self.build_database_filter(query.subquery, boost=boost) if isinstance(query, And): - return AND(self.build_database_filter(subquery) + return AND(self.build_database_filter(subquery, boost=boost) for subquery in query.subqueries) if isinstance(query, Or): - return OR(self.build_database_filter(subquery) + return OR(self.build_database_filter(subquery, boost=boost) for subquery in query.subqueries) raise NotImplementedError( '`%s` is not supported by the database search backend.' - % self.query.__class__.__name__) + % query.__class__.__name__) class DatabaseSearchResults(BaseSearchResults): diff --git a/wagtail/search/query.py b/wagtail/search/query.py index ca3557c83..9321c4930 100644 --- a/wagtail/search/query.py +++ b/wagtail/search/query.py @@ -112,6 +112,15 @@ class MatchAll(SearchQuery): return self.__class__() +class Boost(SearchQuery): + def __init__(self, subquery: SearchQuery, boost: float): + self.subquery = subquery + self.boost = boost + + def apply(self, func): + return func(self.__class__(self.subquery.apply(func), self.boost)) + + class Term(SearchQuery): def __init__(self, term: str, boost: float = 1): self.term = term @@ -145,21 +154,4 @@ class Fuzzy(SearchQuery): # -class Boost(SearchQueryShortcut): - def __init__(self, subquery: SearchQuery, boost: float): - self.subquery = subquery - self.boost = boost - - def apply(self, func): - return func(self.__class__(self.subquery.apply(func), self.boost)) - - def get_equivalent(self): - def boost_child(child): - if isinstance(child, (PlainText, Fuzzy, Prefix, Term)): - child.boost *= self.boost - return child - - return self.subquery.apply(boost_child) - - MATCH_ALL = MatchAll() diff --git a/wagtail/search/tests/test_backends.py b/wagtail/search/tests/test_backends.py index be4147817..366462f5a 100644 --- a/wagtail/search/tests/test_backends.py +++ b/wagtail/search/tests/test_backends.py @@ -598,35 +598,23 @@ class BackendTests(WagtailTestUtils): self.backend.search('Guide', models.Book.objects.all(), operator='xor') - def test_boost_equivalent(self): - boost = Boost(Term('guide'), 5) - equivalent = boost.children[0] - self.assertIsInstance(equivalent, Term) - self.assertAlmostEqual(equivalent.boost, 5) + def test_boost(self): + results = self.backend.search(PlainText('JavaScript Definitive') | Boost(PlainText('Learning Python'), 2.0), models.Book.objects.all()) - boost = Boost(Term('guide', boost=0.5), 5) - equivalent = boost.children[0] - self.assertIsInstance(equivalent, Term) - self.assertAlmostEqual(equivalent.boost, 2.5) + # Both python and JavaScript should be returned with Python at the top + self.assertEqual([r.title for r in results], [ + "Learning Python", + "JavaScript: The Definitive Guide", + ]) - boost = Boost(Boost(Term('guide', 0.1), 3), 5) - sub_boost = boost.children[0] - self.assertIsInstance(sub_boost, Boost) - sub_boost = sub_boost.children[0] - self.assertIsInstance(sub_boost, Term) - self.assertAlmostEqual(sub_boost.boost, 1.5) - boost = Boost(And([Boost(Term('guide', 0.1), 3), Term('two', 2)]), 5) - and_obj = boost.children[0] - self.assertIsInstance(and_obj, And) - sub_boost = and_obj.children[0] - self.assertIsInstance(sub_boost, Boost) - guide = sub_boost.children[0] - self.assertIsInstance(guide, Term) - self.assertAlmostEqual(guide.boost, 1.5) - two = and_obj.children[1] - self.assertIsInstance(two, Term) - self.assertAlmostEqual(two.boost, 10) + results = self.backend.search(PlainText('JavaScript Definitive') | Boost(PlainText('Learning Python'), 0.5), models.Book.objects.all()) + + # Now they should be swapped + self.assertEqual([r.title for r in results], [ + "JavaScript: The Definitive Guide", + "Learning Python", + ]) @override_settings( diff --git a/wagtail/search/tests/test_db_backend.py b/wagtail/search/tests/test_db_backend.py index a8fd04c04..f684401d0 100644 --- a/wagtail/search/tests/test_db_backend.py +++ b/wagtail/search/tests/test_db_backend.py @@ -57,3 +57,8 @@ class TestDBBackend(BackendTests, TestCase): @unittest.expectedFailure def test_incomplete_plain_text(self): super().test_incomplete_plain_text() + + # Database backend doesn't support Boost() query class + @unittest.expectedFailure + def test_boost(self): + super().test_boost()