From bc0760d23d3c75bb43cb26b3f6f8fef21ec8cf2b Mon Sep 17 00:00:00 2001 From: Karl Hobley Date: Mon, 12 Oct 2015 12:14:13 +0100 Subject: [PATCH] Allow operator to be specified when searching --- wagtail/wagtailcore/query.py | 4 ++-- wagtail/wagtaildocs/models.py | 4 ++-- wagtail/wagtailimages/models.py | 4 ++-- wagtail/wagtailsearch/backends/base.py | 15 ++++++++++++--- wagtail/wagtailsearch/backends/db.py | 7 ++++++- wagtail/wagtailsearch/backends/elasticsearch.py | 8 ++++++++ wagtail/wagtailsearch/tests/test_backends.py | 12 ++++++++++++ .../tests/test_elasticsearch_backend.py | 16 ++++++++++++++++ 8 files changed, 60 insertions(+), 10 deletions(-) diff --git a/wagtail/wagtailcore/query.py b/wagtail/wagtailcore/query.py index 1d9f409f0..a1e1a8979 100644 --- a/wagtail/wagtailcore/query.py +++ b/wagtail/wagtailcore/query.py @@ -197,12 +197,12 @@ class PageQuerySet(MP_NodeQuerySet): """ return self.exclude(self.public_q()) - def search(self, query_string, fields=None, backend='default'): + def search(self, query_string, fields=None, operator=None, backend='default'): """ This runs a search query on all the pages in the QuerySet """ search_backend = get_search_backend(backend) - return search_backend.search(query_string, self, fields=fields) + return search_backend.search(query_string, self, fields=fields, operator=operator) def unpublish(self): """ diff --git a/wagtail/wagtaildocs/models.py b/wagtail/wagtaildocs/models.py index 61d44fb22..044786483 100644 --- a/wagtail/wagtaildocs/models.py +++ b/wagtail/wagtaildocs/models.py @@ -20,12 +20,12 @@ from wagtail.wagtailsearch.backends import get_search_backend class DocumentQuerySet(models.QuerySet): - def search(self, query_string, fields=None, backend='default'): + def search(self, query_string, fields=None, operator=None, backend='default'): """ This runs a search query on all the documents in the QuerySet """ search_backend = get_search_backend(backend) - return search_backend.search(query_string, self, fields=fields) + return search_backend.search(query_string, self, fields=fields, operator=operator) @python_2_unicode_compatible diff --git a/wagtail/wagtailimages/models.py b/wagtail/wagtailimages/models.py index ac8eb521f..8d3d56bdd 100644 --- a/wagtail/wagtailimages/models.py +++ b/wagtail/wagtailimages/models.py @@ -43,12 +43,12 @@ class SourceImageIOError(IOError): class ImageQuerySet(models.QuerySet): - def search(self, query_string, fields=None, backend='default'): + def search(self, query_string, fields=None, operator=None, backend='default'): """ This runs a search query on all the images in the QuerySet """ search_backend = get_search_backend(backend) - return search_backend.search(query_string, self, fields=fields) + return search_backend.search(query_string, self, fields=fields, operator=operator) def get_upload_to(instance, filename): diff --git a/wagtail/wagtailsearch/backends/base.py b/wagtail/wagtailsearch/backends/base.py index a2d088d5b..abf7e5609 100644 --- a/wagtail/wagtailsearch/backends/base.py +++ b/wagtail/wagtailsearch/backends/base.py @@ -16,10 +16,13 @@ class FieldError(Exception): class BaseSearchQuery(object): - def __init__(self, queryset, query_string, fields=None): + DEFAULT_OPERATOR = 'or' + + def __init__(self, queryset, query_string, fields=None, operator=None): self.queryset = queryset self.query_string = query_string self.fields = fields + self.operator = operator or self.DEFAULT_OPERATOR def _get_searchable_field(self, field_attname): # Get field @@ -200,7 +203,7 @@ class BaseSearch(object): def delete(self, obj): raise NotImplementedError - def search(self, query_string, model_or_queryset, fields=None, filters=None, prefetch_related=None): + def search(self, query_string, model_or_queryset, fields=None, filters=None, prefetch_related=None, operator=None): # Find model/queryset if isinstance(model_or_queryset, QuerySet): model = model_or_queryset.model @@ -226,6 +229,12 @@ class BaseSearch(object): for prefetch in prefetch_related: queryset = queryset.prefetch_related(prefetch) + # Check operator + if operator is not None: + operator = operator.lower() + if operator not in ['or', 'and']: + raise ValueError("operator must be either 'or' or 'and'") + # Search - search_query = self.search_query_class(queryset, query_string, fields=fields) + search_query = self.search_query_class(queryset, query_string, fields=fields, operator=operator) return self.search_results_class(self, search_query) diff --git a/wagtail/wagtailsearch/backends/db.py b/wagtail/wagtailsearch/backends/db.py index 538f2b54c..9e9a183ef 100644 --- a/wagtail/wagtailsearch/backends/db.py +++ b/wagtail/wagtailsearch/backends/db.py @@ -4,6 +4,8 @@ from wagtail.wagtailsearch.backends.base import BaseSearch, BaseSearchQuery, Bas class DBSearchQuery(BaseSearchQuery): + DEFAULT_OPERATOR = 'and' + def _process_lookup(self, field, lookup, value): return models.Q(**{field.get_attname(self.queryset.model) + '__' + lookup: value}) @@ -52,7 +54,10 @@ class DBSearchQuery(BaseSearchQuery): # Filter on this field term_query |= models.Q(**{'%s__icontains' % field_name: term}) - q &= term_query + if self.operator == 'or': + q |= term_query + elif self.operator == 'and': + q &= term_query return q diff --git a/wagtail/wagtailsearch/backends/elasticsearch.py b/wagtail/wagtailsearch/backends/elasticsearch.py index 586d76a81..03da5e2ee 100644 --- a/wagtail/wagtailsearch/backends/elasticsearch.py +++ b/wagtail/wagtailsearch/backends/elasticsearch.py @@ -156,6 +156,8 @@ class ElasticSearchMapping(object): class ElasticSearchQuery(BaseSearchQuery): + DEFAULT_OPERATOR = 'or' + def _process_lookup(self, field, lookup, value): # Get the name of the field in the index field_index_name = field.get_index_name(self.queryset.model) @@ -254,6 +256,9 @@ class ElasticSearchQuery(BaseSearchQuery): fields[0]: self.query_string, } } + + if self.operator != 'or': + query['match']['operator'] = self.operator else: query = { 'multi_match': { @@ -261,6 +266,9 @@ class ElasticSearchQuery(BaseSearchQuery): 'fields': fields, } } + + if self.operator != 'or': + query['multi_match']['operator'] = self.operator else: query = { 'match_all': {} diff --git a/wagtail/wagtailsearch/tests/test_backends.py b/wagtail/wagtailsearch/tests/test_backends.py index 759d58c79..75d2939ba 100644 --- a/wagtail/wagtailsearch/tests/test_backends.py +++ b/wagtail/wagtailsearch/tests/test_backends.py @@ -76,6 +76,18 @@ class BackendTests(WagtailTestUtils): results = self.backend.search("World", models.SearchTest) self.assertEqual(set(results), {self.testa, self.testd.searchtest_ptr}) + def test_operator_or(self): + # All records that match any term should be returned + results = self.backend.search("Hello world", models.SearchTest, operator='or') + + self.assertEqual(set(results), {self.testa, self.testb, self.testc.searchtest_ptr, self.testd.searchtest_ptr}) + + def test_operator_and(self): + # Records must match all search terms to be returned + results = self.backend.search("Hello world", models.SearchTest, operator='and') + + self.assertEqual(set(results), {self.testa}) + def test_callable_indexed_field(self): results = self.backend.search("Callable", models.SearchTest) self.assertEqual(set(results), {self.testa, self.testb, self.testc.searchtest_ptr, self.testd.searchtest_ptr}) diff --git a/wagtail/wagtailsearch/tests/test_elasticsearch_backend.py b/wagtail/wagtailsearch/tests/test_elasticsearch_backend.py index 47cde1ccb..eddd7ad6f 100644 --- a/wagtail/wagtailsearch/tests/test_elasticsearch_backend.py +++ b/wagtail/wagtailsearch/tests/test_elasticsearch_backend.py @@ -201,6 +201,14 @@ class TestElasticSearchQuery(TestCase): expected_result = {'filtered': {'filter': {'prefix': {'content_type': 'searchtests_searchtest'}}, 'query': {'match_all': {}}}} self.assertDictEqual(query.to_es(), expected_result) + def test_and_operator(self): + # Create a query + query = self.ElasticSearchQuery(models.SearchTest.objects.all(), "Hello", operator='and') + + # Check it + expected_result = {'filtered': {'filter': {'prefix': {'content_type': 'searchtests_searchtest'}}, 'query': {'multi_match': {'query': 'Hello', 'fields': ['_all', '_partials'], 'operator': 'and'}}}} + self.assertDictEqual(query.to_es(), expected_result) + def test_filter(self): # Create a query query = self.ElasticSearchQuery(models.SearchTest.objects.filter(title="Test"), "Hello") @@ -252,6 +260,14 @@ class TestElasticSearchQuery(TestCase): expected_result = {'filtered': {'filter': {'prefix': {'content_type': 'searchtests_searchtest'}}, 'query': {'match': {'title': 'Hello'}}}} self.assertDictEqual(query.to_es(), expected_result) + def test_fields_with_and_operator(self): + # Create a query + query = self.ElasticSearchQuery(models.SearchTest.objects.all(), "Hello", fields=['title'], operator='and') + + # Check it + expected_result = {'filtered': {'filter': {'prefix': {'content_type': 'searchtests_searchtest'}}, 'query': {'match': {'title': 'Hello', 'operator': 'and'}}}} + self.assertDictEqual(query.to_es(), expected_result) + def test_exact_lookup(self): # Create a query query = self.ElasticSearchQuery(models.SearchTest.objects.filter(title__exact="Test"), "Hello")