use contextvar instead of threadlocal

This commit is contained in:
hamed 2023-11-01 14:33:24 +03:00
parent 2bf675fceb
commit dd0173ca54
2 changed files with 28 additions and 18 deletions

View file

@ -1,32 +1,35 @@
import contextlib
import threading
import time
from contextvars import ContextVar
from functools import partial
from django.contrib.auth import get_user_model
from django.db.models.signals import pre_save
from auditlog.models import LogEntry
threadlocal = threading.local()
auditlog_value = ContextVar('auditlog_value')
auditlog_disabled = ContextVar('auditlog_disabled', default=False)
@contextlib.contextmanager
def set_actor(actor, remote_addr=None):
"""Connect a signal receiver with current user attached."""
# Initialize thread local storage
threadlocal.auditlog = {
context_data = {
"signal_duid": ("set_actor", time.time()),
"remote_addr": remote_addr,
}
auditlog_value.set(context_data)
# Connect signal for automatic logging
set_actor = partial(
_set_actor, user=actor, signal_duid=threadlocal.auditlog["signal_duid"]
_set_actor, user=actor, signal_duid=context_data["signal_duid"]
)
pre_save.connect(
set_actor,
sender=LogEntry,
dispatch_uid=threadlocal.auditlog["signal_duid"],
dispatch_uid=context_data["signal_duid"],
weak=False,
)
@ -34,12 +37,11 @@ def set_actor(actor, remote_addr=None):
yield
finally:
try:
auditlog = threadlocal.auditlog
except AttributeError:
auditlog = auditlog_value.get()
except LookupError:
pass
else:
pre_save.disconnect(sender=LogEntry, dispatch_uid=auditlog["signal_duid"])
del threadlocal.auditlog
def _set_actor(user, sender, instance, signal_duid, **kwargs):
@ -48,14 +50,18 @@ def _set_actor(user, sender, instance, signal_duid, **kwargs):
This function becomes a valid signal receiver when it is curried with the actor and a dispatch id.
"""
try:
auditlog = threadlocal.auditlog
except AttributeError:
auditlog = auditlog_value.get()
except LookupError as ex:
pass
else:
if signal_duid != auditlog["signal_duid"]:
return
if sender == LogEntry and instance.actor is None:
auth_user_model = get_user_model()
if (
sender == LogEntry
and isinstance(user, auth_user_model)
and instance.actor is None
):
instance.actor = user
instance.remote_addr = auditlog["remote_addr"]
@ -63,11 +69,11 @@ def _set_actor(user, sender, instance, signal_duid, **kwargs):
@contextlib.contextmanager
def disable_auditlog():
threadlocal.auditlog_disabled = True
token = auditlog_disabled.set(True)
try:
yield
finally:
try:
del threadlocal.auditlog_disabled
except AttributeError:
auditlog_disabled.reset(token)
except LookupError:
pass

View file

@ -2,7 +2,7 @@ from functools import wraps
from django.conf import settings
from auditlog.context import threadlocal
from auditlog.context import auditlog_disabled
from auditlog.diff import model_instance_diff
from auditlog.models import LogEntry
from auditlog.signals import post_log, pre_log
@ -17,8 +17,12 @@ def check_disable(signal_handler):
@wraps(signal_handler)
def wrapper(*args, **kwargs):
if not getattr(threadlocal, "auditlog_disabled", False) and not (
kwargs.get("raw") and settings.AUDITLOG_DISABLE_ON_RAW_SAVE
try:
auditlog_disabled_value = auditlog_disabled.get()
except LookupError:
auditlog_disabled_value = False
if not auditlog_disabled_value and not (
kwargs.get("raw") and settings.AUDITLOG_DISABLE_ON_RAW_SAVE
):
signal_handler(*args, **kwargs)