From dbe89153a68eaf77bcb336d5ecf65715f7819160 Mon Sep 17 00:00:00 2001 From: Bertrand Bordage Date: Thu, 4 Jan 2018 19:19:50 +0100 Subject: [PATCH] Implements `Prefix` on PostgreSQL. --- wagtail/contrib/postgres_search/apps.py | 8 +- wagtail/contrib/postgres_search/backend.py | 255 ++++++++++-------- .../migrations/0001_initial.py | 6 - .../migrations/0002_add_autocomplete.py | 49 ++++ wagtail/contrib/postgres_search/models.py | 22 +- wagtail/contrib/postgres_search/utils.py | 17 ++ wagtail/search/backends/db.py | 13 +- .../tests/elasticsearch_common_tests.py | 12 + wagtail/search/tests/test_backends.py | 22 +- wagtail/search/tests/test_db_backend.py | 5 + 10 files changed, 278 insertions(+), 131 deletions(-) create mode 100644 wagtail/contrib/postgres_search/migrations/0002_add_autocomplete.py diff --git a/wagtail/contrib/postgres_search/apps.py b/wagtail/contrib/postgres_search/apps.py index 169c2fb23..3054518ec 100644 --- a/wagtail/contrib/postgres_search/apps.py +++ b/wagtail/contrib/postgres_search/apps.py @@ -1,8 +1,7 @@ from django.apps import AppConfig from django.core.checks import Error, Tags, register -from .utils import ( - BOOSTS_WEIGHTS, WEIGHTS_VALUES, determine_boosts_weights, get_postgresql_connections) +from .utils import get_postgresql_connections, set_weights class PostgresSearchConfig(AppConfig): @@ -17,7 +16,4 @@ class PostgresSearchConfig(AppConfig): 'to use PostgreSQL search.', id='wagtail.contrib.postgres_search.E001')] - BOOSTS_WEIGHTS.extend(determine_boosts_weights()) - max_weight = BOOSTS_WEIGHTS[0][0] - WEIGHTS_VALUES.extend([v / max_weight - for v, w in reversed(BOOSTS_WEIGHTS)]) + set_weights() diff --git a/wagtail/contrib/postgres_search/backend.py b/wagtail/contrib/postgres_search/backend.py index e66eea1a2..3e93cdde4 100644 --- a/wagtail/contrib/postgres_search/backend.py +++ b/wagtail/contrib/postgres_search/backend.py @@ -10,33 +10,31 @@ from django.utils.encoding import force_text from wagtail.search.backends.base import ( BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults) -from wagtail.search.index import RelatedFields, SearchField -from wagtail.search.query import And, MatchAll, Not, Or, SearchQueryShortcut, Term +from wagtail.search.index import get_indexed_models, RelatedFields, SearchField +from wagtail.search.query import ( + And, MatchAll, Not, Or, Prefix, SearchQueryShortcut, Term) from wagtail.search.utils import ADD, AND, OR -from .models import IndexEntry +from .models import IndexEntry, SearchAutocomplete as PostgresSearchAutocomplete from .utils import ( - WEIGHTS_VALUES, get_ancestors_content_types_pks, get_content_type_pk, - get_descendants_content_types_pks, get_postgresql_connections, get_weight, unidecode) + get_content_type_pk, get_descendants_content_types_pks, + get_postgresql_connections, get_sql_weights, get_weight, unidecode) -# TODO: Add autocomplete. +EMPTY_VECTOR = SearchVector(Value('')) class Index: - def __init__(self, backend, model, db_alias=None): + def __init__(self, backend, db_alias=None): self.backend = backend - self.model = model - if db_alias is None: - db_alias = DEFAULT_DB_ALIAS - if connections[db_alias].vendor != 'postgresql': + self.name = self.backend.index_name + self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias + self.connection = connections[self.db_alias] + if self.connection.vendor != 'postgresql': raise NotSupportedError( 'You must select a PostgreSQL database ' 'to use PostgreSQL search.') - self.db_alias = db_alias - self.index_entries = IndexEntry._default_manager.using(self.db_alias) - self.name = model._meta.label - self.search_fields = self.model.get_search_fields() + self.entries = IndexEntry._default_manager.using(self.db_alias) def add_model(self, model): pass @@ -44,20 +42,23 @@ class Index: def refresh(self): pass - def delete_stale_entries(self): - if self.model._meta.parents: - # We don’t need to delete stale entries for non-root models, - # since we already delete them by deleting roots. - return - existing_pks = (self.model._default_manager.using(self.db_alias) + def delete_stale_model_entries(self, model): + existing_pks = (model._default_manager.using(self.db_alias) .annotate(object_id=Cast('pk', TextField())) .values('object_id')) - content_type_ids = get_descendants_content_types_pks(self.model) + content_types_pks = get_descendants_content_types_pks(model) stale_entries = ( - self.index_entries.filter(content_type_id__in=content_type_ids) + self.entries.filter(content_type_id__in=content_types_pks) .exclude(object_id__in=existing_pks)) stale_entries.delete() + def delete_stale_entries(self): + for model in get_indexed_models(): + # We don’t need to delete stale entries for non-root models, + # since we already delete them by deleting roots. + if not model._meta.parents: + self.delete_stale_model_entries(model) + def prepare_value(self, value): if isinstance(value, str): return value @@ -70,8 +71,8 @@ class Index: def prepare_field(self, obj, field): if isinstance(field, SearchField): - yield (unidecode(self.prepare_value(field.get_value(obj))), - get_weight(field.boost)) + yield (field, get_weight(field.boost), + unidecode(self.prepare_value(field.get_value(obj)))) elif isinstance(field, RelatedFields): sub_obj = field.get_value(obj) if sub_obj is None: @@ -84,83 +85,101 @@ class Index: sub_objs = [sub_obj] for sub_obj in sub_objs: for sub_field in field.fields: - for value in self.prepare_field(sub_obj, sub_field): - yield value + yield from self.prepare_field(sub_obj, sub_field) - def prepare_body(self, obj): - return [(value, boost) for field in self.search_fields - for value, boost in self.prepare_field(obj, field)] + def prepare_obj(self, obj, search_fields): + obj._object_id_ = force_text(obj.pk) + obj._autocomplete_ = [] + obj._body_ = [] + for field in search_fields: + for current_field, boost, value in self.prepare_field(obj, field): + if isinstance(current_field, SearchField) and \ + current_field.partial_match: + obj._autocomplete_.append((value, boost)) + else: + obj._body_.append((value, boost)) def add_item(self, obj): - self.add_items(self.model, [obj]) + self.add_items(obj._meta.model, [obj]) - def add_items_upsert(self, connection, content_type_pk, objs, config): - vectors_sql = [] + def add_items_upsert(self, content_type_pk, objs): + config = self.backend.config + autocomplete_sql = [] + body_sql = [] data_params = [] sql_template = ('to_tsvector(%s)' if config is None else "to_tsvector('%s', %%s)" % config) sql_template = 'setweight(%s, %%s)' % sql_template for obj in objs: - data_params.extend((content_type_pk, obj._object_id)) + data_params.extend((content_type_pk, obj._object_id_)) + if obj._autocomplete_: + autocomplete_sql.append('||'.join(sql_template + for _ in obj._autocomplete_)) + data_params.extend([v for t in obj._autocomplete_ for v in t]) + else: + autocomplete_sql.append("''::tsvector") if obj._body_: - vectors_sql.append('||'.join(sql_template for _ in obj._body_)) + body_sql.append('||'.join(sql_template for _ in obj._body_)) data_params.extend([v for t in obj._body_ for v in t]) else: - vectors_sql.append("''::tsvector") - data_sql = ', '.join(['(%%s, %%s, %s)' % s for s in vectors_sql]) - with connection.cursor() as cursor: + body_sql.append("''::tsvector") + data_sql = ', '.join(['(%%s, %%s, %s, %s)' % (a, b) + for a, b in zip(autocomplete_sql, body_sql)]) + with self.connection.cursor() as cursor: cursor.execute(""" - INSERT INTO %s(content_type_id, object_id, body_search) + INSERT INTO %s (content_type_id, object_id, autocomplete, body) (VALUES %s) ON CONFLICT (content_type_id, object_id) - DO UPDATE SET body_search = EXCLUDED.body_search + DO UPDATE SET autocomplete = EXCLUDED.autocomplete, + body = EXCLUDED.body """ % (IndexEntry._meta.db_table, data_sql), data_params) - def add_items_update_then_create(self, content_type_pk, objs, config): + def add_items_update_then_create(self, content_type_pk, objs): + config = self.backend.config ids_and_objs = {} for obj in objs: - obj._search_vector = ( + obj._autocomplete_ = ( + ADD([SearchVector(Value(text), weight=weight, config=config) + for text, weight in obj._autocomplete_]) + if obj._autocomplete_ else EMPTY_VECTOR) + obj._body_ = ( ADD([SearchVector(Value(text), weight=weight, config=config) for text, weight in obj._body_]) - if obj._body_ else SearchVector(Value(''))) - ids_and_objs[obj._object_id] = obj - index_entries_for_ct = self.index_entries.filter( + if obj._body_ else EMPTY_VECTOR) + ids_and_objs[obj._object_id_] = obj + index_entries_for_ct = self.entries.filter( content_type_id=content_type_pk) indexed_ids = frozenset( index_entries_for_ct.filter(object_id__in=ids_and_objs) .values_list('object_id', flat=True)) for indexed_id in indexed_ids: obj = ids_and_objs[indexed_id] - index_entries_for_ct.filter(object_id=obj._object_id) \ - .update(body_search=obj._search_vector) + index_entries_for_ct.filter(object_id=obj._object_id_) \ + .update(autocomplete=obj._autocomplete_, body=obj._body_) to_be_created = [] for object_id in ids_and_objs: if object_id not in indexed_ids: + obj = ids_and_objs[object_id] to_be_created.append(IndexEntry( - content_type_id=content_type_pk, - object_id=object_id, - body_search=ids_and_objs[object_id]._search_vector, - )) - self.index_entries.bulk_create(to_be_created) + content_type_id=content_type_pk, object_id=object_id, + autocomplete=obj._autocomplete_, body=obj._body_)) + self.entries.bulk_create(to_be_created) def add_items(self, model, objs): - content_type_pk = get_content_type_pk(model) - config = self.backend.get_config() + search_fields = model.get_search_fields() + if not search_fields: + return for obj in objs: - obj._object_id = force_text(obj.pk) - obj._body_ = self.prepare_body(obj) + self.prepare_obj(obj, search_fields) - # Removes index entries of an ancestor model in case the descendant - # model instance was created since. - self.index_entries.filter( - content_type_id__in=get_ancestors_content_types_pks(model) - ).filter(object_id__in=[obj._object_id for obj in objs]).delete() - - connection = connections[self.db_alias] - if connection.pg_version >= 90500: # PostgreSQL >= 9.5 - self.add_items_upsert(connection, content_type_pk, objs, config) - else: - self.add_items_update_then_create(content_type_pk, objs, config) + # TODO: Delete unindexed objects while dealing with proxy models. + if objs: + content_type_pk = get_content_type_pk(model) + # Use a faster method for PostgreSQL >= 9.5 + update_method = ( + self.add_items_upsert if self.connection.pg_version >= 90500 + else self.add_items_update_then_create) + update_method(content_type_pk, objs) def delete_item(self, item): item.index_entries.using(self.db_alias).delete() @@ -174,7 +193,40 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.search_fields = self.queryset.model.get_searchable_search_fields() + self.sql_weights = get_sql_weights() + # TODO: Better handle mixed queries containing + # both autocomplete and search. + self.is_autocomplete = False + if self.fields is not None: + search_fields = self.queryset.model.get_searchable_search_fields() + self.search_fields = { + field_lookup: self.get_search_field(field_lookup, + fields=search_fields) + for field_lookup in self.fields} + + def get_search_field(self, field_lookup, fields=None): + if fields is None: + fields = self.search_fields + if LOOKUP_SEP in field_lookup: + field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1) + else: + sub_field_name = None + for field in fields: + if isinstance(field, SearchField) \ + and field.field_name == field_lookup: + return field + # Note: Searching on a specific related field using + # `.search(fields=…)` is not yet supported by Wagtail. + # This method anticipates by already implementing it. + if isinstance(field, RelatedFields) \ + and field.field_name == field_lookup: + return self.get_search_field(sub_field_name, field.fields) + + # TODO: Find a way to use the term boosting. + def check_boost(self, query): + if query.boost != 1: + warn('PostgreSQL search backend ' + 'does not support term boosting for now.') def build_database_query(self, query=None, config=None): if query is None: @@ -182,11 +234,13 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): if isinstance(query, SearchQueryShortcut): return self.build_database_query(query.get_equivalent(), config) + if isinstance(query, Prefix): + self.check_boost(query) + self.is_autocomplete = True + return PostgresSearchAutocomplete(unidecode(query.prefix), + config=config) if isinstance(query, Term): - # TODO: Find a way to use the term boosting. - if query.boost != 1: - warn('PostgreSQL search backend ' - 'does not support term boosting for now.') + self.check_boost(query) return PostgresSearchQuery(unidecode(query.term), config=config) if isinstance(query, Not): return ~self.build_database_query(query.subquery, config) @@ -200,49 +254,28 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): '`%s` is not supported by the PostgreSQL search backend.' % self.query.__class__.__name__) - def get_boost(self, field_name, fields=None): - if fields is None: - fields = self.search_fields - if LOOKUP_SEP in field_name: - field_name, sub_field_name = field_name.split(LOOKUP_SEP, 1) - else: - sub_field_name = None - for field in fields: - if isinstance(field, SearchField) \ - and field.field_name == field_name: - # Note: Searching on a specific related field using - # `.search(fields=…)` is not yet supported by Wagtail. - # This method anticipates by already implementing it. - if isinstance(field, RelatedFields): - return self.get_boost(sub_field_name, field.fields) - return field.boost - def search(self, config, start, stop): # TODO: Handle MatchAll nested inside other search query classes. if isinstance(self.query, MatchAll): return self.queryset[start:stop] search_query = self.build_database_query(config=config) - queryset = self.queryset - query = queryset.query if self.fields is None: - vector = F('index_entries__body_search') + vector = F('index_entries__autocomplete') + if not self.is_autocomplete: + vector = vector._combine(F('index_entries__body'), '||', False) else: vector = ADD( - SearchVector(field, config=search_query.config, - weight=get_weight(self.get_boost(field))) - for field in self.fields) - vector = vector.resolve_expression(query) - search_query = search_query.resolve_expression(query) - lookup = IndexEntry._meta.get_field('body_search').get_lookup('exact')( - vector, search_query) - query.where.add(lookup, 'AND') + SearchVector(field_lookup, config=search_query.config, + weight=get_weight(search_field.boost)) + for field_lookup, search_field in self.search_fields.items() + if not self.is_autocomplete or search_field.partial_match) + queryset = self.queryset.annotate( + _vector_=vector).filter(_vector_=search_query) if self.order_by_relevance: - # Due to a Django bug, arrays are not automatically converted here. - converted_weights = '{' + ','.join(map(str, WEIGHTS_VALUES)) + '}' - queryset = queryset.order_by(SearchRank(vector, search_query, - weights=converted_weights).desc(), - '-pk') + rank_expression = SearchRank(F('_vector_'), search_query, + weights=self.sql_weights) + queryset = queryset.order_by(rank_expression.desc(), '-pk') elif not queryset.query.order_by: # Adds a default ordering to avoid issue #3729. queryset = queryset.order_by('-pk') @@ -268,11 +301,11 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler): class PostgresSearchResults(BaseSearchResults): def _do_search(self): - return list(self.query_compiler.search(self.backend.get_config(), + return list(self.query_compiler.search(self.backend.config, self.start, self.stop)) def _do_count(self): - return self.query_compiler.search(self.backend.get_config(), None, None).count() + return self.query_compiler.search(self.backend.config, None, None).count() class PostgresSearchRebuilder: @@ -317,16 +350,14 @@ class PostgresSearchBackend(BaseSearchBackend): def __init__(self, params): super().__init__(params) - self.params = params + self.index_name = params.get('INDEX', 'default') + self.config = params.get('SEARCH_CONFIG') if params.get('ATOMIC_REBUILD', False): self.rebuilder_class = self.atomic_rebuilder_class IndexEntry.add_generic_relations() - def get_config(self): - return self.params.get('SEARCH_CONFIG') - def get_index_for_model(self, model, db_alias=None): - return Index(self, model, db_alias) + return Index(self, db_alias) def get_index_for_object(self, obj): return self.get_index_for_model(obj._meta.model, obj._state.db) diff --git a/wagtail/contrib/postgres_search/migrations/0001_initial.py b/wagtail/contrib/postgres_search/migrations/0001_initial.py index 73fbd0efd..811a374ad 100644 --- a/wagtail/contrib/postgres_search/migrations/0001_initial.py +++ b/wagtail/contrib/postgres_search/migrations/0001_initial.py @@ -42,11 +42,5 @@ class Migration(migrations.Migration): 'CREATE INDEX {0}_body_search ON {0} ' 'USING GIN(body_search);'.format(table), 'DROP INDEX {}_body_search;'.format(table), - state_operations=[migrations.AddIndex( - model_name='indexentry', - index=django.contrib.postgres.indexes.GinIndex( - fields=['body_search'], - name='postgres_se_body_se_70ba1a_gin'), - )], ), ] diff --git a/wagtail/contrib/postgres_search/migrations/0002_add_autocomplete.py b/wagtail/contrib/postgres_search/migrations/0002_add_autocomplete.py new file mode 100644 index 000000000..353846bb0 --- /dev/null +++ b/wagtail/contrib/postgres_search/migrations/0002_add_autocomplete.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.5 on 2017-10-19 14:53 +from __future__ import unicode_literals + +import django.contrib.postgres.search +from django.db import migrations + +from ..models import IndexEntry + + +table = IndexEntry._meta.db_table + + +class Migration(migrations.Migration): + + dependencies = [ + ('postgres_search', '0001_initial'), + ] + + operations = [ + migrations.RunSQL( + 'DROP INDEX {}_body_search;'.format(table), + 'CREATE INDEX {0}_body_search ON {0} ' + 'USING GIN(body_search);'.format(table), + ), + migrations.RenameField( + model_name='indexentry', + old_name='body_search', + new_name='body', + ), + migrations.AddField( + model_name='indexentry', + name='autocomplete', + field=django.contrib.postgres.search.SearchVectorField(default=''), + preserve_default=False, + ), + migrations.AddIndex( + model_name='indexentry', + index=django.contrib.postgres.indexes.GinIndex( + fields=['autocomplete'], + name='postgres_search_autocomplete'), + ), + migrations.AddIndex( + model_name='indexentry', + index=django.contrib.postgres.indexes.GinIndex( + fields=['body'], + name='postgres_search_body'), + ), + ] diff --git a/wagtail/contrib/postgres_search/models.py b/wagtail/contrib/postgres_search/models.py index cbe9cc5da..48e147699 100644 --- a/wagtail/contrib/postgres_search/models.py +++ b/wagtail/contrib/postgres_search/models.py @@ -2,7 +2,7 @@ from django.apps import apps from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.contrib.contenttypes.models import ContentType from django.contrib.postgres.indexes import GinIndex -from django.contrib.postgres.search import SearchVectorField +from django.contrib.postgres.search import SearchQuery, SearchVectorField from django.db.models import CASCADE, ForeignKey, Model, TextField from django.db.models.functions import Cast from django.utils.translation import ugettext_lazy as _ @@ -11,6 +11,20 @@ from wagtail.search.index import class_is_indexed from .utils import get_descendants_content_types_pks +class SearchAutocomplete(SearchQuery): + def as_sql(self, compiler, connection): + params = [self.value] + if self.config: + config_sql, config_params = compiler.compile(self.config) + template = "to_tsquery({}::regconfig, ''%s':*')".format(config_sql) + params = config_params + [self.value] + else: + template = "to_tsquery(''%s':*')" + if self.invert: + template = '!!({})'.format(template) + return template, params + + class TextIDGenericRelation(GenericRelation): auto_created = True @@ -46,13 +60,15 @@ class IndexEntry(Model): content_object = GenericForeignKey() # TODO: Add per-object boosting. - body_search = SearchVectorField() + autocomplete = SearchVectorField() + body = SearchVectorField() class Meta: unique_together = ('content_type', 'object_id') verbose_name = _('index entry') verbose_name_plural = _('index entries') - indexes = [GinIndex(fields=['body_search'])] + indexes = [GinIndex(fields=['autocomplete']), + GinIndex(fields=['body'])] def __str__(self): return '%s: %s' % (self.content_type.name, self.content_object) diff --git a/wagtail/contrib/postgres_search/utils.py b/wagtail/contrib/postgres_search/utils.py index ef4a5e9c5..cc04fb4d7 100644 --- a/wagtail/contrib/postgres_search/utils.py +++ b/wagtail/contrib/postgres_search/utils.py @@ -94,6 +94,19 @@ def determine_boosts_weights(boosts=()): for i, weight in enumerate(WEIGHTS)] +def set_weights(): + BOOSTS_WEIGHTS.extend(determine_boosts_weights()) + weights = [w for w, c in BOOSTS_WEIGHTS] + min_weight = min(weights) + max_weight = max(weights) + if min_weight <= 0: + if min_weight == 0: + min_weight = -0.1 + weights = [w - min_weight for w in weights] + WEIGHTS_VALUES.extend([w / max_weight + for w in reversed(weights)]) + + def get_weight(boost): if boost is None: return WEIGHTS[-1] @@ -101,3 +114,7 @@ def get_weight(boost): if boost >= max_boost: return weight return weight + + +def get_sql_weights(): + return '{' + ','.join(map(str, WEIGHTS_VALUES)) + '}' diff --git a/wagtail/search/backends/db.py b/wagtail/search/backends/db.py index b25aa669e..22194f3da 100644 --- a/wagtail/search/backends/db.py +++ b/wagtail/search/backends/db.py @@ -5,7 +5,8 @@ from django.db.models.expressions import Value from wagtail.search.backends.base import ( BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults) -from wagtail.search.query import And, MatchAll, Not, Or, SearchQueryShortcut, Term +from wagtail.search.query import ( + And, MatchAll, Not, Or, Prefix, SearchQueryShortcut, Term) from wagtail.search.utils import AND, OR @@ -51,6 +52,10 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): term_query |= models.Q(**{field_name + '__icontains': term}) return term_query + def check_boost(self, query): + if query.boost != 1: + warn('Database search backend does not support term boosting.') + def build_database_filter(self, query=None): if query is None: query = self.query @@ -61,9 +66,11 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler): if isinstance(query, SearchQueryShortcut): return self.build_database_filter(query.get_equivalent()) if isinstance(query, Term): - if query.boost != 1: - warn('Database search backend does not support term boosting.') + self.check_boost(query) return self.build_single_term_filter(query.term) + if isinstance(query, Prefix): + self.check_boost(query) + return self.build_single_term_filter(query.prefix) if isinstance(query, Not): return ~self.build_database_filter(query.subquery) if isinstance(query, And): diff --git a/wagtail/search/tests/elasticsearch_common_tests.py b/wagtail/search/tests/elasticsearch_common_tests.py index 0262fece4..df7a1d24f 100644 --- a/wagtail/search/tests/elasticsearch_common_tests.py +++ b/wagtail/search/tests/elasticsearch_common_tests.py @@ -1,3 +1,4 @@ +import unittest from datetime import date from io import StringIO @@ -171,3 +172,14 @@ class ElasticsearchCommonSearchBackendTests(BackendTests): results = self.backend.search(MATCH_ALL, models.Book)[110:] self.assertEqual(len(results), 53) + + # Elasticsearch always does prefix matching on `partial_match` fields, + # even when we don’t use `Prefix`. + @unittest.expectedFailure + def test_incomplete_term(self): + super().test_incomplete_term() + + # Elasticsearch does not accept prefix for multiple words + @unittest.expectedFailure + def test_prefix_multiple_words(self): + super().test_prefix_multiple_words() diff --git a/wagtail/search/tests/test_backends.py b/wagtail/search/tests/test_backends.py index 92530ecc2..bb47f3cd7 100644 --- a/wagtail/search/tests/test_backends.py +++ b/wagtail/search/tests/test_backends.py @@ -15,7 +15,9 @@ from wagtail.search.backends import ( InvalidSearchBackendError, get_search_backend, get_search_backends) from wagtail.search.backends.base import FieldError from wagtail.search.backends.db import DatabaseSearchBackend -from wagtail.search.query import MATCH_ALL, And, Boost, Filter, Not, Or, PlainText, Term +from wagtail.search.query import ( + MATCH_ALL, And, Boost, Filter, Not, Or, PlainText, Prefix, Term, +) class BackendTests(WagtailTestUtils): @@ -448,6 +450,13 @@ class BackendTests(WagtailTestUtils): {'JavaScript: The Definitive Guide', 'JavaScript: The good parts'}) + def test_incomplete_term(self): + # Single word + results = self.backend.search(Term('pro'), + models.Book.objects.all()) + + self.assertSetEqual({r.title for r in results}, set()) + def test_and(self): results = self.backend.search(And([Term('javascript'), Term('definitive')]), @@ -507,6 +516,17 @@ class BackendTests(WagtailTestUtils): 'The Rust Programming Language', 'Two Scoops of Django 1.11'}) + def test_prefix_single_word(self): + results = self.backend.search(Prefix('pro'), models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'The Rust Programming Language'}) + + def test_prefix_multiple_words(self): + results = self.backend.search(Prefix('rust pro'), + models.Book.objects.all()) + self.assertSetEqual({r.title for r in results}, + {'The Rust Programming Language'}) + # # Shortcut query classes # diff --git a/wagtail/search/tests/test_db_backend.py b/wagtail/search/tests/test_db_backend.py index fba3332d7..8222d33ed 100644 --- a/wagtail/search/tests/test_db_backend.py +++ b/wagtail/search/tests/test_db_backend.py @@ -37,3 +37,8 @@ class TestDBBackend(BackendTests, TestCase): @unittest.expectedFailure def test_search_callable_field(self): super().test_search_callable_field() + + # Database backend always uses `icontains`, so always autocomplete + @unittest.expectedFailure + def test_incomplete_term(self): + super().test_incomplete_term()