Implements Prefix on PostgreSQL.

This commit is contained in:
Bertrand Bordage 2018-01-04 19:19:50 +01:00
parent c3b6966b31
commit dbe89153a6
10 changed files with 278 additions and 131 deletions

View file

@ -1,8 +1,7 @@
from django.apps import AppConfig
from django.core.checks import Error, Tags, register
from .utils import (
BOOSTS_WEIGHTS, WEIGHTS_VALUES, determine_boosts_weights, get_postgresql_connections)
from .utils import get_postgresql_connections, set_weights
class PostgresSearchConfig(AppConfig):
@ -17,7 +16,4 @@ class PostgresSearchConfig(AppConfig):
'to use PostgreSQL search.',
id='wagtail.contrib.postgres_search.E001')]
BOOSTS_WEIGHTS.extend(determine_boosts_weights())
max_weight = BOOSTS_WEIGHTS[0][0]
WEIGHTS_VALUES.extend([v / max_weight
for v, w in reversed(BOOSTS_WEIGHTS)])
set_weights()

View file

@ -10,33 +10,31 @@ from django.utils.encoding import force_text
from wagtail.search.backends.base import (
BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults)
from wagtail.search.index import RelatedFields, SearchField
from wagtail.search.query import And, MatchAll, Not, Or, SearchQueryShortcut, Term
from wagtail.search.index import get_indexed_models, RelatedFields, SearchField
from wagtail.search.query import (
And, MatchAll, Not, Or, Prefix, SearchQueryShortcut, Term)
from wagtail.search.utils import ADD, AND, OR
from .models import IndexEntry
from .models import IndexEntry, SearchAutocomplete as PostgresSearchAutocomplete
from .utils import (
WEIGHTS_VALUES, get_ancestors_content_types_pks, get_content_type_pk,
get_descendants_content_types_pks, get_postgresql_connections, get_weight, unidecode)
get_content_type_pk, get_descendants_content_types_pks,
get_postgresql_connections, get_sql_weights, get_weight, unidecode)
# TODO: Add autocomplete.
EMPTY_VECTOR = SearchVector(Value(''))
class Index:
def __init__(self, backend, model, db_alias=None):
def __init__(self, backend, db_alias=None):
self.backend = backend
self.model = model
if db_alias is None:
db_alias = DEFAULT_DB_ALIAS
if connections[db_alias].vendor != 'postgresql':
self.name = self.backend.index_name
self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
self.connection = connections[self.db_alias]
if self.connection.vendor != 'postgresql':
raise NotSupportedError(
'You must select a PostgreSQL database '
'to use PostgreSQL search.')
self.db_alias = db_alias
self.index_entries = IndexEntry._default_manager.using(self.db_alias)
self.name = model._meta.label
self.search_fields = self.model.get_search_fields()
self.entries = IndexEntry._default_manager.using(self.db_alias)
def add_model(self, model):
pass
@ -44,20 +42,23 @@ class Index:
def refresh(self):
pass
def delete_stale_entries(self):
if self.model._meta.parents:
# We dont need to delete stale entries for non-root models,
# since we already delete them by deleting roots.
return
existing_pks = (self.model._default_manager.using(self.db_alias)
def delete_stale_model_entries(self, model):
existing_pks = (model._default_manager.using(self.db_alias)
.annotate(object_id=Cast('pk', TextField()))
.values('object_id'))
content_type_ids = get_descendants_content_types_pks(self.model)
content_types_pks = get_descendants_content_types_pks(model)
stale_entries = (
self.index_entries.filter(content_type_id__in=content_type_ids)
self.entries.filter(content_type_id__in=content_types_pks)
.exclude(object_id__in=existing_pks))
stale_entries.delete()
def delete_stale_entries(self):
for model in get_indexed_models():
# We dont need to delete stale entries for non-root models,
# since we already delete them by deleting roots.
if not model._meta.parents:
self.delete_stale_model_entries(model)
def prepare_value(self, value):
if isinstance(value, str):
return value
@ -70,8 +71,8 @@ class Index:
def prepare_field(self, obj, field):
if isinstance(field, SearchField):
yield (unidecode(self.prepare_value(field.get_value(obj))),
get_weight(field.boost))
yield (field, get_weight(field.boost),
unidecode(self.prepare_value(field.get_value(obj))))
elif isinstance(field, RelatedFields):
sub_obj = field.get_value(obj)
if sub_obj is None:
@ -84,83 +85,101 @@ class Index:
sub_objs = [sub_obj]
for sub_obj in sub_objs:
for sub_field in field.fields:
for value in self.prepare_field(sub_obj, sub_field):
yield value
yield from self.prepare_field(sub_obj, sub_field)
def prepare_body(self, obj):
return [(value, boost) for field in self.search_fields
for value, boost in self.prepare_field(obj, field)]
def prepare_obj(self, obj, search_fields):
obj._object_id_ = force_text(obj.pk)
obj._autocomplete_ = []
obj._body_ = []
for field in search_fields:
for current_field, boost, value in self.prepare_field(obj, field):
if isinstance(current_field, SearchField) and \
current_field.partial_match:
obj._autocomplete_.append((value, boost))
else:
obj._body_.append((value, boost))
def add_item(self, obj):
self.add_items(self.model, [obj])
self.add_items(obj._meta.model, [obj])
def add_items_upsert(self, connection, content_type_pk, objs, config):
vectors_sql = []
def add_items_upsert(self, content_type_pk, objs):
config = self.backend.config
autocomplete_sql = []
body_sql = []
data_params = []
sql_template = ('to_tsvector(%s)' if config is None
else "to_tsvector('%s', %%s)" % config)
sql_template = 'setweight(%s, %%s)' % sql_template
for obj in objs:
data_params.extend((content_type_pk, obj._object_id))
data_params.extend((content_type_pk, obj._object_id_))
if obj._autocomplete_:
autocomplete_sql.append('||'.join(sql_template
for _ in obj._autocomplete_))
data_params.extend([v for t in obj._autocomplete_ for v in t])
else:
autocomplete_sql.append("''::tsvector")
if obj._body_:
vectors_sql.append('||'.join(sql_template for _ in obj._body_))
body_sql.append('||'.join(sql_template for _ in obj._body_))
data_params.extend([v for t in obj._body_ for v in t])
else:
vectors_sql.append("''::tsvector")
data_sql = ', '.join(['(%%s, %%s, %s)' % s for s in vectors_sql])
with connection.cursor() as cursor:
body_sql.append("''::tsvector")
data_sql = ', '.join(['(%%s, %%s, %s, %s)' % (a, b)
for a, b in zip(autocomplete_sql, body_sql)])
with self.connection.cursor() as cursor:
cursor.execute("""
INSERT INTO %s(content_type_id, object_id, body_search)
INSERT INTO %s (content_type_id, object_id, autocomplete, body)
(VALUES %s)
ON CONFLICT (content_type_id, object_id)
DO UPDATE SET body_search = EXCLUDED.body_search
DO UPDATE SET autocomplete = EXCLUDED.autocomplete,
body = EXCLUDED.body
""" % (IndexEntry._meta.db_table, data_sql), data_params)
def add_items_update_then_create(self, content_type_pk, objs, config):
def add_items_update_then_create(self, content_type_pk, objs):
config = self.backend.config
ids_and_objs = {}
for obj in objs:
obj._search_vector = (
obj._autocomplete_ = (
ADD([SearchVector(Value(text), weight=weight, config=config)
for text, weight in obj._autocomplete_])
if obj._autocomplete_ else EMPTY_VECTOR)
obj._body_ = (
ADD([SearchVector(Value(text), weight=weight, config=config)
for text, weight in obj._body_])
if obj._body_ else SearchVector(Value('')))
ids_and_objs[obj._object_id] = obj
index_entries_for_ct = self.index_entries.filter(
if obj._body_ else EMPTY_VECTOR)
ids_and_objs[obj._object_id_] = obj
index_entries_for_ct = self.entries.filter(
content_type_id=content_type_pk)
indexed_ids = frozenset(
index_entries_for_ct.filter(object_id__in=ids_and_objs)
.values_list('object_id', flat=True))
for indexed_id in indexed_ids:
obj = ids_and_objs[indexed_id]
index_entries_for_ct.filter(object_id=obj._object_id) \
.update(body_search=obj._search_vector)
index_entries_for_ct.filter(object_id=obj._object_id_) \
.update(autocomplete=obj._autocomplete_, body=obj._body_)
to_be_created = []
for object_id in ids_and_objs:
if object_id not in indexed_ids:
obj = ids_and_objs[object_id]
to_be_created.append(IndexEntry(
content_type_id=content_type_pk,
object_id=object_id,
body_search=ids_and_objs[object_id]._search_vector,
))
self.index_entries.bulk_create(to_be_created)
content_type_id=content_type_pk, object_id=object_id,
autocomplete=obj._autocomplete_, body=obj._body_))
self.entries.bulk_create(to_be_created)
def add_items(self, model, objs):
content_type_pk = get_content_type_pk(model)
config = self.backend.get_config()
search_fields = model.get_search_fields()
if not search_fields:
return
for obj in objs:
obj._object_id = force_text(obj.pk)
obj._body_ = self.prepare_body(obj)
self.prepare_obj(obj, search_fields)
# Removes index entries of an ancestor model in case the descendant
# model instance was created since.
self.index_entries.filter(
content_type_id__in=get_ancestors_content_types_pks(model)
).filter(object_id__in=[obj._object_id for obj in objs]).delete()
connection = connections[self.db_alias]
if connection.pg_version >= 90500: # PostgreSQL >= 9.5
self.add_items_upsert(connection, content_type_pk, objs, config)
else:
self.add_items_update_then_create(content_type_pk, objs, config)
# TODO: Delete unindexed objects while dealing with proxy models.
if objs:
content_type_pk = get_content_type_pk(model)
# Use a faster method for PostgreSQL >= 9.5
update_method = (
self.add_items_upsert if self.connection.pg_version >= 90500
else self.add_items_update_then_create)
update_method(content_type_pk, objs)
def delete_item(self, item):
item.index_entries.using(self.db_alias).delete()
@ -174,7 +193,40 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.search_fields = self.queryset.model.get_searchable_search_fields()
self.sql_weights = get_sql_weights()
# TODO: Better handle mixed queries containing
# both autocomplete and search.
self.is_autocomplete = False
if self.fields is not None:
search_fields = self.queryset.model.get_searchable_search_fields()
self.search_fields = {
field_lookup: self.get_search_field(field_lookup,
fields=search_fields)
for field_lookup in self.fields}
def get_search_field(self, field_lookup, fields=None):
if fields is None:
fields = self.search_fields
if LOOKUP_SEP in field_lookup:
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
else:
sub_field_name = None
for field in fields:
if isinstance(field, SearchField) \
and field.field_name == field_lookup:
return field
# Note: Searching on a specific related field using
# `.search(fields=…)` is not yet supported by Wagtail.
# This method anticipates by already implementing it.
if isinstance(field, RelatedFields) \
and field.field_name == field_lookup:
return self.get_search_field(sub_field_name, field.fields)
# TODO: Find a way to use the term boosting.
def check_boost(self, query):
if query.boost != 1:
warn('PostgreSQL search backend '
'does not support term boosting for now.')
def build_database_query(self, query=None, config=None):
if query is None:
@ -182,11 +234,13 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
if isinstance(query, SearchQueryShortcut):
return self.build_database_query(query.get_equivalent(), config)
if isinstance(query, Prefix):
self.check_boost(query)
self.is_autocomplete = True
return PostgresSearchAutocomplete(unidecode(query.prefix),
config=config)
if isinstance(query, Term):
# TODO: Find a way to use the term boosting.
if query.boost != 1:
warn('PostgreSQL search backend '
'does not support term boosting for now.')
self.check_boost(query)
return PostgresSearchQuery(unidecode(query.term), config=config)
if isinstance(query, Not):
return ~self.build_database_query(query.subquery, config)
@ -200,49 +254,28 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
'`%s` is not supported by the PostgreSQL search backend.'
% self.query.__class__.__name__)
def get_boost(self, field_name, fields=None):
if fields is None:
fields = self.search_fields
if LOOKUP_SEP in field_name:
field_name, sub_field_name = field_name.split(LOOKUP_SEP, 1)
else:
sub_field_name = None
for field in fields:
if isinstance(field, SearchField) \
and field.field_name == field_name:
# Note: Searching on a specific related field using
# `.search(fields=…)` is not yet supported by Wagtail.
# This method anticipates by already implementing it.
if isinstance(field, RelatedFields):
return self.get_boost(sub_field_name, field.fields)
return field.boost
def search(self, config, start, stop):
# TODO: Handle MatchAll nested inside other search query classes.
if isinstance(self.query, MatchAll):
return self.queryset[start:stop]
search_query = self.build_database_query(config=config)
queryset = self.queryset
query = queryset.query
if self.fields is None:
vector = F('index_entries__body_search')
vector = F('index_entries__autocomplete')
if not self.is_autocomplete:
vector = vector._combine(F('index_entries__body'), '||', False)
else:
vector = ADD(
SearchVector(field, config=search_query.config,
weight=get_weight(self.get_boost(field)))
for field in self.fields)
vector = vector.resolve_expression(query)
search_query = search_query.resolve_expression(query)
lookup = IndexEntry._meta.get_field('body_search').get_lookup('exact')(
vector, search_query)
query.where.add(lookup, 'AND')
SearchVector(field_lookup, config=search_query.config,
weight=get_weight(search_field.boost))
for field_lookup, search_field in self.search_fields.items()
if not self.is_autocomplete or search_field.partial_match)
queryset = self.queryset.annotate(
_vector_=vector).filter(_vector_=search_query)
if self.order_by_relevance:
# Due to a Django bug, arrays are not automatically converted here.
converted_weights = '{' + ','.join(map(str, WEIGHTS_VALUES)) + '}'
queryset = queryset.order_by(SearchRank(vector, search_query,
weights=converted_weights).desc(),
'-pk')
rank_expression = SearchRank(F('_vector_'), search_query,
weights=self.sql_weights)
queryset = queryset.order_by(rank_expression.desc(), '-pk')
elif not queryset.query.order_by:
# Adds a default ordering to avoid issue #3729.
queryset = queryset.order_by('-pk')
@ -268,11 +301,11 @@ class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
class PostgresSearchResults(BaseSearchResults):
def _do_search(self):
return list(self.query_compiler.search(self.backend.get_config(),
return list(self.query_compiler.search(self.backend.config,
self.start, self.stop))
def _do_count(self):
return self.query_compiler.search(self.backend.get_config(), None, None).count()
return self.query_compiler.search(self.backend.config, None, None).count()
class PostgresSearchRebuilder:
@ -317,16 +350,14 @@ class PostgresSearchBackend(BaseSearchBackend):
def __init__(self, params):
super().__init__(params)
self.params = params
self.index_name = params.get('INDEX', 'default')
self.config = params.get('SEARCH_CONFIG')
if params.get('ATOMIC_REBUILD', False):
self.rebuilder_class = self.atomic_rebuilder_class
IndexEntry.add_generic_relations()
def get_config(self):
return self.params.get('SEARCH_CONFIG')
def get_index_for_model(self, model, db_alias=None):
return Index(self, model, db_alias)
return Index(self, db_alias)
def get_index_for_object(self, obj):
return self.get_index_for_model(obj._meta.model, obj._state.db)

View file

@ -42,11 +42,5 @@ class Migration(migrations.Migration):
'CREATE INDEX {0}_body_search ON {0} '
'USING GIN(body_search);'.format(table),
'DROP INDEX {}_body_search;'.format(table),
state_operations=[migrations.AddIndex(
model_name='indexentry',
index=django.contrib.postgres.indexes.GinIndex(
fields=['body_search'],
name='postgres_se_body_se_70ba1a_gin'),
)],
),
]

View file

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.5 on 2017-10-19 14:53
from __future__ import unicode_literals
import django.contrib.postgres.search
from django.db import migrations
from ..models import IndexEntry
table = IndexEntry._meta.db_table
class Migration(migrations.Migration):
dependencies = [
('postgres_search', '0001_initial'),
]
operations = [
migrations.RunSQL(
'DROP INDEX {}_body_search;'.format(table),
'CREATE INDEX {0}_body_search ON {0} '
'USING GIN(body_search);'.format(table),
),
migrations.RenameField(
model_name='indexentry',
old_name='body_search',
new_name='body',
),
migrations.AddField(
model_name='indexentry',
name='autocomplete',
field=django.contrib.postgres.search.SearchVectorField(default=''),
preserve_default=False,
),
migrations.AddIndex(
model_name='indexentry',
index=django.contrib.postgres.indexes.GinIndex(
fields=['autocomplete'],
name='postgres_search_autocomplete'),
),
migrations.AddIndex(
model_name='indexentry',
index=django.contrib.postgres.indexes.GinIndex(
fields=['body'],
name='postgres_search_body'),
),
]

View file

@ -2,7 +2,7 @@ from django.apps import apps
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.indexes import GinIndex
from django.contrib.postgres.search import SearchVectorField
from django.contrib.postgres.search import SearchQuery, SearchVectorField
from django.db.models import CASCADE, ForeignKey, Model, TextField
from django.db.models.functions import Cast
from django.utils.translation import ugettext_lazy as _
@ -11,6 +11,20 @@ from wagtail.search.index import class_is_indexed
from .utils import get_descendants_content_types_pks
class SearchAutocomplete(SearchQuery):
def as_sql(self, compiler, connection):
params = [self.value]
if self.config:
config_sql, config_params = compiler.compile(self.config)
template = "to_tsquery({}::regconfig, ''%s':*')".format(config_sql)
params = config_params + [self.value]
else:
template = "to_tsquery(''%s':*')"
if self.invert:
template = '!!({})'.format(template)
return template, params
class TextIDGenericRelation(GenericRelation):
auto_created = True
@ -46,13 +60,15 @@ class IndexEntry(Model):
content_object = GenericForeignKey()
# TODO: Add per-object boosting.
body_search = SearchVectorField()
autocomplete = SearchVectorField()
body = SearchVectorField()
class Meta:
unique_together = ('content_type', 'object_id')
verbose_name = _('index entry')
verbose_name_plural = _('index entries')
indexes = [GinIndex(fields=['body_search'])]
indexes = [GinIndex(fields=['autocomplete']),
GinIndex(fields=['body'])]
def __str__(self):
return '%s: %s' % (self.content_type.name, self.content_object)

View file

@ -94,6 +94,19 @@ def determine_boosts_weights(boosts=()):
for i, weight in enumerate(WEIGHTS)]
def set_weights():
BOOSTS_WEIGHTS.extend(determine_boosts_weights())
weights = [w for w, c in BOOSTS_WEIGHTS]
min_weight = min(weights)
max_weight = max(weights)
if min_weight <= 0:
if min_weight == 0:
min_weight = -0.1
weights = [w - min_weight for w in weights]
WEIGHTS_VALUES.extend([w / max_weight
for w in reversed(weights)])
def get_weight(boost):
if boost is None:
return WEIGHTS[-1]
@ -101,3 +114,7 @@ def get_weight(boost):
if boost >= max_boost:
return weight
return weight
def get_sql_weights():
return '{' + ','.join(map(str, WEIGHTS_VALUES)) + '}'

View file

@ -5,7 +5,8 @@ from django.db.models.expressions import Value
from wagtail.search.backends.base import (
BaseSearchBackend, BaseSearchQueryCompiler, BaseSearchResults)
from wagtail.search.query import And, MatchAll, Not, Or, SearchQueryShortcut, Term
from wagtail.search.query import (
And, MatchAll, Not, Or, Prefix, SearchQueryShortcut, Term)
from wagtail.search.utils import AND, OR
@ -51,6 +52,10 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler):
term_query |= models.Q(**{field_name + '__icontains': term})
return term_query
def check_boost(self, query):
if query.boost != 1:
warn('Database search backend does not support term boosting.')
def build_database_filter(self, query=None):
if query is None:
query = self.query
@ -61,9 +66,11 @@ class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler):
if isinstance(query, SearchQueryShortcut):
return self.build_database_filter(query.get_equivalent())
if isinstance(query, Term):
if query.boost != 1:
warn('Database search backend does not support term boosting.')
self.check_boost(query)
return self.build_single_term_filter(query.term)
if isinstance(query, Prefix):
self.check_boost(query)
return self.build_single_term_filter(query.prefix)
if isinstance(query, Not):
return ~self.build_database_filter(query.subquery)
if isinstance(query, And):

View file

@ -1,3 +1,4 @@
import unittest
from datetime import date
from io import StringIO
@ -171,3 +172,14 @@ class ElasticsearchCommonSearchBackendTests(BackendTests):
results = self.backend.search(MATCH_ALL, models.Book)[110:]
self.assertEqual(len(results), 53)
# Elasticsearch always does prefix matching on `partial_match` fields,
# even when we dont use `Prefix`.
@unittest.expectedFailure
def test_incomplete_term(self):
super().test_incomplete_term()
# Elasticsearch does not accept prefix for multiple words
@unittest.expectedFailure
def test_prefix_multiple_words(self):
super().test_prefix_multiple_words()

View file

@ -15,7 +15,9 @@ from wagtail.search.backends import (
InvalidSearchBackendError, get_search_backend, get_search_backends)
from wagtail.search.backends.base import FieldError
from wagtail.search.backends.db import DatabaseSearchBackend
from wagtail.search.query import MATCH_ALL, And, Boost, Filter, Not, Or, PlainText, Term
from wagtail.search.query import (
MATCH_ALL, And, Boost, Filter, Not, Or, PlainText, Prefix, Term,
)
class BackendTests(WagtailTestUtils):
@ -448,6 +450,13 @@ class BackendTests(WagtailTestUtils):
{'JavaScript: The Definitive Guide',
'JavaScript: The good parts'})
def test_incomplete_term(self):
# Single word
results = self.backend.search(Term('pro'),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results}, set())
def test_and(self):
results = self.backend.search(And([Term('javascript'),
Term('definitive')]),
@ -507,6 +516,17 @@ class BackendTests(WagtailTestUtils):
'The Rust Programming Language',
'Two Scoops of Django 1.11'})
def test_prefix_single_word(self):
results = self.backend.search(Prefix('pro'), models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'The Rust Programming Language'})
def test_prefix_multiple_words(self):
results = self.backend.search(Prefix('rust pro'),
models.Book.objects.all())
self.assertSetEqual({r.title for r in results},
{'The Rust Programming Language'})
#
# Shortcut query classes
#

View file

@ -37,3 +37,8 @@ class TestDBBackend(BackendTests, TestCase):
@unittest.expectedFailure
def test_search_callable_field(self):
super().test_search_callable_field()
# Database backend always uses `icontains`, so always autocomplete
@unittest.expectedFailure
def test_incomplete_term(self):
super().test_incomplete_term()