Annotate the tracker module

This commit is contained in:
Maarten ter Huurne 2023-03-17 12:35:37 +01:00
parent 632441ea53
commit 56ea527286

View file

@ -2,11 +2,23 @@ from __future__ import annotations
from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload
from django.core.exceptions import FieldError
from django.db import models
from django.db.models.fields.files import FieldFile
if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from types import TracebackType
class _AugmentedModel(models.Model):
_instance_initialized: bool
_deferred_fields: set[str]
T = TypeVar("T")
class LightStateFieldFile(FieldFile):
"""
@ -18,7 +30,7 @@ class LightStateFieldFile(FieldFile):
Django 3.1+ can make the app unusable, as CPU and memory usage gets easily
multiplied by magnitudes.
"""
def __getstate__(self):
def __getstate__(self) -> dict[str, Any]:
"""
We don't need to deepcopy the instance, so nullify if provided.
"""
@ -28,27 +40,35 @@ class LightStateFieldFile(FieldFile):
return state
def lightweight_deepcopy(value):
def lightweight_deepcopy(value: T) -> T:
"""
Use our lightweight class to avoid copying the instance on a FieldFile deepcopy.
"""
if isinstance(value, FieldFile):
value = LightStateFieldFile(
value = cast(T, LightStateFieldFile(
instance=value.instance,
field=value.field,
name=value.name,
)
))
return deepcopy(value)
class DescriptorWrapper:
def __init__(self, field_name, descriptor, tracker_attname):
def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str):
self.field_name = field_name
self.descriptor = descriptor
self.tracker_attname = tracker_attname
def __get__(self, instance, owner):
@overload
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper:
...
@overload
def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field:
...
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field:
if instance is None:
return self
was_deferred = self.field_name in instance.get_deferred_fields()
@ -58,7 +78,7 @@ class DescriptorWrapper:
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
return value
def __set__(self, instance, value):
def __set__(self, instance: models.Model, value: models.Field) -> None:
initialized = hasattr(instance, '_instance_initialized')
was_deferred = self.field_name in instance.get_deferred_fields()
@ -81,11 +101,11 @@ class DescriptorWrapper:
else:
instance.__dict__[self.field_name] = value
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> models.Field:
return getattr(self.descriptor, attr)
@staticmethod
def cls_for_descriptor(descriptor):
def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]:
if hasattr(descriptor, '__delete__'):
return FullDescriptorWrapper
else:
@ -96,8 +116,8 @@ class FullDescriptorWrapper(DescriptorWrapper):
"""
Wrapper for descriptors with all three descriptor methods.
"""
def __delete__(self, obj):
self.descriptor.__delete__(obj)
def __delete__(self, obj: models.Field) -> None:
self.descriptor.__delete__(obj) # type: ignore[attr-defined]
class FieldsContext:
@ -121,7 +141,12 @@ class FieldsContext:
"""
def __init__(self, tracker, *fields, state=None):
def __init__(
self,
tracker: FieldInstanceTracker,
*fields: str,
state: dict[str, int] | None = None
):
"""
:param tracker: FieldInstanceTracker instance to be reset after
context exit
@ -139,7 +164,7 @@ class FieldsContext:
self.fields = fields
self.state = state
def __enter__(self):
def __enter__(self) -> FieldsContext:
"""
Increments tracked fields occurrences count in shared state.
"""
@ -148,7 +173,12 @@ class FieldsContext:
self.state[f] += 1
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
"""
Decrements tracked fields occurrences count in shared state.
@ -166,29 +196,34 @@ class FieldsContext:
class FieldInstanceTracker:
def __init__(self, instance, fields, field_map):
self.instance = instance
def __init__(self, instance: models.Model, fields: Iterable[str], field_map: Mapping[str, str]):
self.instance = cast('_AugmentedModel', instance)
self.fields = fields
self.field_map = field_map
self.context = FieldsContext(self, *self.fields)
def __enter__(self):
def __enter__(self) -> FieldsContext:
return self.context.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None
) -> None:
return self.context.__exit__(exc_type, exc_val, exc_tb)
def __call__(self, *fields):
def __call__(self, *fields: str) -> FieldsContext:
return FieldsContext(self, *fields, state=self.context.state)
@property
def deferred_fields(self):
def deferred_fields(self) -> set[str]:
return self.instance.get_deferred_fields()
def get_field_value(self, field):
def get_field_value(self, field: str) -> Any:
return getattr(self.instance, self.field_map[field])
def set_saved_fields(self, fields=None):
def set_saved_fields(self, fields: Iterable[str] | None = None) -> None:
if not self.instance.pk:
self.saved_data = {}
elif fields is None:
@ -200,7 +235,7 @@ class FieldInstanceTracker:
for field, field_value in self.saved_data.items():
self.saved_data[field] = lightweight_deepcopy(field_value)
def current(self, fields=None):
def current(self, fields: Iterable[str] | None = None) -> dict[str, Any]:
"""Returns dict of current values for all tracked fields"""
if fields is None:
deferred_fields = self.deferred_fields
@ -214,7 +249,7 @@ class FieldInstanceTracker:
return {f: self.get_field_value(f) for f in fields}
def has_changed(self, field):
def has_changed(self, field: str) -> bool:
"""Returns ``True`` if field has changed from currently saved value"""
if field in self.fields:
# deferred fields haven't changed
@ -224,7 +259,7 @@ class FieldInstanceTracker:
else:
raise FieldError('field "%s" not tracked' % field)
def previous(self, field):
def previous(self, field: str) -> Any:
"""Returns currently saved value of given field"""
# handle deferred fields that have not yet been loaded from the database
@ -244,7 +279,7 @@ class FieldInstanceTracker:
return self.saved_data.get(field)
def changed(self):
def changed(self) -> dict[str, Any]:
"""Returns dict of fields that changed since save (with old values)"""
return {
field: self.previous(field)
@ -257,13 +292,18 @@ class FieldTracker:
tracker_class = FieldInstanceTracker
def __init__(self, fields=None):
self.fields = fields
def __init__(self, fields: Iterable[str] | None = None):
# finalize_class() will replace None; pretend it is never None.
self.fields = cast(Iterable[str], fields)
def __call__(self, func=None, fields=None):
def decorator(f):
def __call__(
self,
func: Callable | None = None,
fields: Iterable[str] | None = None
) -> Any:
def decorator(f: Callable) -> Callable:
@wraps(f)
def inner(obj, *args, **kwargs):
def inner(obj: models.Model, *args: object, **kwargs: object) -> object:
tracker = getattr(obj, self.attname)
field_list = tracker.fields if fields is None else fields
with tracker(*field_list):
@ -274,7 +314,7 @@ class FieldTracker:
return decorator
return decorator(func)
def get_field_map(self, cls):
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
"""Returns dict mapping fields names to model attribute names"""
field_map = {field: field for field in self.fields}
all_fields = {f.name: f.attname for f in cls._meta.fields}
@ -282,17 +322,17 @@ class FieldTracker:
if k in field_map})
return field_map
def contribute_to_class(self, cls, name):
def contribute_to_class(self, cls: type[models.Model], name: str) -> None:
self.name = name
self.attname = '_%s' % name
models.signals.class_prepared.connect(self.finalize_class, sender=cls)
def finalize_class(self, sender, **kwargs):
if self.fields is None:
def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
if self.fields is None or TYPE_CHECKING:
self.fields = (field.attname for field in sender._meta.fields)
self.fields = set(self.fields)
for field_name in self.fields:
descriptor = getattr(sender, field_name)
descriptor: models.Field = getattr(sender, field_name)
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
setattr(sender, field_name, wrapped_descriptor)
@ -302,34 +342,39 @@ class FieldTracker:
setattr(sender, self.name, self)
self.patch_save(sender)
def initialize_tracker(self, sender, instance, **kwargs):
def initialize_tracker(
self,
sender: type[models.Model],
instance: models.Model,
**kwargs: object
) -> None:
if not isinstance(instance, self.model_class):
return # Only init instances of given model (including children)
tracker = self.tracker_class(instance, self.fields, self.field_map)
setattr(instance, self.attname, tracker)
tracker.set_saved_fields()
instance._instance_initialized = True
cast('_AugmentedModel', instance)._instance_initialized = True
def patch_init(self, model):
def patch_init(self, model: type[models.Model]) -> None:
original = getattr(model, '__init__')
@wraps(original)
def inner(instance, *args, **kwargs):
def inner(instance: models.Model, *args: Any, **kwargs: Any) -> None:
original(instance, *args, **kwargs)
self.initialize_tracker(model, instance)
setattr(model, '__init__', inner)
def patch_save(self, model):
def patch_save(self, model: type[models.Model]) -> None:
self._patch(model, 'save_base', 'update_fields')
self._patch(model, 'refresh_from_db', 'fields')
def _patch(self, model, method, fields_kwarg):
def _patch(self, model: type[models.Model], method: str, fields_kwarg: str) -> None:
original = getattr(model, method)
@wraps(original)
def inner(instance, *args, **kwargs):
update_fields = kwargs.get(fields_kwarg)
def inner(instance: models.Model, *args: object, **kwargs: Any) -> object:
update_fields: Iterable[str] | None = kwargs.get(fields_kwarg)
if update_fields is None:
fields = self.fields
else:
@ -343,7 +388,15 @@ class FieldTracker:
setattr(model, method, inner)
def __get__(self, instance, owner):
@overload
def __get__(self, instance: None, owner: type[models.Model]) -> FieldTracker:
...
@overload
def __get__(self, instance: models.Model, owner: type[models.Model]) -> FieldInstanceTracker:
...
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> FieldTracker | FieldInstanceTracker:
if instance is None:
return self
else:
@ -352,7 +405,7 @@ class FieldTracker:
class ModelInstanceTracker(FieldInstanceTracker):
def has_changed(self, field):
def has_changed(self, field: str) -> bool:
"""Returns ``True`` if field has changed from currently saved value"""
if not self.instance.pk:
return True
@ -361,7 +414,7 @@ class ModelInstanceTracker(FieldInstanceTracker):
else:
raise FieldError('field "%s" not tracked' % field)
def changed(self):
def changed(self) -> dict[str, Any]:
"""Returns dict of fields that changed since save (with old values)"""
if not self.instance.pk:
return {}
@ -373,5 +426,5 @@ class ModelInstanceTracker(FieldInstanceTracker):
class ModelTracker(FieldTracker):
tracker_class = ModelInstanceTracker
def get_field_map(self, cls):
def get_field_map(self, cls: type[models.Model]) -> dict[str, str]:
return {field: field for field in self.fields}