mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Annotate the tracker module
This commit is contained in:
parent
632441ea53
commit
56ea527286
1 changed files with 101 additions and 48 deletions
|
|
@ -2,11 +2,23 @@ from __future__ import annotations
|
|||
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload
|
||||
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import models
|
||||
from django.db.models.fields.files import FieldFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Mapping
|
||||
from types import TracebackType
|
||||
|
||||
class _AugmentedModel(models.Model):
|
||||
_instance_initialized: bool
|
||||
_deferred_fields: set[str]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class LightStateFieldFile(FieldFile):
|
||||
"""
|
||||
|
|
@ -18,7 +30,7 @@ class LightStateFieldFile(FieldFile):
|
|||
Django 3.1+ can make the app unusable, as CPU and memory usage gets easily
|
||||
multiplied by magnitudes.
|
||||
"""
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
"""
|
||||
We don't need to deepcopy the instance, so nullify if provided.
|
||||
"""
|
||||
|
|
@ -28,27 +40,35 @@ class LightStateFieldFile(FieldFile):
|
|||
return state
|
||||
|
||||
|
||||
def lightweight_deepcopy(value):
|
||||
def lightweight_deepcopy(value: T) -> T:
|
||||
"""
|
||||
Use our lightweight class to avoid copying the instance on a FieldFile deepcopy.
|
||||
"""
|
||||
if isinstance(value, FieldFile):
|
||||
value = LightStateFieldFile(
|
||||
value = cast(T, LightStateFieldFile(
|
||||
instance=value.instance,
|
||||
field=value.field,
|
||||
name=value.name,
|
||||
)
|
||||
))
|
||||
return deepcopy(value)
|
||||
|
||||
|
||||
class DescriptorWrapper:
|
||||
|
||||
def __init__(self, field_name, descriptor, tracker_attname):
|
||||
def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str):
|
||||
self.field_name = field_name
|
||||
self.descriptor = descriptor
|
||||
self.tracker_attname = tracker_attname
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field:
|
||||
...
|
||||
|
||||
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field:
|
||||
if instance is None:
|
||||
return self
|
||||
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||
|
|
@ -58,7 +78,7 @@ class DescriptorWrapper:
|
|||
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
|
||||
return value
|
||||
|
||||
def __set__(self, instance, value):
|
||||
def __set__(self, instance: models.Model, value: models.Field) -> None:
|
||||
initialized = hasattr(instance, '_instance_initialized')
|
||||
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||
|
||||
|
|
@ -81,11 +101,11 @@ class DescriptorWrapper:
|
|||
else:
|
||||
instance.__dict__[self.field_name] = value
|
||||
|
||||
def __getattr__(self, attr):
|
||||
def __getattr__(self, attr: str) -> models.Field:
|
||||
return getattr(self.descriptor, attr)
|
||||
|
||||
@staticmethod
|
||||
def cls_for_descriptor(descriptor):
|
||||
def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]:
|
||||
if hasattr(descriptor, '__delete__'):
|
||||
return FullDescriptorWrapper
|
||||
else:
|
||||
|
|
@ -96,8 +116,8 @@ class FullDescriptorWrapper(DescriptorWrapper):
|
|||
"""
|
||||
Wrapper for descriptors with all three descriptor methods.
|
||||
"""
|
||||
def __delete__(self, obj):
|
||||
self.descriptor.__delete__(obj)
|
||||
def __delete__(self, obj: models.Field) -> None:
|
||||
self.descriptor.__delete__(obj) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class FieldsContext:
|
||||
|
|
@ -121,7 +141,12 @@ class FieldsContext:
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, tracker, *fields, state=None):
|
||||
def __init__(
|
||||
self,
|
||||
tracker: FieldInstanceTracker,
|
||||
*fields: str,
|
||||
state: dict[str, int] | None = None
|
||||
):
|
||||
"""
|
||||
:param tracker: FieldInstanceTracker instance to be reset after
|
||||
context exit
|
||||
|
|
@ -139,7 +164,7 @@ class FieldsContext:
|
|||
self.fields = fields
|
||||
self.state = state
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> FieldsContext:
|
||||
"""
|
||||
Increments tracked fields occurrences count in shared state.
|
||||
"""
|
||||
|
|
@ -148,7 +173,12 @@ class FieldsContext:
|
|||
self.state[f] += 1
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
"""
|
||||
Decrements tracked fields occurrences count in shared state.
|
||||
|
||||
|
|
@ -166,29 +196,34 @@ class FieldsContext:
|
|||
|
||||
|
||||
class FieldInstanceTracker:
|
||||
def __init__(self, instance, fields, field_map):
|
||||
self.instance = instance
|
||||
def __init__(self, instance: models.Model, fields: Iterable[str], field_map: Mapping[str, str]):
|
||||
self.instance = cast('_AugmentedModel', instance)
|
||||
self.fields = fields
|
||||
self.field_map = field_map
|
||||
self.context = FieldsContext(self, *self.fields)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> FieldsContext:
|
||||
return self.context.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
return self.context.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def __call__(self, *fields):
|
||||
def __call__(self, *fields: str) -> FieldsContext:
|
||||
return FieldsContext(self, *fields, state=self.context.state)
|
||||
|
||||
@property
|
||||
def deferred_fields(self):
|
||||
def deferred_fields(self) -> set[str]:
|
||||
return self.instance.get_deferred_fields()
|
||||
|
||||
def get_field_value(self, field):
|
||||
def get_field_value(self, field: str) -> Any:
|
||||
return getattr(self.instance, self.field_map[field])
|
||||
|
||||
def set_saved_fields(self, fields=None):
|
||||
def set_saved_fields(self, fields: Iterable[str] | None = None) -> None:
|
||||
if not self.instance.pk:
|
||||
self.saved_data = {}
|
||||
elif fields is None:
|
||||
|
|
@ -200,7 +235,7 @@ class FieldInstanceTracker:
|
|||
for field, field_value in self.saved_data.items():
|
||||
self.saved_data[field] = lightweight_deepcopy(field_value)
|
||||
|
||||
def current(self, fields=None):
|
||||
def current(self, fields: Iterable[str] | None = None) -> dict[str, Any]:
|
||||
"""Returns dict of current values for all tracked fields"""
|
||||
if fields is None:
|
||||
deferred_fields = self.deferred_fields
|
||||
|
|
@ -214,7 +249,7 @@ class FieldInstanceTracker:
|
|||
|
||||
return {f: self.get_field_value(f) for f in fields}
|
||||
|
||||
def has_changed(self, field):
|
||||
def has_changed(self, field: str) -> bool:
|
||||
"""Returns ``True`` if field has changed from currently saved value"""
|
||||
if field in self.fields:
|
||||
# deferred fields haven't changed
|
||||
|
|
@ -224,7 +259,7 @@ class FieldInstanceTracker:
|
|||
else:
|
||||
raise FieldError('field "%s" not tracked' % field)
|
||||
|
||||
def previous(self, field):
|
||||
def previous(self, field: str) -> Any:
|
||||
"""Returns currently saved value of given field"""
|
||||
|
||||
# handle deferred fields that have not yet been loaded from the database
|
||||
|
|
@ -244,7 +279,7 @@ class FieldInstanceTracker:
|
|||
|
||||
return self.saved_data.get(field)
|
||||
|
||||
def changed(self):
|
||||
def changed(self) -> dict[str, Any]:
|
||||
"""Returns dict of fields that changed since save (with old values)"""
|
||||
return {
|
||||
field: self.previous(field)
|
||||
|
|
@ -257,13 +292,18 @@ class FieldTracker:
|
|||
|
||||
tracker_class = FieldInstanceTracker
|
||||
|
||||
def __init__(self, fields=None):
|
||||
self.fields = fields
|
||||
def __init__(self, fields: Iterable[str] | None = None):
|
||||
# finalize_class() will replace None; pretend it is never None.
|
||||
self.fields = cast(Iterable[str], fields)
|
||||
|
||||
def __call__(self, func=None, fields=None):
|
||||
def decorator(f):
|
||||
def __call__(
|
||||
self,
|
||||
func: Callable | None = None,
|
||||
fields: Iterable[str] | None = None
|
||||
) -> Any:
|
||||
def decorator(f: Callable) -> Callable:
|
||||
@wraps(f)
|
||||
def inner(obj, *args, **kwargs):
|
||||
def inner(obj: models.Model, *args: object, **kwargs: object) -> object:
|
||||
tracker = getattr(obj, self.attname)
|
||||
field_list = tracker.fields if fields is None else fields
|
||||
with tracker(*field_list):
|
||||
|
|
@ -274,7 +314,7 @@ class FieldTracker:
|
|||
return decorator
|
||||
return decorator(func)
|
||||
|
||||
def get_field_map(self, cls):
|
||||
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
|
||||
"""Returns dict mapping fields names to model attribute names"""
|
||||
field_map = {field: field for field in self.fields}
|
||||
all_fields = {f.name: f.attname for f in cls._meta.fields}
|
||||
|
|
@ -282,17 +322,17 @@ class FieldTracker:
|
|||
if k in field_map})
|
||||
return field_map
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
def contribute_to_class(self, cls: type[models.Model], name: str) -> None:
|
||||
self.name = name
|
||||
self.attname = '_%s' % name
|
||||
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
|
||||
|
||||
def finalize_class(self, sender, **kwargs):
|
||||
if self.fields is None:
|
||||
def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
|
||||
if self.fields is None or TYPE_CHECKING:
|
||||
self.fields = (field.attname for field in sender._meta.fields)
|
||||
self.fields = set(self.fields)
|
||||
for field_name in self.fields:
|
||||
descriptor = getattr(sender, field_name)
|
||||
descriptor: models.Field = getattr(sender, field_name)
|
||||
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
|
||||
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
|
||||
setattr(sender, field_name, wrapped_descriptor)
|
||||
|
|
@ -302,34 +342,39 @@ class FieldTracker:
|
|||
setattr(sender, self.name, self)
|
||||
self.patch_save(sender)
|
||||
|
||||
def initialize_tracker(self, sender, instance, **kwargs):
|
||||
def initialize_tracker(
|
||||
self,
|
||||
sender: type[models.Model],
|
||||
instance: models.Model,
|
||||
**kwargs: object
|
||||
) -> None:
|
||||
if not isinstance(instance, self.model_class):
|
||||
return # Only init instances of given model (including children)
|
||||
tracker = self.tracker_class(instance, self.fields, self.field_map)
|
||||
setattr(instance, self.attname, tracker)
|
||||
tracker.set_saved_fields()
|
||||
instance._instance_initialized = True
|
||||
cast('_AugmentedModel', instance)._instance_initialized = True
|
||||
|
||||
def patch_init(self, model):
|
||||
def patch_init(self, model: type[models.Model]) -> None:
|
||||
original = getattr(model, '__init__')
|
||||
|
||||
@wraps(original)
|
||||
def inner(instance, *args, **kwargs):
|
||||
def inner(instance: models.Model, *args: Any, **kwargs: Any) -> None:
|
||||
original(instance, *args, **kwargs)
|
||||
self.initialize_tracker(model, instance)
|
||||
|
||||
setattr(model, '__init__', inner)
|
||||
|
||||
def patch_save(self, model):
|
||||
def patch_save(self, model: type[models.Model]) -> None:
|
||||
self._patch(model, 'save_base', 'update_fields')
|
||||
self._patch(model, 'refresh_from_db', 'fields')
|
||||
|
||||
def _patch(self, model, method, fields_kwarg):
|
||||
def _patch(self, model: type[models.Model], method: str, fields_kwarg: str) -> None:
|
||||
original = getattr(model, method)
|
||||
|
||||
@wraps(original)
|
||||
def inner(instance, *args, **kwargs):
|
||||
update_fields = kwargs.get(fields_kwarg)
|
||||
def inner(instance: models.Model, *args: object, **kwargs: Any) -> object:
|
||||
update_fields: Iterable[str] | None = kwargs.get(fields_kwarg)
|
||||
if update_fields is None:
|
||||
fields = self.fields
|
||||
else:
|
||||
|
|
@ -343,7 +388,15 @@ class FieldTracker:
|
|||
|
||||
setattr(model, method, inner)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: type[models.Model]) -> FieldTracker:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: models.Model, owner: type[models.Model]) -> FieldInstanceTracker:
|
||||
...
|
||||
|
||||
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> FieldTracker | FieldInstanceTracker:
|
||||
if instance is None:
|
||||
return self
|
||||
else:
|
||||
|
|
@ -352,7 +405,7 @@ class FieldTracker:
|
|||
|
||||
class ModelInstanceTracker(FieldInstanceTracker):
|
||||
|
||||
def has_changed(self, field):
|
||||
def has_changed(self, field: str) -> bool:
|
||||
"""Returns ``True`` if field has changed from currently saved value"""
|
||||
if not self.instance.pk:
|
||||
return True
|
||||
|
|
@ -361,7 +414,7 @@ class ModelInstanceTracker(FieldInstanceTracker):
|
|||
else:
|
||||
raise FieldError('field "%s" not tracked' % field)
|
||||
|
||||
def changed(self):
|
||||
def changed(self) -> dict[str, Any]:
|
||||
"""Returns dict of fields that changed since save (with old values)"""
|
||||
if not self.instance.pk:
|
||||
return {}
|
||||
|
|
@ -373,5 +426,5 @@ class ModelInstanceTracker(FieldInstanceTracker):
|
|||
class ModelTracker(FieldTracker):
|
||||
tracker_class = ModelInstanceTracker
|
||||
|
||||
def get_field_map(self, cls):
|
||||
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
|
||||
return {field: field for field in self.fields}
|
||||
|
|
|
|||
Loading…
Reference in a new issue