Generic backend registration

This commit is contained in:
Alexander Gerbik 2024-02-29 12:49:25 +01:00
parent bcf163e3ab
commit b27bd63e82
3 changed files with 286 additions and 137 deletions

View file

@ -22,12 +22,6 @@ also a `conn_max_age` argument to easily enable Django's connection pool.
If you'd rather not use an environment variable, you can pass a URL in directly If you'd rather not use an environment variable, you can pass a URL in directly
instead to ``dj_database_url.parse``. instead to ``dj_database_url.parse``.
Supported Databases
-------------------
Support currently exists for PostgreSQL, PostGIS, MySQL, MySQL (GIS),
Oracle, Oracle (GIS), Redshift, CockroachDB, Timescale, Timescale (GIS) and SQLite.
Installation Installation
------------ ------------
@ -148,6 +142,63 @@ and should instead be passed as:
DATABASES['default'] = dj_database_url.config(default='postgres://...', test_options={'NAME': 'mytestdatabase'}) DATABASES['default'] = dj_database_url.config(default='postgres://...', test_options={'NAME': 'mytestdatabase'})
Supported Databases
-------------------
Support currently exists for PostgreSQL, PostGIS, MySQL, MySQL (GIS),
Oracle, Oracle (GIS), Redshift, CockroachDB, Timescale, Timescale (GIS) and SQLite.
If you want to use
some non-default backends, you need to register them first:
.. code-block:: python
import dj_database_url
# registration should be performed only once
dj_database_url.register("mysql-connector", "mysql.connector.django")
assert dj_database_url.parse("mysql-connector://user:password@host:port/db-name") == {
"ENGINE": "mysql.connector.django",
# ...other connection params
}
Some backends need further config adjustments (e.g. oracle and mssql
expect ``PORT`` to be a string). For such cases you can provide a
post-processing function to ``register()`` (note that ``register()`` is
used as a **decorator(!)** in this case):
.. code-block:: python
import dj_database_url
@dj_database_url.register("mssql", "sql_server.pyodbc")
def stringify_port(config):
config["PORT"] = str(config["PORT"])
@dj_database_url.register("redshift", "django_redshift_backend")
def apply_current_schema(config):
options = config["OPTIONS"]
schema = options.pop("currentSchema", None)
if schema:
options["options"] = f"-c search_path={schema}"
@dj_database_url.register("snowflake", "django_snowflake")
def adjust_snowflake_config(config):
config.pop("PORT", None)
config["ACCOUNT"] = config.pop("HOST")
name, _, schema = config["NAME"].partition("/")
if schema:
config["SCHEMA"] = schema
config["NAME"] = name
options = config.get("OPTIONS", {})
warehouse = options.pop("warehouse", None)
if warehouse:
config["WAREHOUSE"] = warehouse
role = options.pop("role", None)
if role:
config["ROLE"] = role
URL schema URL schema
---------- ----------

View file

@ -1,50 +1,15 @@
import logging import logging
import os import os
import urllib.parse as urlparse import urllib.parse as urlparse
from typing import Any, Dict, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from typing_extensions import TypedDict from typing_extensions import TypedDict
DEFAULT_ENV = "DATABASE_URL" DEFAULT_ENV = "DATABASE_URL"
ENGINE_SCHEMES: Dict[str, "Engine"] = {}
SCHEMES = {
"postgres": "django.db.backends.postgresql",
"postgresql": "django.db.backends.postgresql",
"pgsql": "django.db.backends.postgresql",
"postgis": "django.contrib.gis.db.backends.postgis",
"mysql": "django.db.backends.mysql",
"mysql2": "django.db.backends.mysql",
"mysqlgis": "django.contrib.gis.db.backends.mysql",
"mysql-connector": "mysql.connector.django",
"mssql": "sql_server.pyodbc",
"mssqlms": "mssql",
"spatialite": "django.contrib.gis.db.backends.spatialite",
"sqlite": "django.db.backends.sqlite3",
"oracle": "django.db.backends.oracle",
"oraclegis": "django.contrib.gis.db.backends.oracle",
"redshift": "django_redshift_backend",
"cockroach": "django_cockroachdb",
"timescale": "timescale.db.backends.postgresql",
"timescalegis": "timescale.db.backends.postgis",
}
SCHEMES_WITH_SEARCH_PATH = [
"postgres",
"postgresql",
"pgsql",
"postgis",
"redshift",
"timescale",
"timescalegis",
]
# Register database schemes in URLs.
for key in SCHEMES.keys():
urlparse.uses_netloc.append(key)
del key
# From https://docs.djangoproject.com/en/4.0/ref/settings/#databases # From https://docs.djangoproject.com/en/stable/ref/settings/#databases
class DBConfig(TypedDict, total=False): class DBConfig(TypedDict, total=False):
ATOMIC_REQUESTS: bool ATOMIC_REQUESTS: bool
AUTOCOMMIT: bool AUTOCOMMIT: bool
@ -62,11 +27,109 @@ class DBConfig(TypedDict, total=False):
USER: str USER: str
PostprocessCallable = Callable[[DBConfig], None]
OptionType = Union[int, str, bool]
class ParseError(ValueError):
def __str__(self) -> str:
return (
"This string is not a valid url, possibly because some of its parts"
" is not properly urllib.parse.quote()'ed."
)
class UnknownSchemeError(ValueError):
def __init__(self, scheme: str) -> None:
self.scheme = scheme
def __str__(self) -> str:
schemes = ", ".join(sorted(ENGINE_SCHEMES.keys()))
return (
f"Scheme '{self.scheme}://' is unknown."
" Did you forget to register custom backend?"
f" Following schemes have registered backends: {schemes}."
)
def default_postprocess(parsed_config: DBConfig) -> None:
pass
class Engine:
def __init__(
self,
backend: str,
postprocess: PostprocessCallable = default_postprocess,
) -> None:
self.backend = backend
self.postprocess = postprocess
def register(
scheme: str, backend: str
) -> Callable[[PostprocessCallable], PostprocessCallable]:
engine = Engine(backend)
if scheme not in ENGINE_SCHEMES:
urlparse.uses_netloc.append(scheme)
ENGINE_SCHEMES[scheme] = engine
def inner(func: PostprocessCallable) -> PostprocessCallable:
engine.postprocess = func
return func
return inner
register("spatialite", "django.contrib.gis.db.backends.spatialite")
register("mysql-connector", "mysql.connector.django")
register("mysqlgis", "django.contrib.gis.db.backends.mysql")
register("oraclegis", "django.contrib.gis.db.backends.oracle")
register("cockroach", "django_cockroachdb")
@register("sqlite", "django.db.backends.sqlite3")
def default_to_in_memory_db(parsed_config: DBConfig) -> None:
# mimic sqlalchemy behaviour
if parsed_config["NAME"] == "":
parsed_config["NAME"] = ":memory:"
@register("oracle", "django.db.backends.oracle")
@register("mssqlms", "mssql")
@register("mssql", "sql_server.pyodbc")
def stringify_port(parsed_config: DBConfig) -> None:
parsed_config["PORT"] = str(parsed_config["PORT"])
@register("mysql", "django.db.backends.mysql")
@register("mysql2", "django.db.backends.mysql")
def apply_ssl_ca(parsed_config: DBConfig) -> None:
options = parsed_config["OPTIONS"]
ca = options.pop("ssl-ca", None)
if ca:
options["ssl"] = {"ca": ca}
@register("postgres", "django.db.backends.postgresql")
@register("postgresql", "django.db.backends.postgresql")
@register("pgsql", "django.db.backends.postgresql")
@register("postgis", "django.contrib.gis.db.backends.postgis")
@register("redshift", "django_redshift_backend")
@register("timescale", "timescale.db.backends.postgresql")
@register("timescalegis", "timescale.db.backends.postgis")
def apply_current_schema(parsed_config: DBConfig) -> None:
options = parsed_config["OPTIONS"]
schema = options.pop("currentSchema", None)
if schema:
options["options"] = f"-c search_path={schema}"
def config( def config(
env: str = DEFAULT_ENV, env: str = DEFAULT_ENV,
default: Optional[str] = None, default: Optional[str] = None,
engine: Optional[str] = None, engine: Optional[str] = None,
conn_max_age: Optional[int] = 0, conn_max_age: int = 0,
conn_health_checks: bool = False, conn_health_checks: bool = False,
disable_server_side_cursors: bool = False, disable_server_side_cursors: bool = False,
ssl_require: bool = False, ssl_require: bool = False,
@ -77,7 +140,7 @@ def config(
if s is None: if s is None:
logging.warning( logging.warning(
"No %s environment variable set, and so no databases setup" % env "No %s environment variable set, and so no databases setup", env
) )
if s: if s:
@ -97,107 +160,95 @@ def config(
def parse( def parse(
url: str, url: str,
engine: Optional[str] = None, engine: Optional[str] = None,
conn_max_age: Optional[int] = 0, conn_max_age: int = 0,
conn_health_checks: bool = False, conn_health_checks: bool = False,
disable_server_side_cursors: bool = False, disable_server_side_cursors: bool = False,
ssl_require: bool = False, ssl_require: bool = False,
test_options: Optional[dict] = None, test_options: Optional[dict] = None,
) -> DBConfig: ) -> DBConfig:
"""Parses a database URL.""" """Parses a database URL and returns configured DATABASE dictionary."""
settings = _convert_to_settings(
engine,
conn_max_age,
conn_health_checks,
disable_server_side_cursors,
ssl_require,
test_options,
)
if url == "sqlite://:memory:": if url == "sqlite://:memory:":
# this is a special case, because if we pass this URL into # this is a special case, because if we pass this URL into
# urlparse, urlparse will choke trying to interpret "memory" # urlparse, urlparse will choke trying to interpret "memory"
# as a port number # as a port number
return {"ENGINE": SCHEMES["sqlite"], "NAME": ":memory:"} return {"ENGINE": ENGINE_SCHEMES["sqlite"].backend, "NAME": ":memory:"}
# note: no other settings are required for sqlite # note: no other settings are required for sqlite
# otherwise parse the url as normal try:
parsed_config: DBConfig = {} split_result = urlparse.urlsplit(url)
engine_obj = ENGINE_SCHEMES.get(split_result.scheme)
if test_options is None: if engine_obj is None:
test_options = {} raise UnknownSchemeError(split_result.scheme)
path = split_result.path[1:]
spliturl = urlparse.urlsplit(url) query = urlparse.parse_qs(split_result.query)
options = {k: _parse_option_values(v) for k, v in query.items()}
# Split query strings from path. parsed_config: DBConfig = {
path = spliturl.path[1:] "ENGINE": engine_obj.backend,
query = urlparse.parse_qs(spliturl.query) "USER": urlparse.unquote(split_result.username or ""),
"PASSWORD": urlparse.unquote(split_result.password or ""),
# If we are using sqlite and we have no path, then assume we "HOST": urlparse.unquote(split_result.hostname or ""),
# want an in-memory database (this is the behaviour of sqlalchemy) "PORT": split_result.port or "",
if spliturl.scheme == "sqlite" and path == "": "NAME": urlparse.unquote(path),
path = ":memory:" "OPTIONS": options,
# Handle postgres percent-encoded paths.
hostname = spliturl.hostname or ""
if "%" in hostname:
# Switch to url.netloc to avoid lower cased paths
hostname = spliturl.netloc
if "@" in hostname:
hostname = hostname.rsplit("@", 1)[1]
# Use URL Parse library to decode % encodes
hostname = urlparse.unquote(hostname)
# Lookup specified engine.
if engine is None:
engine = SCHEMES.get(spliturl.scheme)
if engine is None:
raise ValueError(
"No support for '%s'. We support: %s"
% (spliturl.scheme, ", ".join(sorted(SCHEMES.keys())))
)
port = (
str(spliturl.port)
if spliturl.port
and engine in (SCHEMES["oracle"], SCHEMES["mssql"], SCHEMES["mssqlms"])
else spliturl.port
)
# Update with environment configuration.
parsed_config.update(
{
"NAME": urlparse.unquote(path or ""),
"USER": urlparse.unquote(spliturl.username or ""),
"PASSWORD": urlparse.unquote(spliturl.password or ""),
"HOST": hostname,
"PORT": port or "",
"CONN_MAX_AGE": conn_max_age,
"CONN_HEALTH_CHECKS": conn_health_checks,
"DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors,
"ENGINE": engine,
} }
) except UnknownSchemeError:
if test_options: raise
parsed_config.update( except ValueError:
{ raise ParseError() from None
'TEST': test_options,
}
)
# Pass the query string into OPTIONS. # Guarantee that config has options, possibly empty, when postprocess() is called
options: Dict[str, Any] = {} assert isinstance(parsed_config["OPTIONS"], dict)
for key, values in query.items(): engine_obj.postprocess(parsed_config)
if spliturl.scheme == "mysql" and key == "ssl-ca":
options["ssl"] = {"ca": values[-1]}
continue
value = values[-1] # Update the final config with any settings passed in explicitly.
if value.isdigit(): parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {}))
options[key] = int(value) parsed_config.update(settings)
elif value.lower() in ("true", "false"):
options[key] = value.lower() == "true"
else:
options[key] = value
if ssl_require:
options["sslmode"] = "require"
# Support for Postgres Schema URLs
if "currentSchema" in options and spliturl.scheme in SCHEMES_WITH_SEARCH_PATH:
options["options"] = "-c search_path={0}".format(options.pop("currentSchema"))
if options:
parsed_config["OPTIONS"] = options
if not parsed_config["OPTIONS"]:
parsed_config.pop("OPTIONS")
return parsed_config return parsed_config
def _parse_option_values(values: List[str]) -> Union[OptionType, List[OptionType]]:
parsed_values = [_parse_value(v) for v in values]
return parsed_values[0] if len(parsed_values) == 1 else parsed_values
def _parse_value(value: str) -> OptionType:
if value.isdigit():
return int(value)
if value.lower() in ("true", "false"):
return value.lower() == "true"
return value
def _convert_to_settings(
engine: Optional[str],
conn_max_age: int,
conn_health_checks: bool,
disable_server_side_cursors: bool,
ssl_require: bool,
test_options: Optional[dict],
) -> DBConfig:
settings: DBConfig = {
"CONN_MAX_AGE": conn_max_age,
"CONN_HEALTH_CHECKS": conn_health_checks,
"DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors,
}
if engine:
settings["ENGINE"] = engine
if ssl_require:
settings["OPTIONS"] = {}
settings["OPTIONS"]["sslmode"] = "require"
if test_options:
settings["TEST"] = test_options
return settings

View file

@ -1,6 +1,8 @@
import os import os
import re
import unittest import unittest
from unittest import mock from unittest import mock
from urllib.parse import uses_netloc
import dj_database_url import dj_database_url
@ -203,6 +205,24 @@ class DatabaseTestSuite(unittest.TestCase):
assert url["ENGINE"] == "django.db.backends.sqlite3" assert url["ENGINE"] == "django.db.backends.sqlite3"
assert url["NAME"] == ":memory:" assert url["NAME"] == ":memory:"
def test_sqlite_relative_url(self):
url = "sqlite:///db.sqlite3"
config = dj_database_url.parse(url)
assert config["ENGINE"] == "django.db.backends.sqlite3"
assert config["NAME"] == "db.sqlite3"
def test_sqlite_absolute_url(self):
# 4 slashes are needed:
# two are part of scheme
# one separates host:port from path
# and the fourth goes to "NAME" value
url = "sqlite:////db.sqlite3"
config = dj_database_url.parse(url)
assert config["ENGINE"] == "django.db.backends.sqlite3"
assert config["NAME"] == "/db.sqlite3"
def test_parse_engine_setting(self): def test_parse_engine_setting(self):
engine = "django_mysqlpool.backends.mysqlpool" engine = "django_mysqlpool.backends.mysqlpool"
url = "mysql://bea6eb025ca0d8:69772142@us-cdbr-east.cleardb.com/heroku_97681db3eff7580?reconnect=true" url = "mysql://bea6eb025ca0d8:69772142@us-cdbr-east.cleardb.com/heroku_97681db3eff7580?reconnect=true"
@ -588,9 +608,36 @@ class DatabaseTestSuite(unittest.TestCase):
'WARNING:root:No DATABASE_URL environment variable set, and so no databases setup' 'WARNING:root:No DATABASE_URL environment variable set, and so no databases setup'
], cm.output ], cm.output
def test_bad_url_parsing(self): def test_credentials_unquoted__raise_value_error(self):
with self.assertRaisesRegex(ValueError, "No support for 'foo'. We support: "): expected_message = (
dj_database_url.parse("foo://bar") "This string is not a valid url, possibly because some of its parts "
r"is not properly urllib.parse.quote()'ed."
)
with self.assertRaisesRegex(ValueError, re.escape(expected_message)):
dj_database_url.parse("postgres://user:passw#ord!@localhost/foobar")
def test_credentials_quoted__ok(self):
url = "postgres://user%40domain:p%23ssword!@localhost/foobar"
config = dj_database_url.parse(url)
assert config["USER"] == "user@domain"
assert config["PASSWORD"] == "p#ssword!"
def test_unknown_scheme__raise_value_error(self):
expected_message = (
"Scheme 'unknown-scheme://' is unknown. "
"Did you forget to register custom backend? Following schemes have registered backends:"
)
with self.assertRaisesRegex(ValueError, re.escape(expected_message)):
dj_database_url.parse("unknown-scheme://user:password@localhost/foobar")
def test_register_multiple_times__no_duplicates_in_uses_netloc(self):
# make sure that when register() function is misused,
# it won't pollute urllib.parse.uses_netloc list with duplicates.
# Otherwise, it might cause performance issue if some code assumes that
# that list is short and performs linear search on it.
dj_database_url.register("django.contrib.db.backends.bag_end", "bag-end")
dj_database_url.register("django.contrib.db.backends.bag_end", "bag-end")
assert len(uses_netloc) == len(set(uses_netloc))
@mock.patch.dict( @mock.patch.dict(
os.environ, os.environ,