mirror of
https://github.com/jazzband/dj-database-url.git
synced 2026-03-16 22:20:24 +00:00
Generic backend registration
This commit is contained in:
parent
bcf163e3ab
commit
b27bd63e82
3 changed files with 286 additions and 137 deletions
63
README.rst
63
README.rst
|
|
@ -22,12 +22,6 @@ also a `conn_max_age` argument to easily enable Django's connection pool.
|
|||
If you'd rather not use an environment variable, you can pass a URL in directly
|
||||
instead to ``dj_database_url.parse``.
|
||||
|
||||
Supported Databases
|
||||
-------------------
|
||||
|
||||
Support currently exists for PostgreSQL, PostGIS, MySQL, MySQL (GIS),
|
||||
Oracle, Oracle (GIS), Redshift, CockroachDB, Timescale, Timescale (GIS) and SQLite.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
|
|
@ -148,6 +142,63 @@ and should instead be passed as:
|
|||
DATABASES['default'] = dj_database_url.config(default='postgres://...', test_options={'NAME': 'mytestdatabase'})
|
||||
|
||||
|
||||
Supported Databases
|
||||
-------------------
|
||||
|
||||
Support currently exists for PostgreSQL, PostGIS, MySQL, MySQL (GIS),
|
||||
Oracle, Oracle (GIS), Redshift, CockroachDB, Timescale, Timescale (GIS) and SQLite.
|
||||
|
||||
If you want to use
|
||||
some non-default backends, you need to register them first:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import dj_database_url
|
||||
|
||||
# registration should be performed only once
|
||||
dj_database_url.register("mysql-connector", "mysql.connector.django")
|
||||
|
||||
assert dj_database_url.parse("mysql-connector://user:password@host:port/db-name") == {
|
||||
"ENGINE": "mysql.connector.django",
|
||||
# ...other connection params
|
||||
}
|
||||
|
||||
Some backends need further config adjustments (e.g. oracle and mssql
|
||||
expect ``PORT`` to be a string). For such cases you can provide a
|
||||
post-processing function to ``register()`` (note that ``register()`` is
|
||||
used as a **decorator(!)** in this case):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import dj_database_url
|
||||
|
||||
@dj_database_url.register("mssql", "sql_server.pyodbc")
|
||||
def stringify_port(config):
|
||||
config["PORT"] = str(config["PORT"])
|
||||
|
||||
@dj_database_url.register("redshift", "django_redshift_backend")
|
||||
def apply_current_schema(config):
|
||||
options = config["OPTIONS"]
|
||||
schema = options.pop("currentSchema", None)
|
||||
if schema:
|
||||
options["options"] = f"-c search_path={schema}"
|
||||
|
||||
@dj_database_url.register("snowflake", "django_snowflake")
|
||||
def adjust_snowflake_config(config):
|
||||
config.pop("PORT", None)
|
||||
config["ACCOUNT"] = config.pop("HOST")
|
||||
name, _, schema = config["NAME"].partition("/")
|
||||
if schema:
|
||||
config["SCHEMA"] = schema
|
||||
config["NAME"] = name
|
||||
options = config.get("OPTIONS", {})
|
||||
warehouse = options.pop("warehouse", None)
|
||||
if warehouse:
|
||||
config["WAREHOUSE"] = warehouse
|
||||
role = options.pop("role", None)
|
||||
if role:
|
||||
config["ROLE"] = role
|
||||
|
||||
URL schema
|
||||
----------
|
||||
|
||||
|
|
|
|||
|
|
@ -1,50 +1,15 @@
|
|||
import logging
|
||||
import os
|
||||
import urllib.parse as urlparse
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
DEFAULT_ENV = "DATABASE_URL"
|
||||
|
||||
SCHEMES = {
|
||||
"postgres": "django.db.backends.postgresql",
|
||||
"postgresql": "django.db.backends.postgresql",
|
||||
"pgsql": "django.db.backends.postgresql",
|
||||
"postgis": "django.contrib.gis.db.backends.postgis",
|
||||
"mysql": "django.db.backends.mysql",
|
||||
"mysql2": "django.db.backends.mysql",
|
||||
"mysqlgis": "django.contrib.gis.db.backends.mysql",
|
||||
"mysql-connector": "mysql.connector.django",
|
||||
"mssql": "sql_server.pyodbc",
|
||||
"mssqlms": "mssql",
|
||||
"spatialite": "django.contrib.gis.db.backends.spatialite",
|
||||
"sqlite": "django.db.backends.sqlite3",
|
||||
"oracle": "django.db.backends.oracle",
|
||||
"oraclegis": "django.contrib.gis.db.backends.oracle",
|
||||
"redshift": "django_redshift_backend",
|
||||
"cockroach": "django_cockroachdb",
|
||||
"timescale": "timescale.db.backends.postgresql",
|
||||
"timescalegis": "timescale.db.backends.postgis",
|
||||
}
|
||||
|
||||
SCHEMES_WITH_SEARCH_PATH = [
|
||||
"postgres",
|
||||
"postgresql",
|
||||
"pgsql",
|
||||
"postgis",
|
||||
"redshift",
|
||||
"timescale",
|
||||
"timescalegis",
|
||||
]
|
||||
|
||||
# Register database schemes in URLs.
|
||||
for key in SCHEMES.keys():
|
||||
urlparse.uses_netloc.append(key)
|
||||
del key
|
||||
ENGINE_SCHEMES: Dict[str, "Engine"] = {}
|
||||
|
||||
|
||||
# From https://docs.djangoproject.com/en/4.0/ref/settings/#databases
|
||||
# From https://docs.djangoproject.com/en/stable/ref/settings/#databases
|
||||
class DBConfig(TypedDict, total=False):
|
||||
ATOMIC_REQUESTS: bool
|
||||
AUTOCOMMIT: bool
|
||||
|
|
@ -62,11 +27,109 @@ class DBConfig(TypedDict, total=False):
|
|||
USER: str
|
||||
|
||||
|
||||
PostprocessCallable = Callable[[DBConfig], None]
|
||||
OptionType = Union[int, str, bool]
|
||||
|
||||
|
||||
class ParseError(ValueError):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"This string is not a valid url, possibly because some of its parts"
|
||||
" is not properly urllib.parse.quote()'ed."
|
||||
)
|
||||
|
||||
|
||||
class UnknownSchemeError(ValueError):
|
||||
def __init__(self, scheme: str) -> None:
|
||||
self.scheme = scheme
|
||||
|
||||
def __str__(self) -> str:
|
||||
schemes = ", ".join(sorted(ENGINE_SCHEMES.keys()))
|
||||
return (
|
||||
f"Scheme '{self.scheme}://' is unknown."
|
||||
" Did you forget to register custom backend?"
|
||||
f" Following schemes have registered backends: {schemes}."
|
||||
)
|
||||
|
||||
|
||||
def default_postprocess(parsed_config: DBConfig) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(
|
||||
self,
|
||||
backend: str,
|
||||
postprocess: PostprocessCallable = default_postprocess,
|
||||
) -> None:
|
||||
self.backend = backend
|
||||
self.postprocess = postprocess
|
||||
|
||||
|
||||
def register(
|
||||
scheme: str, backend: str
|
||||
) -> Callable[[PostprocessCallable], PostprocessCallable]:
|
||||
engine = Engine(backend)
|
||||
if scheme not in ENGINE_SCHEMES:
|
||||
urlparse.uses_netloc.append(scheme)
|
||||
ENGINE_SCHEMES[scheme] = engine
|
||||
|
||||
def inner(func: PostprocessCallable) -> PostprocessCallable:
|
||||
engine.postprocess = func
|
||||
return func
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
register("spatialite", "django.contrib.gis.db.backends.spatialite")
|
||||
register("mysql-connector", "mysql.connector.django")
|
||||
register("mysqlgis", "django.contrib.gis.db.backends.mysql")
|
||||
register("oraclegis", "django.contrib.gis.db.backends.oracle")
|
||||
register("cockroach", "django_cockroachdb")
|
||||
|
||||
|
||||
@register("sqlite", "django.db.backends.sqlite3")
|
||||
def default_to_in_memory_db(parsed_config: DBConfig) -> None:
|
||||
# mimic sqlalchemy behaviour
|
||||
if parsed_config["NAME"] == "":
|
||||
parsed_config["NAME"] = ":memory:"
|
||||
|
||||
|
||||
@register("oracle", "django.db.backends.oracle")
|
||||
@register("mssqlms", "mssql")
|
||||
@register("mssql", "sql_server.pyodbc")
|
||||
def stringify_port(parsed_config: DBConfig) -> None:
|
||||
parsed_config["PORT"] = str(parsed_config["PORT"])
|
||||
|
||||
|
||||
@register("mysql", "django.db.backends.mysql")
|
||||
@register("mysql2", "django.db.backends.mysql")
|
||||
def apply_ssl_ca(parsed_config: DBConfig) -> None:
|
||||
options = parsed_config["OPTIONS"]
|
||||
ca = options.pop("ssl-ca", None)
|
||||
if ca:
|
||||
options["ssl"] = {"ca": ca}
|
||||
|
||||
|
||||
@register("postgres", "django.db.backends.postgresql")
|
||||
@register("postgresql", "django.db.backends.postgresql")
|
||||
@register("pgsql", "django.db.backends.postgresql")
|
||||
@register("postgis", "django.contrib.gis.db.backends.postgis")
|
||||
@register("redshift", "django_redshift_backend")
|
||||
@register("timescale", "timescale.db.backends.postgresql")
|
||||
@register("timescalegis", "timescale.db.backends.postgis")
|
||||
def apply_current_schema(parsed_config: DBConfig) -> None:
|
||||
options = parsed_config["OPTIONS"]
|
||||
schema = options.pop("currentSchema", None)
|
||||
if schema:
|
||||
options["options"] = f"-c search_path={schema}"
|
||||
|
||||
|
||||
def config(
|
||||
env: str = DEFAULT_ENV,
|
||||
default: Optional[str] = None,
|
||||
engine: Optional[str] = None,
|
||||
conn_max_age: Optional[int] = 0,
|
||||
conn_max_age: int = 0,
|
||||
conn_health_checks: bool = False,
|
||||
disable_server_side_cursors: bool = False,
|
||||
ssl_require: bool = False,
|
||||
|
|
@ -77,7 +140,7 @@ def config(
|
|||
|
||||
if s is None:
|
||||
logging.warning(
|
||||
"No %s environment variable set, and so no databases setup" % env
|
||||
"No %s environment variable set, and so no databases setup", env
|
||||
)
|
||||
|
||||
if s:
|
||||
|
|
@ -97,107 +160,95 @@ def config(
|
|||
def parse(
|
||||
url: str,
|
||||
engine: Optional[str] = None,
|
||||
conn_max_age: Optional[int] = 0,
|
||||
conn_max_age: int = 0,
|
||||
conn_health_checks: bool = False,
|
||||
disable_server_side_cursors: bool = False,
|
||||
ssl_require: bool = False,
|
||||
test_options: Optional[dict] = None,
|
||||
) -> DBConfig:
|
||||
"""Parses a database URL."""
|
||||
"""Parses a database URL and returns configured DATABASE dictionary."""
|
||||
settings = _convert_to_settings(
|
||||
engine,
|
||||
conn_max_age,
|
||||
conn_health_checks,
|
||||
disable_server_side_cursors,
|
||||
ssl_require,
|
||||
test_options,
|
||||
)
|
||||
|
||||
if url == "sqlite://:memory:":
|
||||
# this is a special case, because if we pass this URL into
|
||||
# urlparse, urlparse will choke trying to interpret "memory"
|
||||
# as a port number
|
||||
return {"ENGINE": SCHEMES["sqlite"], "NAME": ":memory:"}
|
||||
return {"ENGINE": ENGINE_SCHEMES["sqlite"].backend, "NAME": ":memory:"}
|
||||
# note: no other settings are required for sqlite
|
||||
|
||||
# otherwise parse the url as normal
|
||||
parsed_config: DBConfig = {}
|
||||
|
||||
if test_options is None:
|
||||
test_options = {}
|
||||
|
||||
spliturl = urlparse.urlsplit(url)
|
||||
|
||||
# Split query strings from path.
|
||||
path = spliturl.path[1:]
|
||||
query = urlparse.parse_qs(spliturl.query)
|
||||
|
||||
# If we are using sqlite and we have no path, then assume we
|
||||
# want an in-memory database (this is the behaviour of sqlalchemy)
|
||||
if spliturl.scheme == "sqlite" and path == "":
|
||||
path = ":memory:"
|
||||
|
||||
# Handle postgres percent-encoded paths.
|
||||
hostname = spliturl.hostname or ""
|
||||
if "%" in hostname:
|
||||
# Switch to url.netloc to avoid lower cased paths
|
||||
hostname = spliturl.netloc
|
||||
if "@" in hostname:
|
||||
hostname = hostname.rsplit("@", 1)[1]
|
||||
# Use URL Parse library to decode % encodes
|
||||
hostname = urlparse.unquote(hostname)
|
||||
|
||||
# Lookup specified engine.
|
||||
if engine is None:
|
||||
engine = SCHEMES.get(spliturl.scheme)
|
||||
if engine is None:
|
||||
raise ValueError(
|
||||
"No support for '%s'. We support: %s"
|
||||
% (spliturl.scheme, ", ".join(sorted(SCHEMES.keys())))
|
||||
)
|
||||
|
||||
port = (
|
||||
str(spliturl.port)
|
||||
if spliturl.port
|
||||
and engine in (SCHEMES["oracle"], SCHEMES["mssql"], SCHEMES["mssqlms"])
|
||||
else spliturl.port
|
||||
)
|
||||
|
||||
# Update with environment configuration.
|
||||
parsed_config.update(
|
||||
{
|
||||
"NAME": urlparse.unquote(path or ""),
|
||||
"USER": urlparse.unquote(spliturl.username or ""),
|
||||
"PASSWORD": urlparse.unquote(spliturl.password or ""),
|
||||
"HOST": hostname,
|
||||
"PORT": port or "",
|
||||
"CONN_MAX_AGE": conn_max_age,
|
||||
"CONN_HEALTH_CHECKS": conn_health_checks,
|
||||
"DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors,
|
||||
"ENGINE": engine,
|
||||
try:
|
||||
split_result = urlparse.urlsplit(url)
|
||||
engine_obj = ENGINE_SCHEMES.get(split_result.scheme)
|
||||
if engine_obj is None:
|
||||
raise UnknownSchemeError(split_result.scheme)
|
||||
path = split_result.path[1:]
|
||||
query = urlparse.parse_qs(split_result.query)
|
||||
options = {k: _parse_option_values(v) for k, v in query.items()}
|
||||
parsed_config: DBConfig = {
|
||||
"ENGINE": engine_obj.backend,
|
||||
"USER": urlparse.unquote(split_result.username or ""),
|
||||
"PASSWORD": urlparse.unquote(split_result.password or ""),
|
||||
"HOST": urlparse.unquote(split_result.hostname or ""),
|
||||
"PORT": split_result.port or "",
|
||||
"NAME": urlparse.unquote(path),
|
||||
"OPTIONS": options,
|
||||
}
|
||||
)
|
||||
if test_options:
|
||||
parsed_config.update(
|
||||
{
|
||||
'TEST': test_options,
|
||||
}
|
||||
)
|
||||
except UnknownSchemeError:
|
||||
raise
|
||||
except ValueError:
|
||||
raise ParseError() from None
|
||||
|
||||
# Pass the query string into OPTIONS.
|
||||
options: Dict[str, Any] = {}
|
||||
for key, values in query.items():
|
||||
if spliturl.scheme == "mysql" and key == "ssl-ca":
|
||||
options["ssl"] = {"ca": values[-1]}
|
||||
continue
|
||||
# Guarantee that config has options, possibly empty, when postprocess() is called
|
||||
assert isinstance(parsed_config["OPTIONS"], dict)
|
||||
engine_obj.postprocess(parsed_config)
|
||||
|
||||
value = values[-1]
|
||||
if value.isdigit():
|
||||
options[key] = int(value)
|
||||
elif value.lower() in ("true", "false"):
|
||||
options[key] = value.lower() == "true"
|
||||
else:
|
||||
options[key] = value
|
||||
|
||||
if ssl_require:
|
||||
options["sslmode"] = "require"
|
||||
|
||||
# Support for Postgres Schema URLs
|
||||
if "currentSchema" in options and spliturl.scheme in SCHEMES_WITH_SEARCH_PATH:
|
||||
options["options"] = "-c search_path={0}".format(options.pop("currentSchema"))
|
||||
|
||||
if options:
|
||||
parsed_config["OPTIONS"] = options
|
||||
# Update the final config with any settings passed in explicitly.
|
||||
parsed_config["OPTIONS"].update(settings.pop("OPTIONS", {}))
|
||||
parsed_config.update(settings)
|
||||
|
||||
if not parsed_config["OPTIONS"]:
|
||||
parsed_config.pop("OPTIONS")
|
||||
return parsed_config
|
||||
|
||||
|
||||
def _parse_option_values(values: List[str]) -> Union[OptionType, List[OptionType]]:
|
||||
parsed_values = [_parse_value(v) for v in values]
|
||||
return parsed_values[0] if len(parsed_values) == 1 else parsed_values
|
||||
|
||||
|
||||
def _parse_value(value: str) -> OptionType:
|
||||
if value.isdigit():
|
||||
return int(value)
|
||||
if value.lower() in ("true", "false"):
|
||||
return value.lower() == "true"
|
||||
return value
|
||||
|
||||
|
||||
def _convert_to_settings(
|
||||
engine: Optional[str],
|
||||
conn_max_age: int,
|
||||
conn_health_checks: bool,
|
||||
disable_server_side_cursors: bool,
|
||||
ssl_require: bool,
|
||||
test_options: Optional[dict],
|
||||
) -> DBConfig:
|
||||
settings: DBConfig = {
|
||||
"CONN_MAX_AGE": conn_max_age,
|
||||
"CONN_HEALTH_CHECKS": conn_health_checks,
|
||||
"DISABLE_SERVER_SIDE_CURSORS": disable_server_side_cursors,
|
||||
}
|
||||
if engine:
|
||||
settings["ENGINE"] = engine
|
||||
if ssl_require:
|
||||
settings["OPTIONS"] = {}
|
||||
settings["OPTIONS"]["sslmode"] = "require"
|
||||
if test_options:
|
||||
settings["TEST"] = test_options
|
||||
return settings
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import re
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from urllib.parse import uses_netloc
|
||||
|
||||
import dj_database_url
|
||||
|
||||
|
|
@ -203,6 +205,24 @@ class DatabaseTestSuite(unittest.TestCase):
|
|||
assert url["ENGINE"] == "django.db.backends.sqlite3"
|
||||
assert url["NAME"] == ":memory:"
|
||||
|
||||
def test_sqlite_relative_url(self):
|
||||
url = "sqlite:///db.sqlite3"
|
||||
config = dj_database_url.parse(url)
|
||||
|
||||
assert config["ENGINE"] == "django.db.backends.sqlite3"
|
||||
assert config["NAME"] == "db.sqlite3"
|
||||
|
||||
def test_sqlite_absolute_url(self):
|
||||
# 4 slashes are needed:
|
||||
# two are part of scheme
|
||||
# one separates host:port from path
|
||||
# and the fourth goes to "NAME" value
|
||||
url = "sqlite:////db.sqlite3"
|
||||
config = dj_database_url.parse(url)
|
||||
|
||||
assert config["ENGINE"] == "django.db.backends.sqlite3"
|
||||
assert config["NAME"] == "/db.sqlite3"
|
||||
|
||||
def test_parse_engine_setting(self):
|
||||
engine = "django_mysqlpool.backends.mysqlpool"
|
||||
url = "mysql://bea6eb025ca0d8:69772142@us-cdbr-east.cleardb.com/heroku_97681db3eff7580?reconnect=true"
|
||||
|
|
@ -588,9 +608,36 @@ class DatabaseTestSuite(unittest.TestCase):
|
|||
'WARNING:root:No DATABASE_URL environment variable set, and so no databases setup'
|
||||
], cm.output
|
||||
|
||||
def test_bad_url_parsing(self):
|
||||
with self.assertRaisesRegex(ValueError, "No support for 'foo'. We support: "):
|
||||
dj_database_url.parse("foo://bar")
|
||||
def test_credentials_unquoted__raise_value_error(self):
|
||||
expected_message = (
|
||||
"This string is not a valid url, possibly because some of its parts "
|
||||
r"is not properly urllib.parse.quote()'ed."
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, re.escape(expected_message)):
|
||||
dj_database_url.parse("postgres://user:passw#ord!@localhost/foobar")
|
||||
|
||||
def test_credentials_quoted__ok(self):
|
||||
url = "postgres://user%40domain:p%23ssword!@localhost/foobar"
|
||||
config = dj_database_url.parse(url)
|
||||
assert config["USER"] == "user@domain"
|
||||
assert config["PASSWORD"] == "p#ssword!"
|
||||
|
||||
def test_unknown_scheme__raise_value_error(self):
|
||||
expected_message = (
|
||||
"Scheme 'unknown-scheme://' is unknown. "
|
||||
"Did you forget to register custom backend? Following schemes have registered backends:"
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, re.escape(expected_message)):
|
||||
dj_database_url.parse("unknown-scheme://user:password@localhost/foobar")
|
||||
|
||||
def test_register_multiple_times__no_duplicates_in_uses_netloc(self):
|
||||
# make sure that when register() function is misused,
|
||||
# it won't pollute urllib.parse.uses_netloc list with duplicates.
|
||||
# Otherwise, it might cause performance issue if some code assumes that
|
||||
# that list is short and performs linear search on it.
|
||||
dj_database_url.register("django.contrib.db.backends.bag_end", "bag-end")
|
||||
dj_database_url.register("django.contrib.db.backends.bag_end", "bag-end")
|
||||
assert len(uses_netloc) == len(set(uses_netloc))
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
|
|
|
|||
Loading…
Reference in a new issue