diff --git a/wagtail/wagtailsearch/tests/test_elasticsearch_backend.py b/wagtail/wagtailsearch/tests/test_elasticsearch_backend.py index 7cccf6ba4..df928a08d 100644 --- a/wagtail/wagtailsearch/tests/test_elasticsearch_backend.py +++ b/wagtail/wagtailsearch/tests/test_elasticsearch_backend.py @@ -5,6 +5,8 @@ import unittest import datetime import json +import mock + from django.test import TestCase from django.db.models import Q @@ -334,6 +336,191 @@ class TestElasticSearchQuery(TestCase): self.assertDictEqual(query.to_es(), expected_result) +class TestElasticSearchResults(TestCase): + def assertDictEqual(self, a, b): + default = self.JSONSerializer().default + self.assertEqual(json.dumps(a, sort_keys=True, default=default), json.dumps(b, sort_keys=True, default=default)) + + def setUp(self): + # Import using a try-catch block to prevent crashes if the elasticsearch-py + # module is not installed + try: + from wagtail.wagtailsearch.backends.elasticsearch import ElasticSearch + from wagtail.wagtailsearch.backends.elasticsearch import ElasticSearchResults + from elasticsearch.serializer import JSONSerializer + except ImportError: + raise unittest.SkipTest("elasticsearch-py not installed") + + self.ElasticSearch = ElasticSearch + self.ElasticSearchResults = ElasticSearchResults + self.JSONSerializer = JSONSerializer + + self.objects = [] + + for i in range(3): + self.objects.append(models.SearchTest.objects.create(title=str(i))) + + def get_results(self): + backend = self.ElasticSearch({}) + query = mock.MagicMock() + query.queryset = models.SearchTest.objects.all() + query.to_es.return_value = 'QUERY' + return self.ElasticSearchResults(backend, query) + + def construct_search_response(self, results): + return { + '_shards': {'failed': 0, 'successful': 5, 'total': 5}, + 'hits': { + 'hits': [ + { + '_id': 'searchtests_searchtest:' + str(result), + '_index': 'wagtail', + '_score': 1, + '_type': 'searchtests_searchtest', + 'fields': { + 'pk': [str(result)], + } + } + for result in results + ], + 'max_score': 1, + 'total': len(results) + }, + 'timed_out': False, + 'took': 2 + } + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_basic_search(self, search): + search.return_value = self.construct_search_response([]) + results = self.get_results() + + list(results) # Performs search + + search.assert_any_call( + from_=0, + body={'query': 'QUERY'}, + _source=False, + fields='pk', + index='wagtail' + ) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_get_single_item(self, search): + # Need to return something to prevent index error + search.return_value = self.construct_search_response([self.objects[0].id]) + results = self.get_results() + + results[10] # Performs search + + search.assert_any_call( + from_=10, + body={'query': 'QUERY'}, + _source=False, + fields='pk', + index='wagtail', + size=1 + ) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_slice_results(self, search): + search.return_value = self.construct_search_response([]) + results = self.get_results()[1:4] + + list(results) # Performs search + + search.assert_any_call( + from_=1, + body={'query': 'QUERY'}, + _source=False, + fields='pk', + index='wagtail', + size=3 + ) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_slice_results_multiple_times(self, search): + search.return_value = self.construct_search_response([]) + results = self.get_results()[10:][:10] + + list(results) # Performs search + + search.assert_any_call( + from_=10, + body={'query': 'QUERY'}, + _source=False, + fields='pk', + index='wagtail', + size=10 + ) + + @unittest.expectedFailure # 1271 + @mock.patch('elasticsearch.Elasticsearch.search') + def test_slice_results_and_get_item(self, search): + # Need to return something to prevent index error + search.return_value = self.construct_search_response([self.objects[0].id]) + results = self.get_results()[10:] + + results[10] # Performs search + + search.assert_any_call( + from_=20, + body={'query': 'QUERY'}, + _source=False, + fields='pk', + index='wagtail', + size=1 + ) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_result_returned(self, search): + search.return_value = self.construct_search_response([self.objects[0].id]) + results = self.get_results() + + self.assertEqual(results[0], self.objects[0]) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_len_1(self, search): + search.return_value = self.construct_search_response([self.objects[0].id]) + results = self.get_results() + + self.assertEqual(len(results), 1) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_len_2(self, search): + search.return_value = self.construct_search_response([self.objects[0].id, self.objects[1].id]) + results = self.get_results() + + self.assertEqual(len(results), 2) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_duplicate_results(self, search): # Duplicates will not be removed + search.return_value = self.construct_search_response([self.objects[0].id, self.objects[0].id]) + results = list(self.get_results()) # Must cast to list so we only create one query + + self.assertEqual(len(results), 2) + self.assertEqual(results[0], self.objects[0]) + self.assertEqual(results[1], self.objects[0]) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_result_order(self, search): + search.return_value = self.construct_search_response([self.objects[0].id, self.objects[1].id, self.objects[2].id]) + results = list(self.get_results()) # Must cast to list so we only create one query + + self.assertEqual(results[0], self.objects[0]) + self.assertEqual(results[1], self.objects[1]) + self.assertEqual(results[2], self.objects[2]) + + @mock.patch('elasticsearch.Elasticsearch.search') + def test_result_order_2(self, search): + search.return_value = self.construct_search_response([self.objects[2].id, self.objects[1].id, self.objects[0].id]) + results = list(self.get_results()) # Must cast to list so we only create one query + + self.assertEqual(results[0], self.objects[2]) + self.assertEqual(results[1], self.objects[1]) + self.assertEqual(results[2], self.objects[0]) + + class TestElasticSearchMapping(TestCase): def assertDictEqual(self, a, b): default = self.JSONSerializer().default