import logging import os import urllib.parse as urlparse from collections.abc import Callable from typing import Any, TypedDict DEFAULT_ENV = "DATABASE_URL" ENGINE_SCHEMES: dict[str, "Engine"] = {} MYSQL_BACKEND = "django.db.backends.mysql" MYSQL_GIS_BACKEND = "django.contrib.gis.db.backends.mysql" POSTGRESQL_BACKEND = "django.db.backends.postgresql" POSTGIS_BACKEND = "django.contrib.gis.db.backends.postgis" REDSHIFT_BACKEND = "django_redshift_backend" COCKROACH_BACKEND = "django_cockroachdb" TIMESCALE_BACKEND = "timescale.db.backends.postgresql" TIMESCALE_GIS_BACKEND = "timescale.db.backends.postgis" MSSQL_BACKEND = "mssql" MSSQL_PYODBC_BACKEND = "sql_server.pyodbc" SQLITE_BACKEND = "django.db.backends.sqlite3" SPATIALITE_GIS_BACKEND = "django.contrib.gis.db.backends.spatialite" ORACLE_BACKEND = "django.db.backends.oracle" ORACLE_GIS_BACKEND = "django.contrib.gis.db.backends.oracle" MYSQL_CONNECTOR_BACKEND = "mysql.connector.django" MYSQL_BACKENDS = ( MYSQL_BACKEND, MYSQL_GIS_BACKEND, ) POSTGRES_BACKENDS = ( POSTGRESQL_BACKEND, POSTGIS_BACKEND, REDSHIFT_BACKEND, COCKROACH_BACKEND, TIMESCALE_BACKEND, TIMESCALE_GIS_BACKEND, ) MSSQL_BACKENDS = ( MSSQL_BACKEND, MSSQL_PYODBC_BACKEND, ) # From https://docs.djangoproject.com/en/stable/ref/settings/#databases class DBConfig(TypedDict, total=False): ATOMIC_REQUESTS: bool AUTOCOMMIT: bool CONN_MAX_AGE: int | None CONN_HEALTH_CHECKS: bool DISABLE_SERVER_SIDE_CURSORS: bool ENGINE: str HOST: str NAME: str OPTIONS: dict[str, Any] PASSWORD: str PORT: str | int TEST: dict[str, Any] TIME_ZONE: str USER: str PostprocessCallable = Callable[[DBConfig], None] OptionType = int | str | bool class ParseError(ValueError): def __str__(self) -> str: return ( "This string is not a valid url, possibly because some of its parts" " is not properly urllib.parse.quote()'ed." ) class UnknownSchemeError(ValueError): def __init__(self, scheme: str): self.scheme = scheme def __str__(self) -> str: schemes = ", ".join(sorted(ENGINE_SCHEMES.keys())) return ( f"Scheme '{self.scheme}://' is unknown." " Did you forget to register custom backend?" f" Following schemes have registered backends: {schemes}." ) def default_postprocess(parsed_config: DBConfig) -> None: pass class Engine: def __init__( self, backend: str, postprocess: PostprocessCallable = default_postprocess, ): self.backend = backend self.postprocess = postprocess def register( scheme: str, backend: str ) -> Callable[[PostprocessCallable], PostprocessCallable]: engine = Engine(backend) if scheme not in ENGINE_SCHEMES: urlparse.uses_netloc.append(scheme) ENGINE_SCHEMES[scheme] = engine def inner(func: PostprocessCallable) -> PostprocessCallable: engine.postprocess = func return func return inner register("spatialite", SPATIALITE_GIS_BACKEND) register("mysql-connector", MYSQL_CONNECTOR_BACKEND) register("mysqlgis", MYSQL_GIS_BACKEND) register("oraclegis", ORACLE_GIS_BACKEND) register("cockroach", COCKROACH_BACKEND) @register("sqlite", SQLITE_BACKEND) def default_to_in_memory_db(parsed_config: DBConfig) -> None: # mimic sqlalchemy behaviour if not parsed_config.get("NAME"): parsed_config["NAME"] = ":memory:" @register("oracle", ORACLE_BACKEND) @register("mssqlms", MSSQL_BACKEND) @register("mssql", MSSQL_PYODBC_BACKEND) def stringify_port(parsed_config: DBConfig) -> None: parsed_config["PORT"] = str(parsed_config.get("PORT", "")) @register("mysql", MYSQL_BACKEND) @register("mysql2", MYSQL_BACKEND) def apply_ssl_ca(parsed_config: DBConfig) -> None: options = parsed_config.get("OPTIONS", {}) ca = options.pop("ssl-ca", None) if ca: options["ssl"] = {"ca": ca} @register("postgres", POSTGRESQL_BACKEND) @register("postgresql", POSTGRESQL_BACKEND) @register("pgsql", POSTGRESQL_BACKEND) @register("postgis", POSTGIS_BACKEND) @register("redshift", REDSHIFT_BACKEND) @register("timescale", TIMESCALE_BACKEND) @register("timescalegis", TIMESCALE_GIS_BACKEND) def apply_current_schema(parsed_config: DBConfig) -> None: options = parsed_config.get("OPTIONS", {}) schema = options.pop("currentSchema", None) if schema: options["options"] = f"-c search_path={schema}" def config( env: str = DEFAULT_ENV, default: str | None = None, engine: str | None = None, conn_max_age: int | None = 0, conn_health_checks: bool = False, disable_server_side_cursors: bool = False, ssl_require: bool = False, test_options: dict[str, Any] | None = None, ) -> DBConfig: """Returns configured DATABASE dictionary from DATABASE_URL.""" s = os.environ.get(env, default) if s is None: logging.warning( "No %s environment variable set, and so no databases setup", env ) if s: return parse( s, engine, conn_max_age, conn_health_checks, disable_server_side_cursors, ssl_require, test_options, ) return {} def parse( url: str, engine: str | None = None, conn_max_age: int | None = 0, conn_health_checks: bool = False, disable_server_side_cursors: bool = False, ssl_require: bool = False, test_options: dict[str, Any] | None = None, ) -> DBConfig: """Parses a database URL and returns configured DATABASE dictionary.""" settings = _convert_to_settings( engine, conn_max_age, conn_health_checks, disable_server_side_cursors, test_options, ) if url == "sqlite://:memory:": # this is a special case, because if we pass this URL into # urlparse, urlparse will choke trying to interpret "memory" # as a port number return {"ENGINE": ENGINE_SCHEMES["sqlite"].backend, "NAME": ":memory:"} # note: no other settings are required for sqlite try: split_result = urlparse.urlsplit(url) engine_obj = ENGINE_SCHEMES.get(split_result.scheme) if engine_obj is None: raise UnknownSchemeError(split_result.scheme) path = split_result.path[1:] query = urlparse.parse_qs(split_result.query) options = {k: _parse_option_values(v) for k, v in query.items()} parsed_config: DBConfig = { "ENGINE": engine_obj.backend, "USER": urlparse.unquote(split_result.username or ""), "PASSWORD": urlparse.unquote(split_result.password or ""), "HOST": urlparse.unquote(split_result.hostname or ""), "PORT": split_result.port or "", "NAME": urlparse.unquote(path), "OPTIONS": options, } except UnknownSchemeError: raise except ValueError: raise ParseError() from None # Guarantee that config has options, possibly empty, when postprocess() is called assert isinstance(parsed_config["OPTIONS"], dict) engine_obj.postprocess(parsed_config) # Update the final config with any settings passed in explicitly. parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {})) parsed_config.update(settings) if ssl_require: _configure_ssl(parsed_config) if not parsed_config["OPTIONS"]: parsed_config.pop("OPTIONS") return parsed_config def _parse_option_values(values: list[str]) -> OptionType | list[OptionType]: parsed_values = [_parse_value(v) for v in values] return parsed_values[0] if len(parsed_values) == 1 else parsed_values def _parse_value(value: str) -> OptionType: if value.isdigit(): return int(value) if value.lower() in ("true", "false"): return value.lower() == "true" return value def _configure_ssl(parsed_config: DBConfig) -> None: assert "OPTIONS" in parsed_config options = parsed_config["OPTIONS"] assert "ENGINE" in parsed_config backend = parsed_config["ENGINE"] if backend in MYSQL_BACKENDS: options["ssl_mode"] = "REQUIRED" elif backend in POSTGRES_BACKENDS: options["sslmode"] = "require" elif backend in MSSQL_BACKENDS: current_extra = options.get("extra_params", "") if "Encrypt=yes" not in current_extra: if current_extra: options["extra_params"] = f"{current_extra};Encrypt=yes" else: options["extra_params"] = "Encrypt=yes" def _convert_to_settings( engine: str | None, conn_max_age: int | None, conn_health_checks: bool, disable_server_side_cursors: bool, test_options: dict[str, Any] | None, ) -> DBConfig: settings: DBConfig = { "CONN_MAX_AGE": conn_max_age, "CONN_HEALTH_CHECKS": conn_health_checks, "DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors, } if engine: settings["ENGINE"] = engine if test_options: settings["TEST"] = test_options return settings