This commit is contained in:
Amine 2026-03-11 15:13:11 +01:00 committed by GitHub
commit 0fc628c21a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 119 additions and 24 deletions

View file

@ -7,6 +7,39 @@ from typing import Any, TypedDict
DEFAULT_ENV = "DATABASE_URL"
ENGINE_SCHEMES: dict[str, "Engine"] = {}
MYSQL_BACKEND = "django.db.backends.mysql"
MYSQL_GIS_BACKEND = "django.contrib.gis.db.backends.mysql"
POSTGRESQL_BACKEND = "django.db.backends.postgresql"
POSTGIS_BACKEND = "django.contrib.gis.db.backends.postgis"
REDSHIFT_BACKEND = "django_redshift_backend"
COCKROACH_BACKEND = "django_cockroachdb"
TIMESCALE_BACKEND = "timescale.db.backends.postgresql"
TIMESCALE_GIS_BACKEND = "timescale.db.backends.postgis"
MSSQL_BACKEND = "mssql"
MSSQL_PYODBC_BACKEND = "sql_server.pyodbc"
SQLITE_BACKEND = "django.db.backends.sqlite3"
SPATIALITE_GIS_BACKEND = "django.contrib.gis.db.backends.spatialite"
ORACLE_BACKEND = "django.db.backends.oracle"
ORACLE_GIS_BACKEND = "django.contrib.gis.db.backends.oracle"
MYSQL_CONNECTOR_BACKEND = "mysql.connector.django"
MYSQL_BACKENDS = (
MYSQL_BACKEND,
MYSQL_GIS_BACKEND,
)
POSTGRES_BACKENDS = (
POSTGRESQL_BACKEND,
POSTGIS_BACKEND,
REDSHIFT_BACKEND,
COCKROACH_BACKEND,
TIMESCALE_BACKEND,
TIMESCALE_GIS_BACKEND,
)
MSSQL_BACKENDS = (
MSSQL_BACKEND,
MSSQL_PYODBC_BACKEND,
)
# From https://docs.djangoproject.com/en/stable/ref/settings/#databases
class DBConfig(TypedDict, total=False):
@ -80,29 +113,29 @@ def register(
return inner
register("spatialite", "django.contrib.gis.db.backends.spatialite")
register("mysql-connector", "mysql.connector.django")
register("mysqlgis", "django.contrib.gis.db.backends.mysql")
register("oraclegis", "django.contrib.gis.db.backends.oracle")
register("cockroach", "django_cockroachdb")
register("spatialite", SPATIALITE_GIS_BACKEND)
register("mysql-connector", MYSQL_CONNECTOR_BACKEND)
register("mysqlgis", MYSQL_GIS_BACKEND)
register("oraclegis", ORACLE_GIS_BACKEND)
register("cockroach", COCKROACH_BACKEND)
@register("sqlite", "django.db.backends.sqlite3")
@register("sqlite", SQLITE_BACKEND)
def default_to_in_memory_db(parsed_config: DBConfig) -> None:
# mimic sqlalchemy behaviour
if not parsed_config.get("NAME"):
parsed_config["NAME"] = ":memory:"
@register("oracle", "django.db.backends.oracle")
@register("mssqlms", "mssql")
@register("mssql", "sql_server.pyodbc")
@register("oracle", ORACLE_BACKEND)
@register("mssqlms", MSSQL_BACKEND)
@register("mssql", MSSQL_PYODBC_BACKEND)
def stringify_port(parsed_config: DBConfig) -> None:
parsed_config["PORT"] = str(parsed_config.get("PORT", ""))
@register("mysql", "django.db.backends.mysql")
@register("mysql2", "django.db.backends.mysql")
@register("mysql", MYSQL_BACKEND)
@register("mysql2", MYSQL_BACKEND)
def apply_ssl_ca(parsed_config: DBConfig) -> None:
options = parsed_config.get("OPTIONS", {})
ca = options.pop("ssl-ca", None)
@ -110,13 +143,13 @@ def apply_ssl_ca(parsed_config: DBConfig) -> None:
options["ssl"] = {"ca": ca}
@register("postgres", "django.db.backends.postgresql")
@register("postgresql", "django.db.backends.postgresql")
@register("pgsql", "django.db.backends.postgresql")
@register("postgis", "django.contrib.gis.db.backends.postgis")
@register("redshift", "django_redshift_backend")
@register("timescale", "timescale.db.backends.postgresql")
@register("timescalegis", "timescale.db.backends.postgis")
@register("postgres", POSTGRESQL_BACKEND)
@register("postgresql", POSTGRESQL_BACKEND)
@register("pgsql", POSTGRESQL_BACKEND)
@register("postgis", POSTGIS_BACKEND)
@register("redshift", REDSHIFT_BACKEND)
@register("timescale", TIMESCALE_BACKEND)
@register("timescalegis", TIMESCALE_GIS_BACKEND)
def apply_current_schema(parsed_config: DBConfig) -> None:
options = parsed_config.get("OPTIONS", {})
schema = options.pop("currentSchema", None)
@ -171,7 +204,6 @@ def parse(
conn_max_age,
conn_health_checks,
disable_server_side_cursors,
ssl_require,
test_options,
)
@ -212,6 +244,9 @@ def parse(
parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {}))
parsed_config.update(settings)
if ssl_require:
_configure_ssl(parsed_config)
if not parsed_config["OPTIONS"]:
parsed_config.pop("OPTIONS")
return parsed_config
@ -230,12 +265,30 @@ def _parse_value(value: str) -> OptionType:
return value
def _configure_ssl(parsed_config: DBConfig) -> None:
assert "OPTIONS" in parsed_config
options = parsed_config["OPTIONS"]
assert "ENGINE" in parsed_config
backend = parsed_config["ENGINE"]
if backend in MYSQL_BACKENDS:
options["ssl_mode"] = "REQUIRED"
elif backend in POSTGRES_BACKENDS:
options["sslmode"] = "require"
elif backend in MSSQL_BACKENDS:
current_extra = options.get("extra_params", "")
if "Encrypt=yes" not in current_extra:
if current_extra:
options["extra_params"] = f"{current_extra};Encrypt=yes"
else:
options["extra_params"] = "Encrypt=yes"
def _convert_to_settings(
engine: str | None,
conn_max_age: int | None,
conn_health_checks: bool,
disable_server_side_cursors: bool,
ssl_require: bool,
test_options: dict[str, Any] | None,
) -> DBConfig:
settings: DBConfig = {
@ -245,9 +298,6 @@ def _convert_to_settings(
}
if engine:
settings["ENGINE"] = engine
if ssl_require:
settings["OPTIONS"] = {}
settings["OPTIONS"]["sslmode"] = "require"
if test_options:
settings["TEST"] = test_options
return settings

View file

@ -673,10 +673,55 @@ class DatabaseTestSuite(unittest.TestCase):
os.environ,
{"DATABASE_URL": "postgres://user:password@instance.amazonaws.com:5431/d8r8?"},
)
def test_ssl_require(self) -> None:
def test_ssl_require_postgres(self) -> None:
url = dj_database_url.config(ssl_require=True)
assert url["OPTIONS"] == {'sslmode': 'require'}
@mock.patch.dict(
os.environ,
{"DATABASE_URL": "mysql://user:password@instance.amazonaws.com:3306/dbname"},
)
def test_ssl_require_mysql(self) -> None:
url = dj_database_url.config(ssl_require=True)
assert url["OPTIONS"] == {"ssl_mode": "REQUIRED"}
@mock.patch.dict(
os.environ,
{"DATABASE_URL": "mssqlms://user:password@instance.amazonaws.com:1234/dbname"},
)
def test_ssl_require_mssql(self) -> None:
url = dj_database_url.config(ssl_require=True)
assert url["OPTIONS"] == {"extra_params": "Encrypt=yes"}
@mock.patch.dict(
os.environ,
{
"DATABASE_URL": (
"mssql://user:password@instance.amazonaws.com:1234/dbname"
"?extra_params=TrustServerCertificate=yes"
)
},
)
def test_ssl_require_mssql_existing_extra_params(self) -> None:
url = dj_database_url.config(ssl_require=True)
assert url["OPTIONS"] == {
"extra_params": "TrustServerCertificate=yes;Encrypt=yes"
}
@mock.patch.dict(
os.environ,
{
"DATABASE_URL": (
"mssql://user:password@instance.amazonaws.com:1234/dbname"
"?extra_params=Encrypt=yes"
)
},
)
def test_ssl_require_mssql_already_encrypted(self) -> None:
url = dj_database_url.config(ssl_require=True)
# Should NOT append Encrypt=yes again
assert url["OPTIONS"] == {"extra_params": "Encrypt=yes"}
def test_options_int_values(self) -> None:
"""Ensure that options with integer values are parsed correctly."""
url = dj_database_url.parse(