from django import VERSION as DJANGO_VERSION from django.core.management.color import no_style from django.db import connection, transaction from ..utils import _get_tables from .models import PostgresModel class TestUtilsMixin: def setUp(self): self.is_sqlite = connection.vendor == 'sqlite' self.is_mysql = connection.vendor == 'mysql' self.is_postgresql = connection.vendor == 'postgresql' self.django_version = DJANGO_VERSION self.force_reopen_connection() # TODO: Remove this workaround when this issue is fixed: # https://code.djangoproject.com/ticket/29494 def tearDown(self): if connection.vendor == 'postgresql': flush_args = [no_style(), (PostgresModel._meta.db_table,),] if float(".".join(map(str, DJANGO_VERSION[:2]))) < 3.1: flush_args.append(()) flush_sql_list = connection.ops.sql_flush(*flush_args) with transaction.atomic(): for sql in flush_sql_list: with connection.cursor() as cursor: cursor.execute(sql) def force_reopen_connection(self): if connection.vendor in ('mysql', 'postgresql'): # We need to reopen the connection or Django # will execute an extra SQL request below. connection.cursor() def assert_tables(self, queryset, *tables): tables = {table if isinstance(table, str) else table._meta.db_table for table in tables} self.assertSetEqual(_get_tables(queryset.db, queryset.query), tables, str(queryset.query)) def assert_query_cached(self, queryset, result=None, result_type=None, compare_results=True, before=1, after=0): if result_type is None: result_type = list if result is None else type(result) with self.assertNumQueries(before): data1 = queryset.all() if result_type is list: data1 = list(data1) with self.assertNumQueries(after): data2 = queryset.all() if result_type is list: data2 = list(data2) if not compare_results: return assert_functions = { list: self.assertListEqual, set: self.assertSetEqual, dict: self.assertDictEqual, } assert_function = assert_functions.get(result_type, self.assertEqual) assert_function(data2, data1) if result is not None: assert_function(data2, result)