diff --git a/cachalot/tests/db_router.py b/cachalot/tests/db_router.py index e7e3267..de711ee 100644 --- a/cachalot/tests/db_router.py +++ b/cachalot/tests/db_router.py @@ -13,5 +13,6 @@ class PostgresRouter(object): else 'default') def allow_migrate(self, db, app_label, model=None, **hints): - if model is not None and self.in_postgres(model): + if hints.get('extension') in ('hstore', 'unaccent') \ + or (model is not None and self.in_postgres(model)): return db == self.get_postgresql_alias() diff --git a/cachalot/tests/migrations/0001_initial.py b/cachalot/tests/migrations/0001_initial.py index 0087931..71a2e48 100644 --- a/cachalot/tests/migrations/0001_initial.py +++ b/cachalot/tests/migrations/0001_initial.py @@ -15,6 +15,10 @@ class Migration(migrations.Migration): ] operations = [ + migrations.RunSQL('CREATE EXTENSION hstore;', + hints={'extension': 'hstore'}), + migrations.RunSQL('CREATE EXTENSION unaccent;', + hints={'extension': 'unaccent'}), migrations.CreateModel( name='Test', fields=[ @@ -50,7 +54,9 @@ class Migration(migrations.Migration): ] if django_version >= (1, 8): - from django.contrib.postgres.fields import ArrayField, IntegerRangeField + from django.contrib.postgres.fields import ( + ArrayField, HStoreField, + IntegerRangeField, FloatRangeField, DateRangeField, DateTimeRangeField) Migration.operations.append( migrations.CreateModel( @@ -61,7 +67,11 @@ if django_version >= (1, 8): ('int_array', ArrayField( models.IntegerField(null=True, blank=True), size=3, null=True, blank=True)), + ('hstore', HStoreField(null=True, blank=True)), ('int_range', IntegerRangeField(null=True, blank=True)), + ('float_range', FloatRangeField(null=True, blank=True)), + ('date_range', DateRangeField(null=True, blank=True)), + ('datetime_range', DateTimeRangeField(null=True, blank=True)), ], ) ) diff --git a/cachalot/tests/models.py b/cachalot/tests/models.py index dcb6eea..fe88fa1 100644 --- a/cachalot/tests/models.py +++ b/cachalot/tests/models.py @@ -32,9 +32,18 @@ class TestChild(TestParent): if django_version >= (1, 8): - from django.contrib.postgres.fields import ArrayField, IntegerRangeField + from django.contrib.postgres.fields import ( + ArrayField, HStoreField, + IntegerRangeField, FloatRangeField, DateRangeField, DateTimeRangeField) + class PostgresModel(Model): int_array = ArrayField(IntegerField(null=True, blank=True), size=3, null=True, blank=True) + + hstore = HStoreField(null=True, blank=True) + int_range = IntegerRangeField(null=True, blank=True) + float_range = FloatRangeField(null=True, blank=True) + date_range = DateRangeField(null=True, blank=True) + datetime_range = DateTimeRangeField(null=True, blank=True) diff --git a/cachalot/tests/postgres.py b/cachalot/tests/postgres.py index 275f04c..f4d95de 100644 --- a/cachalot/tests/postgres.py +++ b/cachalot/tests/postgres.py @@ -1,6 +1,8 @@ # coding: utf-8 from __future__ import unicode_literals +from datetime import date, datetime +from decimal import Decimal from platform import python_version_tuple from unittest import skipUnless, skipIf @@ -8,10 +10,11 @@ from django import VERSION as django_version from django.core.cache import caches from django.core.cache.backends.filebased import FileBasedCache from django.db import connection -from django.test import TransactionTestCase -from psycopg2._range import NumericRange +from django.test import TransactionTestCase, override_settings +from psycopg2._range import NumericRange, DateRange, DateTimeTZRange +from pytz import timezone -from .models import PostgresModel +from .models import PostgresModel, Test @skipUnless(connection.vendor == 'postgresql' and django_version[:2] >= (1, 8), @@ -20,23 +23,292 @@ from .models import PostgresModel and python_version_tuple()[:2] == ('2', '7'), 'Caching psycopg2 objects is not working with file-based cache ' 'and Python 2.7.') +@override_settings(USE_TZ=True) class PostgresReadTest(TransactionTestCase): def setUp(self): - self.obj = PostgresModel.objects.create( - int_array=[1, 2, 3], int_range=[1900, 2000]) + PostgresModel.objects.create( + int_array=[1, 2, 3], + hstore={'a': 'b', 'c': None}, + int_range=[1900, 2000], float_range=[-1e3, 9.87654321], + date_range=['1678-03-04', '1741-07-28'], + datetime_range=[datetime(1989, 1, 30, 12, 20, + tzinfo=timezone('Europe/Paris')), None]) + PostgresModel.objects.create( + int_array=[4, None, 6], + hstore={'a': '1', 'b': '2'}, + int_range=[1989, None], float_range=[0.0, None], + date_range=['1989-01-30', None], + datetime_range=[None, None]) + + def test_unaccent(self): + Test.objects.create(name='Clémentine') + Test.objects.create(name='Clementine') + qs = Test.objects.filter(name__unaccent='Clémentine') + with self.assertNumQueries(1): + data1 = [t.name for t in qs.all()] + with self.assertNumQueries(0): + data2 = [t.name for t in qs.all()] + self.assertListEqual(data2, data1) + self.assertListEqual(data2, ['Clementine', 'Clémentine']) def test_int_array(self): + qs = PostgresModel.objects.all() with self.assertNumQueries(1): - data1 = [o.int_array for o in PostgresModel.objects.all()] + data1 = [o.int_array for o in qs.all()] with self.assertNumQueries(0): - data2 = [o.int_array for o in PostgresModel.objects.all()] + data2 = [o.int_array for o in qs.all()] self.assertListEqual(data2, data1) - self.assertListEqual(data2, [[1, 2, 3]]) + self.assertListEqual(data2, [[1, 2, 3], [4, None, 6]]) + + qs = PostgresModel.objects.filter(int_array__contains=[3]) + with self.assertNumQueries(1): + data3 = [o.int_array for o in qs.all()] + with self.assertNumQueries(0): + data4 = [o.int_array for o in qs.all()] + self.assertListEqual(data4, data3) + self.assertListEqual(data4, [[1, 2, 3]]) + + qs = PostgresModel.objects.filter(int_array__contained_by=[1, 2, 3, + 4, 5, 6]) + with self.assertNumQueries(1): + data7 = [o.int_array for o in qs.all()] + with self.assertNumQueries(0): + data8 = [o.int_array for o in qs.all()] + self.assertListEqual(data8, data7) + self.assertListEqual(data8, [[1, 2, 3]]) + + qs = PostgresModel.objects.filter(int_array__overlap=[3, 4]) + with self.assertNumQueries(1): + data9 = [o.int_array for o in qs.all()] + with self.assertNumQueries(0): + data10 = [o.int_array for o in qs.all()] + self.assertListEqual(data10, data9) + self.assertListEqual(data10, [[1, 2, 3], [4, None, 6]]) + + qs = PostgresModel.objects.filter(int_array__len__in=(2, 3)) + with self.assertNumQueries(1): + data11 = [o.int_array for o in qs.all()] + with self.assertNumQueries(0): + data12 = [o.int_array for o in qs.all()] + self.assertListEqual(data12, data11) + self.assertListEqual(data12, [[1, 2, 3], [4, None, 6]]) + + qs = PostgresModel.objects.filter(int_array__2=6) + with self.assertNumQueries(1): + data13 = [o.int_array for o in qs.all()] + with self.assertNumQueries(0): + data14 = [o.int_array for o in qs.all()] + self.assertListEqual(data14, data13) + self.assertListEqual(data14, [[4, None, 6]]) + + qs = PostgresModel.objects.filter(int_array__0_2=(1, 2)) + with self.assertNumQueries(1): + data15 = [o.int_array for o in qs.all()] + with self.assertNumQueries(0): + data16 = [o.int_array for o in qs.all()] + self.assertListEqual(data16, data15) + self.assertListEqual(data16, [[1, 2, 3]]) + + def test_hstore(self): + qs = PostgresModel.objects.all() + with self.assertNumQueries(1): + data1 = [o.hstore for o in qs.all()] + with self.assertNumQueries(0): + data2 = [o.hstore for o in qs.all()] + self.assertListEqual(data2, data1) + self.assertListEqual(data2, [{'a': 'b', 'c': None}, + {'a': '1', 'b': '2'}]) + + qs = PostgresModel.objects.filter(hstore__a='1') + with self.assertNumQueries(1): + data3 = [o.hstore for o in qs.all()] + with self.assertNumQueries(0): + data4 = [o.hstore for o in qs.all()] + self.assertListEqual(data4, data3) + self.assertListEqual(data4, [{'a': '1', 'b': '2'}]) + + qs = PostgresModel.objects.filter( + hstore__contains={'a': 'b'}) + with self.assertNumQueries(1): + data5 = [o.hstore for o in qs.all()] + with self.assertNumQueries(0): + data6 = [o.hstore for o in qs.all()] + self.assertListEqual(data6, data5) + self.assertListEqual(data6, [{'a': 'b', 'c': None}]) + + qs = PostgresModel.objects.filter( + hstore__contained_by={'a': 'b', 'c': None, 'b': '2'}) + with self.assertNumQueries(1): + data7 = [o.hstore for o in qs.all()] + with self.assertNumQueries(0): + data8 = [o.hstore for o in qs.all()] + self.assertListEqual(data8, data7) + self.assertListEqual(data8, [{'a': 'b', 'c': None}]) + + qs = PostgresModel.objects.filter(hstore__has_key='c') + with self.assertNumQueries(1): + data9 = [o.hstore for o in qs.all()] + with self.assertNumQueries(0): + data10 = [o.hstore for o in qs.all()] + self.assertListEqual(data10, data9) + self.assertListEqual(data10, [{'a': 'b', 'c': None}]) + + qs = PostgresModel.objects.filter(hstore__has_keys=['a', 'b']) + with self.assertNumQueries(1): + data11 = [o.hstore for o in qs.all()] + with self.assertNumQueries(0): + data12 = [o.hstore for o in qs.all()] + self.assertListEqual(data12, data11) + self.assertListEqual(data12, [{'a': '1', 'b': '2'}]) + + qs = PostgresModel.objects.filter(hstore__keys=['a', 'b']) + with self.assertNumQueries(1): + data13 = [o.hstore for o in qs.all()] + with self.assertNumQueries(0): + data14 = [o.hstore for o in qs.all()] + self.assertListEqual(data14, data13) + self.assertListEqual(data14, [{'a': '1', 'b': '2'}]) + + qs = PostgresModel.objects.filter(hstore__values=['1', '2']) + with self.assertNumQueries(1): + data15 = [o.hstore for o in qs.all()] + with self.assertNumQueries(0): + data16 = [o.hstore for o in qs.all()] + self.assertListEqual(data16, data15) + self.assertListEqual(data16, [{'a': '1', 'b': '2'}]) def test_int_range(self): + qs = PostgresModel.objects.all() with self.assertNumQueries(1): - data1 = [o.int_range for o in PostgresModel.objects.all()] + data1 = [o.int_range for o in qs.all()] with self.assertNumQueries(0): - data2 = [o.int_range for o in PostgresModel.objects.all()] + data2 = [o.int_range for o in qs.all()] self.assertListEqual(data2, data1) - self.assertListEqual(data2, [NumericRange(1900, 2000)]) + self.assertListEqual(data2, [NumericRange(1900, 2000), + NumericRange(1989)]) + + qs = PostgresModel.objects.filter(int_range__contains=2015) + with self.assertNumQueries(1): + data3 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data4 = [o.int_range for o in qs.all()] + self.assertListEqual(data4, data3) + self.assertListEqual(data4, [NumericRange(1989)]) + + qs = PostgresModel.objects.filter( + int_range__contains=NumericRange(1950, 1990)) + with self.assertNumQueries(1): + data5 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data6 = [o.int_range for o in qs.all()] + self.assertListEqual(data6, data5) + self.assertListEqual(data6, [NumericRange(1900, 2000)]) + + qs = PostgresModel.objects.filter( + int_range__contained_by=NumericRange(0, 2050)) + with self.assertNumQueries(1): + data5 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data6 = [o.int_range for o in qs.all()] + self.assertListEqual(data6, data5) + self.assertListEqual(data6, [NumericRange(1900, 2000)]) + + qs = PostgresModel.objects.filter(int_range__fully_lt=(2015, None)) + with self.assertNumQueries(1): + data7 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data8 = [o.int_range for o in qs.all()] + self.assertListEqual(data8, data7) + self.assertListEqual(data8, [NumericRange(1900, 2000)]) + + qs = PostgresModel.objects.filter(int_range__fully_gt=(1970, 1980)) + with self.assertNumQueries(1): + data9 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data10 = [o.int_range for o in qs.all()] + self.assertListEqual(data10, data9) + self.assertListEqual(data10, [NumericRange(1989)]) + + qs = PostgresModel.objects.filter(int_range__not_lt=(1970, 1980)) + with self.assertNumQueries(1): + data11 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data12 = [o.int_range for o in qs.all()] + self.assertListEqual(data12, data11) + self.assertListEqual(data12, [NumericRange(1989)]) + + qs = PostgresModel.objects.filter(int_range__not_gt=(1970, 1980)) + with self.assertNumQueries(1): + data13 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data14 = [o.int_range for o in qs.all()] + self.assertListEqual(data14, data13) + self.assertListEqual(data14, []) + + qs = PostgresModel.objects.filter(int_range__adjacent_to=(1900, 1989)) + with self.assertNumQueries(1): + data15 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data16 = [o.int_range for o in qs.all()] + self.assertListEqual(data16, data15) + self.assertListEqual(data16, [NumericRange(1989)]) + + qs = PostgresModel.objects.filter(int_range__startswith=1900) + with self.assertNumQueries(1): + data17 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data18 = [o.int_range for o in qs.all()] + self.assertListEqual(data18, data17) + self.assertListEqual(data18, [NumericRange(1900, 2000)]) + + qs = PostgresModel.objects.filter(int_range__endswith=2000) + with self.assertNumQueries(1): + data19 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data20 = [o.int_range for o in qs.all()] + self.assertListEqual(data20, data19) + self.assertListEqual(data20, [NumericRange(1900, 2000)]) + + PostgresModel.objects.create(int_range=[1900, 1900]) + + qs = PostgresModel.objects.filter(int_range__isempty=True) + with self.assertNumQueries(1): + data21 = [o.int_range for o in qs.all()] + with self.assertNumQueries(0): + data22 = [o.int_range for o in qs.all()] + self.assertListEqual(data22, data21) + self.assertListEqual(data22, [NumericRange(empty=True)]) + + def test_float_range(self): + with self.assertNumQueries(1): + data1 = [o.float_range for o in PostgresModel.objects.all()] + with self.assertNumQueries(0): + data2 = [o.float_range for o in PostgresModel.objects.all()] + self.assertListEqual(data2, data1) + # For a strange reason, probably a misconception in psycopg2 + # or a bad name in django.contrib.postgres (less probable), + # FloatRange returns decimals instead of floats. + self.assertListEqual(data2, [ + NumericRange(Decimal('-1000.0'), Decimal('9.87654321')), + NumericRange(Decimal('0.0'))]) + + def test_date_range(self): + with self.assertNumQueries(1): + data1 = [o.date_range for o in PostgresModel.objects.all()] + with self.assertNumQueries(0): + data2 = [o.date_range for o in PostgresModel.objects.all()] + self.assertListEqual(data2, data1) + self.assertListEqual(data2, [ + DateRange(date(1678, 3, 4), date(1741, 7, 28)), + DateRange(date(1989, 1, 30))]) + + def test_datetime_range(self): + with self.assertNumQueries(1): + data1 = [o.datetime_range for o in PostgresModel.objects.all()] + with self.assertNumQueries(0): + data2 = [o.datetime_range for o in PostgresModel.objects.all()] + self.assertListEqual(data2, data1) + self.assertListEqual(data2, [ + DateTimeTZRange(datetime(1989, 1, 30, 12, 20, + tzinfo=timezone('Europe/Paris'))), + DateTimeTZRange(bounds='()')]) diff --git a/cachalot/utils.py b/cachalot/utils.py index df0c75f..65be424 100644 --- a/cachalot/utils.py +++ b/cachalot/utils.py @@ -19,10 +19,32 @@ class UncachableQuery(Exception): pass -CACHABLE_PARAM_TYPES = frozenset(( +CACHABLE_PARAM_TYPES = { bool, int, binary_type, text_type, type(None), - datetime.date, datetime.time, datetime.datetime, datetime.timedelta -)) + datetime.date, datetime.time, datetime.datetime, datetime.timedelta, +} + +try: + from psycopg2._range import ( + NumericRange, DateRange, DateTimeRange, DateTimeTZRange) +except ImportError: + pass +else: + CACHABLE_PARAM_TYPES.update(( + NumericRange, DateRange, DateTimeRange, DateTimeTZRange)) + + +def check_parameter_types(params): + for p in params: + cl = p.__class__ + if cl not in CACHABLE_PARAM_TYPES: + if cl is list or cl is tuple: + check_parameter_types(p) + elif cl is dict: + check_parameter_types(p.items()) + else: + raise UncachableQuery + def get_query_cache_key(compiler): """ @@ -38,8 +60,7 @@ def get_query_cache_key(compiler): :rtype: str """ sql, params = compiler.as_sql() - if any(p.__class__ not in CACHABLE_PARAM_TYPES for p in params): - raise UncachableQuery + check_parameter_types(params) cache_key = '%s:%s:%s' % (compiler.using, sql, params) return sha1(cache_key.encode('utf-8')).hexdigest() diff --git a/settings.py b/settings.py index 8927e9e..e775579 100644 --- a/settings.py +++ b/settings.py @@ -89,6 +89,7 @@ INSTALLED_APPS = [ 'cachalot', 'django.contrib.auth', 'django.contrib.contenttypes', + 'django.contrib.postgres', # Enables the unaccent lookup. ]