mirror of
https://github.com/jazzband/dj-database-url.git
synced 2026-03-16 22:20:24 +00:00
Implement ssl_require parameter mapping for various database backends
This commit is contained in:
parent
dba6077081
commit
6933ca5805
2 changed files with 119 additions and 24 deletions
|
|
@ -7,6 +7,39 @@ from typing import Any, TypedDict
|
|||
DEFAULT_ENV = "DATABASE_URL"
|
||||
ENGINE_SCHEMES: dict[str, "Engine"] = {}
|
||||
|
||||
MYSQL_BACKEND = "django.db.backends.mysql"
|
||||
MYSQL_GIS_BACKEND = "django.contrib.gis.db.backends.mysql"
|
||||
POSTGRESQL_BACKEND = "django.db.backends.postgresql"
|
||||
POSTGIS_BACKEND = "django.contrib.gis.db.backends.postgis"
|
||||
REDSHIFT_BACKEND = "django_redshift_backend"
|
||||
COCKROACH_BACKEND = "django_cockroachdb"
|
||||
TIMESCALE_BACKEND = "timescale.db.backends.postgresql"
|
||||
TIMESCALE_GIS_BACKEND = "timescale.db.backends.postgis"
|
||||
MSSQL_BACKEND = "mssql"
|
||||
MSSQL_PYODBC_BACKEND = "sql_server.pyodbc"
|
||||
SQLITE_BACKEND = "django.db.backends.sqlite3"
|
||||
SPATIALITE_GIS_BACKEND = "django.contrib.gis.db.backends.spatialite"
|
||||
ORACLE_BACKEND = "django.db.backends.oracle"
|
||||
ORACLE_GIS_BACKEND = "django.contrib.gis.db.backends.oracle"
|
||||
MYSQL_CONNECTOR_BACKEND = "mysql.connector.django"
|
||||
|
||||
MYSQL_BACKENDS = (
|
||||
MYSQL_BACKEND,
|
||||
MYSQL_GIS_BACKEND,
|
||||
)
|
||||
POSTGRES_BACKENDS = (
|
||||
POSTGRESQL_BACKEND,
|
||||
POSTGIS_BACKEND,
|
||||
REDSHIFT_BACKEND,
|
||||
COCKROACH_BACKEND,
|
||||
TIMESCALE_BACKEND,
|
||||
TIMESCALE_GIS_BACKEND,
|
||||
)
|
||||
MSSQL_BACKENDS = (
|
||||
MSSQL_BACKEND,
|
||||
MSSQL_PYODBC_BACKEND,
|
||||
)
|
||||
|
||||
|
||||
# From https://docs.djangoproject.com/en/stable/ref/settings/#databases
|
||||
class DBConfig(TypedDict, total=False):
|
||||
|
|
@ -80,29 +113,29 @@ def register(
|
|||
return inner
|
||||
|
||||
|
||||
register("spatialite", "django.contrib.gis.db.backends.spatialite")
|
||||
register("mysql-connector", "mysql.connector.django")
|
||||
register("mysqlgis", "django.contrib.gis.db.backends.mysql")
|
||||
register("oraclegis", "django.contrib.gis.db.backends.oracle")
|
||||
register("cockroach", "django_cockroachdb")
|
||||
register("spatialite", SPATIALITE_GIS_BACKEND)
|
||||
register("mysql-connector", MYSQL_CONNECTOR_BACKEND)
|
||||
register("mysqlgis", MYSQL_GIS_BACKEND)
|
||||
register("oraclegis", ORACLE_GIS_BACKEND)
|
||||
register("cockroach", COCKROACH_BACKEND)
|
||||
|
||||
|
||||
@register("sqlite", "django.db.backends.sqlite3")
|
||||
@register("sqlite", SQLITE_BACKEND)
|
||||
def default_to_in_memory_db(parsed_config: DBConfig) -> None:
|
||||
# mimic sqlalchemy behaviour
|
||||
if not parsed_config.get("NAME"):
|
||||
parsed_config["NAME"] = ":memory:"
|
||||
|
||||
|
||||
@register("oracle", "django.db.backends.oracle")
|
||||
@register("mssqlms", "mssql")
|
||||
@register("mssql", "sql_server.pyodbc")
|
||||
@register("oracle", ORACLE_BACKEND)
|
||||
@register("mssqlms", MSSQL_BACKEND)
|
||||
@register("mssql", MSSQL_PYODBC_BACKEND)
|
||||
def stringify_port(parsed_config: DBConfig) -> None:
|
||||
parsed_config["PORT"] = str(parsed_config.get("PORT", ""))
|
||||
|
||||
|
||||
@register("mysql", "django.db.backends.mysql")
|
||||
@register("mysql2", "django.db.backends.mysql")
|
||||
@register("mysql", MYSQL_BACKEND)
|
||||
@register("mysql2", MYSQL_BACKEND)
|
||||
def apply_ssl_ca(parsed_config: DBConfig) -> None:
|
||||
options = parsed_config.get("OPTIONS", {})
|
||||
ca = options.pop("ssl-ca", None)
|
||||
|
|
@ -110,13 +143,13 @@ def apply_ssl_ca(parsed_config: DBConfig) -> None:
|
|||
options["ssl"] = {"ca": ca}
|
||||
|
||||
|
||||
@register("postgres", "django.db.backends.postgresql")
|
||||
@register("postgresql", "django.db.backends.postgresql")
|
||||
@register("pgsql", "django.db.backends.postgresql")
|
||||
@register("postgis", "django.contrib.gis.db.backends.postgis")
|
||||
@register("redshift", "django_redshift_backend")
|
||||
@register("timescale", "timescale.db.backends.postgresql")
|
||||
@register("timescalegis", "timescale.db.backends.postgis")
|
||||
@register("postgres", POSTGRESQL_BACKEND)
|
||||
@register("postgresql", POSTGRESQL_BACKEND)
|
||||
@register("pgsql", POSTGRESQL_BACKEND)
|
||||
@register("postgis", POSTGIS_BACKEND)
|
||||
@register("redshift", REDSHIFT_BACKEND)
|
||||
@register("timescale", TIMESCALE_BACKEND)
|
||||
@register("timescalegis", TIMESCALE_GIS_BACKEND)
|
||||
def apply_current_schema(parsed_config: DBConfig) -> None:
|
||||
options = parsed_config.get("OPTIONS", {})
|
||||
schema = options.pop("currentSchema", None)
|
||||
|
|
@ -171,7 +204,6 @@ def parse(
|
|||
conn_max_age,
|
||||
conn_health_checks,
|
||||
disable_server_side_cursors,
|
||||
ssl_require,
|
||||
test_options,
|
||||
)
|
||||
|
||||
|
|
@ -212,6 +244,9 @@ def parse(
|
|||
parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {}))
|
||||
parsed_config.update(settings)
|
||||
|
||||
if ssl_require:
|
||||
_configure_ssl(parsed_config)
|
||||
|
||||
if not parsed_config["OPTIONS"]:
|
||||
parsed_config.pop("OPTIONS")
|
||||
return parsed_config
|
||||
|
|
@ -230,12 +265,30 @@ def _parse_value(value: str) -> OptionType:
|
|||
return value
|
||||
|
||||
|
||||
def _configure_ssl(parsed_config: DBConfig) -> None:
|
||||
assert "OPTIONS" in parsed_config
|
||||
options = parsed_config["OPTIONS"]
|
||||
assert "ENGINE" in parsed_config
|
||||
backend = parsed_config["ENGINE"]
|
||||
|
||||
if backend in MYSQL_BACKENDS:
|
||||
options["ssl_mode"] = "REQUIRED"
|
||||
elif backend in POSTGRES_BACKENDS:
|
||||
options["sslmode"] = "require"
|
||||
elif backend in MSSQL_BACKENDS:
|
||||
current_extra = options.get("extra_params", "")
|
||||
if "Encrypt=yes" not in current_extra:
|
||||
if current_extra:
|
||||
options["extra_params"] = f"{current_extra};Encrypt=yes"
|
||||
else:
|
||||
options["extra_params"] = "Encrypt=yes"
|
||||
|
||||
|
||||
def _convert_to_settings(
|
||||
engine: str | None,
|
||||
conn_max_age: int | None,
|
||||
conn_health_checks: bool,
|
||||
disable_server_side_cursors: bool,
|
||||
ssl_require: bool,
|
||||
test_options: dict[str, Any] | None,
|
||||
) -> DBConfig:
|
||||
settings: DBConfig = {
|
||||
|
|
@ -245,9 +298,6 @@ def _convert_to_settings(
|
|||
}
|
||||
if engine:
|
||||
settings["ENGINE"] = engine
|
||||
if ssl_require:
|
||||
settings["OPTIONS"] = {}
|
||||
settings["OPTIONS"]["sslmode"] = "require"
|
||||
if test_options:
|
||||
settings["TEST"] = test_options
|
||||
return settings
|
||||
|
|
|
|||
|
|
@ -673,10 +673,55 @@ class DatabaseTestSuite(unittest.TestCase):
|
|||
os.environ,
|
||||
{"DATABASE_URL": "postgres://user:password@instance.amazonaws.com:5431/d8r8?"},
|
||||
)
|
||||
def test_ssl_require(self) -> None:
|
||||
def test_ssl_require_postgres(self) -> None:
|
||||
url = dj_database_url.config(ssl_require=True)
|
||||
assert url["OPTIONS"] == {'sslmode': 'require'}
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"DATABASE_URL": "mysql://user:password@instance.amazonaws.com:3306/dbname"},
|
||||
)
|
||||
def test_ssl_require_mysql(self) -> None:
|
||||
url = dj_database_url.config(ssl_require=True)
|
||||
assert url["OPTIONS"] == {"ssl_mode": "REQUIRED"}
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{"DATABASE_URL": "mssqlms://user:password@instance.amazonaws.com:1234/dbname"},
|
||||
)
|
||||
def test_ssl_require_mssql(self) -> None:
|
||||
url = dj_database_url.config(ssl_require=True)
|
||||
assert url["OPTIONS"] == {"extra_params": "Encrypt=yes"}
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DATABASE_URL": (
|
||||
"mssql://user:password@instance.amazonaws.com:1234/dbname"
|
||||
"?extra_params=TrustServerCertificate=yes"
|
||||
)
|
||||
},
|
||||
)
|
||||
def test_ssl_require_mssql_existing_extra_params(self) -> None:
|
||||
url = dj_database_url.config(ssl_require=True)
|
||||
assert url["OPTIONS"] == {
|
||||
"extra_params": "TrustServerCertificate=yes;Encrypt=yes"
|
||||
}
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"DATABASE_URL": (
|
||||
"mssql://user:password@instance.amazonaws.com:1234/dbname"
|
||||
"?extra_params=Encrypt=yes"
|
||||
)
|
||||
},
|
||||
)
|
||||
def test_ssl_require_mssql_already_encrypted(self) -> None:
|
||||
url = dj_database_url.config(ssl_require=True)
|
||||
# Should NOT append Encrypt=yes again
|
||||
assert url["OPTIONS"] == {"extra_params": "Encrypt=yes"}
|
||||
|
||||
def test_options_int_values(self) -> None:
|
||||
"""Ensure that options with integer values are parsed correctly."""
|
||||
url = dj_database_url.parse(
|
||||
|
|
|
|||
Loading…
Reference in a new issue