django-cachalot/benchmark.py
Andrew-Chen-Wang 8a55454557 Support benchmarks for MacOS
* Added how to run benchmarks in docs and README
2020-07-09 18:33:05 -04:00

365 lines
12 KiB
Python
Executable file
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import io
import os
import platform
import re
import sqlite3
from collections import OrderedDict
from datetime import datetime
from random import choice
from subprocess import check_output
from time import time
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "settings")
import django
django.setup()
import matplotlib.pyplot as plt
import pandas as pd
import psycopg2
from django.conf import settings
from django.contrib.auth.models import Group, User
from django.core.cache import caches
from django.db import connection, connections
from django.test.utils import CaptureQueriesContext, override_settings
from django.utils.encoding import force_text
from MySQLdb import _mysql
import cachalot
from cachalot.api import invalidate
from cachalot.tests.models import Test
RESULTS_PATH = f"benchmark/docs/{datetime.now().date()}/"
CONTEXTS = ("Control", "Cold cache", "Hot cache")
DIVIDER = "divider"
LINUX_DATA_PATH = "/var/lib/"
DISK_DATA_RE = re.compile(r'^MODEL="(.*)" MOUNTPOINT="(.*)"$')
def get_disk_model_for_path_linux(path):
out = force_text(check_output(["lsblk", "-Po", "MODEL,MOUNTPOINT"]))
mount_points = []
previous_model = None
for model, mount_point in [
DISK_DATA_RE.match(line).groups() for line in out.split("\n") if line
]:
if model:
previous_model = model.strip()
if mount_point:
mount_points.append((previous_model, mount_point))
mount_points = sorted(mount_points, key=lambda t: -len(t[1]))
for model, mount_point in mount_points:
if path.startswith(mount_point):
return model
def write_conditions():
versions = OrderedDict()
distribution = platform.uname()
# Linux
if distribution.system == "Linux":
# CPU
with open("/proc/cpuinfo") as f:
versions["CPU"] = re.search(
r"^model name\s+: (.+)$", f.read(), flags=re.MULTILINE
).group(1)
# RAM
with open("/proc/meminfo") as f:
versions["RAM"] = re.search(
r"^MemTotal:\s+(.+)$", f.read(), flags=re.MULTILINE
).group(1)
# Disk Model
versions.update((("Disk", get_disk_model_for_path_linux(LINUX_DATA_PATH)),))
# OS
versions["Linux distribution"] = f"{distribution.system} {distribution.release}"
# Darwin
else:
# CPU
versions["CPU"] = os.popen("sysctl -n machdep.cpu.brand_string").read().rstrip("\n")
# RAM
versions["RAM"] = os.popen("sysctl -n hw.memsize").read().rstrip("\n")
# Disk Model
versions["DISK"] = os.popen(
"diskutil info /dev/disk0 | grep 'Device / Media Name'"
).read().split(":")[1].rstrip("\n").lstrip(" ")
# OS
versions["OS"] = f"{distribution.system} {distribution.release}"
versions.update(
(
("Python", platform.python_version()),
("Django", django.__version__),
("cachalot", cachalot.__version__),
("sqlite", sqlite3.sqlite_version),
)
)
# PostgreSQL
try:
with connections["postgresql"].cursor() as cursor:
cursor.execute("SELECT version();")
versions["PostgreSQL"] = re.match(
r"^PostgreSQL\s+(\S+)\s", cursor.fetchone()[0]
).group(1)
except django.db.utils.OperationalError:
raise django.db.utils.OperationalError(
"You need a PostgreSQL DB called \"cachalot\" first. "
"Login with \"psql -U postgres -h localhost\" and run: "
"CREATE DATABASE cachalot;"
)
# MySQL
try:
with connections["mysql"].cursor() as cursor:
cursor.execute("SELECT version();")
versions["MySQL"] = cursor.fetchone()[0].split("-")[0]
except django.db.utils.OperationalError:
raise django.db.utils.OperationalError(
"You need a MySQL DB called \"cachalot\" first. "
"Login with \"mysql -u root\" and run: CREATE DATABASE cachalot;"
)
# Redis
out = force_text(check_output(["redis-cli", "INFO", "server"])).replace("\r", "")
versions["Redis"] = re.search(
r"^redis_version:([\d\.]+)$", out, flags=re.MULTILINE
).group(1)
# memcached
out = force_text(check_output(["memcached", "-h"]))
versions["memcached"] = re.match(
r"^memcached ([\d\.]+)$", out, flags=re.MULTILINE
).group(1)
versions.update(
(
("psycopg2", psycopg2.__version__.split()[0]),
("mysqlclient", _mysql.__version__),
)
)
with io.open(os.path.join(RESULTS_PATH, "conditions.rst"), "w") as f:
f.write(
"In this benchmark, a small database is generated, "
"and each test is executed %s times "
"under the following conditions:\n\n" % Benchmark.n
)
def write_table_sep(char="="):
f.write((char * 20) + " " + (char * 50) + "\n")
write_table_sep()
for k, v in versions.items():
f.write(k.ljust(20) + " " + v + "\n")
write_table_sep()
class AssertNumQueries(CaptureQueriesContext):
def __init__(self, n, using=None):
self.n = n
self.using = using
super(AssertNumQueries, self).__init__(self.get_connection())
def get_connection(self):
if self.using is None:
return connection
return connections[self.using]
def __exit__(self, exc_type, exc_val, exc_tb):
super(AssertNumQueries, self).__exit__(exc_type, exc_val, exc_tb)
if len(self) != self.n:
print(
"The amount of queries should be %s, but %s were captured."
% (self.n, len(self))
)
class Benchmark(object):
n = 20
def __init__(self):
self.data = []
def bench_once(self, context, num_queries, invalidate_before=False):
for _ in range(self.n):
if invalidate_before:
invalidate(db_alias=self.db_alias)
with AssertNumQueries(num_queries, using=self.db_alias):
start = time()
self.query_function(self.db_alias)
end = time()
self.data.append(
{
"query": self.query_name,
"time": end - start,
"context": context,
"db": self.db_vendor,
"cache": self.cache_name,
}
)
def benchmark(self, query_str, to_list=True, num_queries=1):
# Clears the cache before a single benchmark to ensure the same
# conditions across single benchmarks.
caches[settings.CACHALOT_CACHE].clear()
self.query_name = query_str
query_str = "Test.objects.using(using)" + query_str
if to_list:
query_str = "list(%s)" % query_str
self.query_function = eval("lambda using: " + query_str)
with override_settings(CACHALOT_ENABLED=False):
self.bench_once(CONTEXTS[0], num_queries)
self.bench_once(CONTEXTS[1], num_queries, invalidate_before=True)
self.bench_once(CONTEXTS[2], 0)
def execute_benchmark(self):
self.benchmark(".count()", to_list=False)
self.benchmark(".first()", to_list=False)
self.benchmark("[:10]")
self.benchmark("[5000:5010]")
self.benchmark(".filter(name__icontains='e')[0:10]")
self.benchmark(".filter(name__icontains='e')[5000:5010]")
self.benchmark(".order_by('owner')[0:10]")
self.benchmark(".order_by('owner')[5000:5010]")
self.benchmark(".select_related('owner')[0:10]")
self.benchmark(".select_related('owner')[5000:5010]")
self.benchmark(".prefetch_related('owner__groups')[0:10]", num_queries=3)
self.benchmark(".prefetch_related('owner__groups')[5000:5010]", num_queries=3)
def run(self):
for db_alias in settings.DATABASES:
self.db_alias = db_alias
self.db_vendor = connections[self.db_alias].vendor
print("Benchmarking %s" % self.db_vendor)
for cache_alias in settings.CACHES:
cache = caches[cache_alias]
self.cache_name = cache.__class__.__name__[:-5].lower()
with override_settings(CACHALOT_CACHE=cache_alias):
self.execute_benchmark()
self.df = pd.DataFrame.from_records(self.data)
if not os.path.exists(RESULTS_PATH):
os.mkdir(RESULTS_PATH)
self.df.to_csv(os.path.join(RESULTS_PATH, "data.csv"))
self.xlim = (0, self.df["time"].max() * 1.01)
self.output("db")
self.output("cache")
def output(self, param):
gp = self.df.groupby(["context", "query", param])["time"]
self.means = gp.mean().unstack().unstack().reindex(CONTEXTS)
los = self.means - gp.min().unstack().unstack().reindex(CONTEXTS)
ups = gp.max().unstack().unstack().reindex(CONTEXTS) - self.means
self.errors = dict(
(
key,
dict(
(
subkey,
[
[los[key][subkey][context] for context in self.means.index],
[ups[key][subkey][context] for context in self.means.index],
],
)
for subkey in self.means.columns.levels[1]
),
)
for key in self.means.columns.levels[0]
)
self.get_perfs(param)
self.plot_detail(param)
gp = self.df.groupby(["context", param])["time"]
self.means = gp.mean().unstack().reindex(CONTEXTS)
los = self.means - gp.min().unstack().reindex(CONTEXTS)
ups = gp.max().unstack().reindex(CONTEXTS) - self.means
self.errors = [
[
[los[key][context] for context in self.means.index],
[ups[key][context] for context in self.means.index],
]
for key in self.means
]
self.plot_general(param)
def get_perfs(self, param):
with io.open(os.path.join(RESULTS_PATH, param + "_results.rst"), "w") as f:
for v in self.means.columns.levels[0]:
g = self.means[v].mean(axis=1)
perf = "%s is %.1f× slower then %.1f× faster" % (
v.ljust(10),
g[CONTEXTS[1]] / g[CONTEXTS[0]],
g[CONTEXTS[0]] / g[CONTEXTS[2]],
)
print(perf)
f.write("- %s\n" % perf)
def plot_detail(self, param):
for v in self.means.columns.levels[0]:
plt.figure()
axes = self.means[v].plot(
kind="barh",
xerr=self.errors[v],
xlim=self.xlim,
figsize=(15, 15),
subplots=True,
layout=(6, 2),
sharey=True,
legend=False,
)
plt.gca().invert_yaxis()
for row in axes:
for ax in row:
ax.xaxis.grid(True)
ax.set_ylabel("")
ax.set_xlabel("Time (s)")
plt.savefig(os.path.join(RESULTS_PATH, "%s_%s.svg" % (param, v)))
def plot_general(self, param):
plt.figure()
ax = self.means.plot(kind="barh", xerr=self.errors, xlim=self.xlim)
ax.invert_yaxis()
ax.xaxis.grid(True)
ax.set_ylabel("")
ax.set_xlabel("Time (s)")
plt.savefig(os.path.join(RESULTS_PATH, "%s.svg" % param))
def create_data(using):
User.objects.using(using).bulk_create(
[User(username="user%d" % i) for i in range(50)]
)
Group.objects.using(using).bulk_create(
[Group(name="test%d" % i) for i in range(10)]
)
groups = list(Group.objects.using(using))
for u in User.objects.using(using):
u.groups.add(choice(groups), choice(groups))
users = list(User.objects.using(using))
Test.objects.using(using).bulk_create(
[Test(name="test%d" % i, owner=choice(users)) for i in range(10000)]
)
if __name__ == "__main__":
if not os.path.exists(RESULTS_PATH):
os.mkdir(RESULTS_PATH)
write_conditions()
old_db_names = {}
for alias in connections:
conn = connections[alias]
old_db_names[alias] = conn.settings_dict["NAME"]
conn.creation.create_test_db(autoclobber=True)
print("Populating %s" % connections[alias].vendor)
create_data(alias)
Benchmark().run()
for alias in connections:
connections[alias].creation.destroy_test_db(old_db_names[alias])