From 6933ca58058546fe99b46409eafed75b59b97ac9 Mon Sep 17 00:00:00 2001 From: aminechm Date: Thu, 8 Jan 2026 14:56:44 +0100 Subject: [PATCH] Implement ssl_require parameter mapping for various database backends --- dj_database_url/__init__.py | 96 ++++++++++++++++++++++++++--------- tests/test_dj_database_url.py | 47 ++++++++++++++++- 2 files changed, 119 insertions(+), 24 deletions(-) diff --git a/dj_database_url/__init__.py b/dj_database_url/__init__.py index 6328005..524a439 100644 --- a/dj_database_url/__init__.py +++ b/dj_database_url/__init__.py @@ -7,6 +7,39 @@ from typing import Any, TypedDict DEFAULT_ENV = "DATABASE_URL" ENGINE_SCHEMES: dict[str, "Engine"] = {} +MYSQL_BACKEND = "django.db.backends.mysql" +MYSQL_GIS_BACKEND = "django.contrib.gis.db.backends.mysql" +POSTGRESQL_BACKEND = "django.db.backends.postgresql" +POSTGIS_BACKEND = "django.contrib.gis.db.backends.postgis" +REDSHIFT_BACKEND = "django_redshift_backend" +COCKROACH_BACKEND = "django_cockroachdb" +TIMESCALE_BACKEND = "timescale.db.backends.postgresql" +TIMESCALE_GIS_BACKEND = "timescale.db.backends.postgis" +MSSQL_BACKEND = "mssql" +MSSQL_PYODBC_BACKEND = "sql_server.pyodbc" +SQLITE_BACKEND = "django.db.backends.sqlite3" +SPATIALITE_GIS_BACKEND = "django.contrib.gis.db.backends.spatialite" +ORACLE_BACKEND = "django.db.backends.oracle" +ORACLE_GIS_BACKEND = "django.contrib.gis.db.backends.oracle" +MYSQL_CONNECTOR_BACKEND = "mysql.connector.django" + +MYSQL_BACKENDS = ( + MYSQL_BACKEND, + MYSQL_GIS_BACKEND, +) +POSTGRES_BACKENDS = ( + POSTGRESQL_BACKEND, + POSTGIS_BACKEND, + REDSHIFT_BACKEND, + COCKROACH_BACKEND, + TIMESCALE_BACKEND, + TIMESCALE_GIS_BACKEND, +) +MSSQL_BACKENDS = ( + MSSQL_BACKEND, + MSSQL_PYODBC_BACKEND, +) + # From https://docs.djangoproject.com/en/stable/ref/settings/#databases class DBConfig(TypedDict, total=False): @@ -80,29 +113,29 @@ def register( return inner -register("spatialite", "django.contrib.gis.db.backends.spatialite") -register("mysql-connector", "mysql.connector.django") -register("mysqlgis", "django.contrib.gis.db.backends.mysql") -register("oraclegis", "django.contrib.gis.db.backends.oracle") -register("cockroach", "django_cockroachdb") +register("spatialite", SPATIALITE_GIS_BACKEND) +register("mysql-connector", MYSQL_CONNECTOR_BACKEND) +register("mysqlgis", MYSQL_GIS_BACKEND) +register("oraclegis", ORACLE_GIS_BACKEND) +register("cockroach", COCKROACH_BACKEND) -@register("sqlite", "django.db.backends.sqlite3") +@register("sqlite", SQLITE_BACKEND) def default_to_in_memory_db(parsed_config: DBConfig) -> None: # mimic sqlalchemy behaviour if not parsed_config.get("NAME"): parsed_config["NAME"] = ":memory:" -@register("oracle", "django.db.backends.oracle") -@register("mssqlms", "mssql") -@register("mssql", "sql_server.pyodbc") +@register("oracle", ORACLE_BACKEND) +@register("mssqlms", MSSQL_BACKEND) +@register("mssql", MSSQL_PYODBC_BACKEND) def stringify_port(parsed_config: DBConfig) -> None: parsed_config["PORT"] = str(parsed_config.get("PORT", "")) -@register("mysql", "django.db.backends.mysql") -@register("mysql2", "django.db.backends.mysql") +@register("mysql", MYSQL_BACKEND) +@register("mysql2", MYSQL_BACKEND) def apply_ssl_ca(parsed_config: DBConfig) -> None: options = parsed_config.get("OPTIONS", {}) ca = options.pop("ssl-ca", None) @@ -110,13 +143,13 @@ def apply_ssl_ca(parsed_config: DBConfig) -> None: options["ssl"] = {"ca": ca} -@register("postgres", "django.db.backends.postgresql") -@register("postgresql", "django.db.backends.postgresql") -@register("pgsql", "django.db.backends.postgresql") -@register("postgis", "django.contrib.gis.db.backends.postgis") -@register("redshift", "django_redshift_backend") -@register("timescale", "timescale.db.backends.postgresql") -@register("timescalegis", "timescale.db.backends.postgis") +@register("postgres", POSTGRESQL_BACKEND) +@register("postgresql", POSTGRESQL_BACKEND) +@register("pgsql", POSTGRESQL_BACKEND) +@register("postgis", POSTGIS_BACKEND) +@register("redshift", REDSHIFT_BACKEND) +@register("timescale", TIMESCALE_BACKEND) +@register("timescalegis", TIMESCALE_GIS_BACKEND) def apply_current_schema(parsed_config: DBConfig) -> None: options = parsed_config.get("OPTIONS", {}) schema = options.pop("currentSchema", None) @@ -171,7 +204,6 @@ def parse( conn_max_age, conn_health_checks, disable_server_side_cursors, - ssl_require, test_options, ) @@ -212,6 +244,9 @@ def parse( parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {})) parsed_config.update(settings) + if ssl_require: + _configure_ssl(parsed_config) + if not parsed_config["OPTIONS"]: parsed_config.pop("OPTIONS") return parsed_config @@ -230,12 +265,30 @@ def _parse_value(value: str) -> OptionType: return value +def _configure_ssl(parsed_config: DBConfig) -> None: + assert "OPTIONS" in parsed_config + options = parsed_config["OPTIONS"] + assert "ENGINE" in parsed_config + backend = parsed_config["ENGINE"] + + if backend in MYSQL_BACKENDS: + options["ssl_mode"] = "REQUIRED" + elif backend in POSTGRES_BACKENDS: + options["sslmode"] = "require" + elif backend in MSSQL_BACKENDS: + current_extra = options.get("extra_params", "") + if "Encrypt=yes" not in current_extra: + if current_extra: + options["extra_params"] = f"{current_extra};Encrypt=yes" + else: + options["extra_params"] = "Encrypt=yes" + + def _convert_to_settings( engine: str | None, conn_max_age: int | None, conn_health_checks: bool, disable_server_side_cursors: bool, - ssl_require: bool, test_options: dict[str, Any] | None, ) -> DBConfig: settings: DBConfig = { @@ -245,9 +298,6 @@ def _convert_to_settings( } if engine: settings["ENGINE"] = engine - if ssl_require: - settings["OPTIONS"] = {} - settings["OPTIONS"]["sslmode"] = "require" if test_options: settings["TEST"] = test_options return settings diff --git a/tests/test_dj_database_url.py b/tests/test_dj_database_url.py index 093083e..5cb353e 100644 --- a/tests/test_dj_database_url.py +++ b/tests/test_dj_database_url.py @@ -673,10 +673,55 @@ class DatabaseTestSuite(unittest.TestCase): os.environ, {"DATABASE_URL": "postgres://user:password@instance.amazonaws.com:5431/d8r8?"}, ) - def test_ssl_require(self) -> None: + def test_ssl_require_postgres(self) -> None: url = dj_database_url.config(ssl_require=True) assert url["OPTIONS"] == {'sslmode': 'require'} + @mock.patch.dict( + os.environ, + {"DATABASE_URL": "mysql://user:password@instance.amazonaws.com:3306/dbname"}, + ) + def test_ssl_require_mysql(self) -> None: + url = dj_database_url.config(ssl_require=True) + assert url["OPTIONS"] == {"ssl_mode": "REQUIRED"} + + @mock.patch.dict( + os.environ, + {"DATABASE_URL": "mssqlms://user:password@instance.amazonaws.com:1234/dbname"}, + ) + def test_ssl_require_mssql(self) -> None: + url = dj_database_url.config(ssl_require=True) + assert url["OPTIONS"] == {"extra_params": "Encrypt=yes"} + + @mock.patch.dict( + os.environ, + { + "DATABASE_URL": ( + "mssql://user:password@instance.amazonaws.com:1234/dbname" + "?extra_params=TrustServerCertificate=yes" + ) + }, + ) + def test_ssl_require_mssql_existing_extra_params(self) -> None: + url = dj_database_url.config(ssl_require=True) + assert url["OPTIONS"] == { + "extra_params": "TrustServerCertificate=yes;Encrypt=yes" + } + + @mock.patch.dict( + os.environ, + { + "DATABASE_URL": ( + "mssql://user:password@instance.amazonaws.com:1234/dbname" + "?extra_params=Encrypt=yes" + ) + }, + ) + def test_ssl_require_mssql_already_encrypted(self) -> None: + url = dj_database_url.config(ssl_require=True) + # Should NOT append Encrypt=yes again + assert url["OPTIONS"] == {"extra_params": "Encrypt=yes"} + def test_options_int_values(self) -> None: """Ensure that options with integer values are parsed correctly.""" url = dj_database_url.parse(