dj-database-url/dj_database_url/__init__.py

303 lines
8.9 KiB
Python

import logging
import os
import urllib.parse as urlparse
from collections.abc import Callable
from typing import Any, TypedDict
DEFAULT_ENV = "DATABASE_URL"
ENGINE_SCHEMES: dict[str, "Engine"] = {}
MYSQL_BACKEND = "django.db.backends.mysql"
MYSQL_GIS_BACKEND = "django.contrib.gis.db.backends.mysql"
POSTGRESQL_BACKEND = "django.db.backends.postgresql"
POSTGIS_BACKEND = "django.contrib.gis.db.backends.postgis"
REDSHIFT_BACKEND = "django_redshift_backend"
COCKROACH_BACKEND = "django_cockroachdb"
TIMESCALE_BACKEND = "timescale.db.backends.postgresql"
TIMESCALE_GIS_BACKEND = "timescale.db.backends.postgis"
MSSQL_BACKEND = "mssql"
MSSQL_PYODBC_BACKEND = "sql_server.pyodbc"
SQLITE_BACKEND = "django.db.backends.sqlite3"
SPATIALITE_GIS_BACKEND = "django.contrib.gis.db.backends.spatialite"
ORACLE_BACKEND = "django.db.backends.oracle"
ORACLE_GIS_BACKEND = "django.contrib.gis.db.backends.oracle"
MYSQL_CONNECTOR_BACKEND = "mysql.connector.django"
MYSQL_BACKENDS = (
MYSQL_BACKEND,
MYSQL_GIS_BACKEND,
)
POSTGRES_BACKENDS = (
POSTGRESQL_BACKEND,
POSTGIS_BACKEND,
REDSHIFT_BACKEND,
COCKROACH_BACKEND,
TIMESCALE_BACKEND,
TIMESCALE_GIS_BACKEND,
)
MSSQL_BACKENDS = (
MSSQL_BACKEND,
MSSQL_PYODBC_BACKEND,
)
# From https://docs.djangoproject.com/en/stable/ref/settings/#databases
class DBConfig(TypedDict, total=False):
ATOMIC_REQUESTS: bool
AUTOCOMMIT: bool
CONN_MAX_AGE: int | None
CONN_HEALTH_CHECKS: bool
DISABLE_SERVER_SIDE_CURSORS: bool
ENGINE: str
HOST: str
NAME: str
OPTIONS: dict[str, Any]
PASSWORD: str
PORT: str | int
TEST: dict[str, Any]
TIME_ZONE: str
USER: str
PostprocessCallable = Callable[[DBConfig], None]
OptionType = int | str | bool
class ParseError(ValueError):
def __str__(self) -> str:
return (
"This string is not a valid url, possibly because some of its parts"
" is not properly urllib.parse.quote()'ed."
)
class UnknownSchemeError(ValueError):
def __init__(self, scheme: str):
self.scheme = scheme
def __str__(self) -> str:
schemes = ", ".join(sorted(ENGINE_SCHEMES.keys()))
return (
f"Scheme '{self.scheme}://' is unknown."
" Did you forget to register custom backend?"
f" Following schemes have registered backends: {schemes}."
)
def default_postprocess(parsed_config: DBConfig) -> None:
pass
class Engine:
def __init__(
self,
backend: str,
postprocess: PostprocessCallable = default_postprocess,
):
self.backend = backend
self.postprocess = postprocess
def register(
scheme: str, backend: str
) -> Callable[[PostprocessCallable], PostprocessCallable]:
engine = Engine(backend)
if scheme not in ENGINE_SCHEMES:
urlparse.uses_netloc.append(scheme)
ENGINE_SCHEMES[scheme] = engine
def inner(func: PostprocessCallable) -> PostprocessCallable:
engine.postprocess = func
return func
return inner
register("spatialite", SPATIALITE_GIS_BACKEND)
register("mysql-connector", MYSQL_CONNECTOR_BACKEND)
register("mysqlgis", MYSQL_GIS_BACKEND)
register("oraclegis", ORACLE_GIS_BACKEND)
register("cockroach", COCKROACH_BACKEND)
@register("sqlite", SQLITE_BACKEND)
def default_to_in_memory_db(parsed_config: DBConfig) -> None:
# mimic sqlalchemy behaviour
if not parsed_config.get("NAME"):
parsed_config["NAME"] = ":memory:"
@register("oracle", ORACLE_BACKEND)
@register("mssqlms", MSSQL_BACKEND)
@register("mssql", MSSQL_PYODBC_BACKEND)
def stringify_port(parsed_config: DBConfig) -> None:
parsed_config["PORT"] = str(parsed_config.get("PORT", ""))
@register("mysql", MYSQL_BACKEND)
@register("mysql2", MYSQL_BACKEND)
def apply_ssl_ca(parsed_config: DBConfig) -> None:
options = parsed_config.get("OPTIONS", {})
ca = options.pop("ssl-ca", None)
if ca:
options["ssl"] = {"ca": ca}
@register("postgres", POSTGRESQL_BACKEND)
@register("postgresql", POSTGRESQL_BACKEND)
@register("pgsql", POSTGRESQL_BACKEND)
@register("postgis", POSTGIS_BACKEND)
@register("redshift", REDSHIFT_BACKEND)
@register("timescale", TIMESCALE_BACKEND)
@register("timescalegis", TIMESCALE_GIS_BACKEND)
def apply_current_schema(parsed_config: DBConfig) -> None:
options = parsed_config.get("OPTIONS", {})
schema = options.pop("currentSchema", None)
if schema:
options["options"] = f"-c search_path={schema}"
def config(
env: str = DEFAULT_ENV,
default: str | None = None,
engine: str | None = None,
conn_max_age: int | None = 0,
conn_health_checks: bool = False,
disable_server_side_cursors: bool = False,
ssl_require: bool = False,
test_options: dict[str, Any] | None = None,
) -> DBConfig:
"""Returns configured DATABASE dictionary from DATABASE_URL."""
s = os.environ.get(env, default)
if s is None:
logging.warning(
"No %s environment variable set, and so no databases setup", env
)
if s:
return parse(
s,
engine,
conn_max_age,
conn_health_checks,
disable_server_side_cursors,
ssl_require,
test_options,
)
return {}
def parse(
url: str,
engine: str | None = None,
conn_max_age: int | None = 0,
conn_health_checks: bool = False,
disable_server_side_cursors: bool = False,
ssl_require: bool = False,
test_options: dict[str, Any] | None = None,
) -> DBConfig:
"""Parses a database URL and returns configured DATABASE dictionary."""
settings = _convert_to_settings(
engine,
conn_max_age,
conn_health_checks,
disable_server_side_cursors,
test_options,
)
if url == "sqlite://:memory:":
# this is a special case, because if we pass this URL into
# urlparse, urlparse will choke trying to interpret "memory"
# as a port number
return {"ENGINE": ENGINE_SCHEMES["sqlite"].backend, "NAME": ":memory:"}
# note: no other settings are required for sqlite
try:
split_result = urlparse.urlsplit(url)
engine_obj = ENGINE_SCHEMES.get(split_result.scheme)
if engine_obj is None:
raise UnknownSchemeError(split_result.scheme)
path = split_result.path[1:]
query = urlparse.parse_qs(split_result.query)
options = {k: _parse_option_values(v) for k, v in query.items()}
parsed_config: DBConfig = {
"ENGINE": engine_obj.backend,
"USER": urlparse.unquote(split_result.username or ""),
"PASSWORD": urlparse.unquote(split_result.password or ""),
"HOST": urlparse.unquote(split_result.hostname or ""),
"PORT": split_result.port or "",
"NAME": urlparse.unquote(path),
"OPTIONS": options,
}
except UnknownSchemeError:
raise
except ValueError:
raise ParseError() from None
# Guarantee that config has options, possibly empty, when postprocess() is called
assert isinstance(parsed_config["OPTIONS"], dict)
engine_obj.postprocess(parsed_config)
# Update the final config with any settings passed in explicitly.
parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {}))
parsed_config.update(settings)
if ssl_require:
_configure_ssl(parsed_config)
if not parsed_config["OPTIONS"]:
parsed_config.pop("OPTIONS")
return parsed_config
def _parse_option_values(values: list[str]) -> OptionType | list[OptionType]:
parsed_values = [_parse_value(v) for v in values]
return parsed_values[0] if len(parsed_values) == 1 else parsed_values
def _parse_value(value: str) -> OptionType:
if value.isdigit():
return int(value)
if value.lower() in ("true", "false"):
return value.lower() == "true"
return value
def _configure_ssl(parsed_config: DBConfig) -> None:
assert "OPTIONS" in parsed_config
options = parsed_config["OPTIONS"]
assert "ENGINE" in parsed_config
backend = parsed_config["ENGINE"]
if backend in MYSQL_BACKENDS:
options["ssl_mode"] = "REQUIRED"
elif backend in POSTGRES_BACKENDS:
options["sslmode"] = "require"
elif backend in MSSQL_BACKENDS:
current_extra = options.get("extra_params", "")
if "Encrypt=yes" not in current_extra:
if current_extra:
options["extra_params"] = f"{current_extra};Encrypt=yes"
else:
options["extra_params"] = "Encrypt=yes"
def _convert_to_settings(
engine: str | None,
conn_max_age: int | None,
conn_health_checks: bool,
disable_server_side_cursors: bool,
test_options: dict[str, Any] | None,
) -> DBConfig:
settings: DBConfig = {
"CONN_MAX_AGE": conn_max_age,
"CONN_HEALTH_CHECKS": conn_health_checks,
"DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors,
}
if engine:
settings["ENGINE"] = engine
if test_options:
settings["TEST"] = test_options
return settings