diff --git a/model_utils/tracker.py b/model_utils/tracker.py index 266ff3b..6109380 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -2,7 +2,16 @@ from __future__ import annotations from copy import deepcopy from functools import wraps -from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + Protocol, + TypeVar, + cast, + overload, +) from django.core.exceptions import FieldError from django.db import models @@ -16,10 +25,22 @@ if TYPE_CHECKING: _instance_initialized: bool _deferred_fields: set[str] - T = TypeVar("T") +class Descriptor(Protocol[T]): + def __get__(self, instance: object, owner: type[object]) -> T: + ... + + def __set__(self, instance: object, value: T) -> None: + ... + + +class FullDescriptor(Descriptor[T]): + def __delete__(self, instance: object) -> None: + ... + + class LightStateFieldFile(FieldFile): """ FieldFile subclass with the only aim to remove the instance from the state. @@ -53,22 +74,22 @@ def lightweight_deepcopy(value: T) -> T: return deepcopy(value) -class DescriptorWrapper: +class DescriptorWrapper(Generic[T]): - def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str): + def __init__(self, field_name: str, descriptor: Descriptor[T], tracker_attname: str): self.field_name = field_name self.descriptor = descriptor self.tracker_attname = tracker_attname @overload - def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper: + def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper[T]: ... @overload - def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field: + def __get__(self, instance: models.Model, owner: type[models.Model]) -> T: ... - def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field: + def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper[T] | T: if instance is None: return self was_deferred = self.field_name in instance.get_deferred_fields() @@ -78,7 +99,7 @@ class DescriptorWrapper: tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value) return value - def __set__(self, instance: models.Model, value: models.Field) -> None: + def __set__(self, instance: models.Model, value: T) -> None: initialized = hasattr(instance, '_instance_initialized') was_deferred = self.field_name in instance.get_deferred_fields() @@ -101,23 +122,23 @@ class DescriptorWrapper: else: instance.__dict__[self.field_name] = value - def __getattr__(self, attr: str) -> models.Field: + def __getattr__(self, attr: str) -> T: return getattr(self.descriptor, attr) @staticmethod - def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]: + def cls_for_descriptor(descriptor: Descriptor[T]) -> type[DescriptorWrapper[T]]: if hasattr(descriptor, '__delete__'): return FullDescriptorWrapper else: return DescriptorWrapper -class FullDescriptorWrapper(DescriptorWrapper): +class FullDescriptorWrapper(DescriptorWrapper[T]): """ Wrapper for descriptors with all three descriptor methods. """ - def __delete__(self, obj: models.Field) -> None: - self.descriptor.__delete__(obj) # type: ignore[attr-defined] + def __delete__(self, obj: models.Model) -> None: + cast(FullDescriptor[T], self.descriptor).__delete__(obj) class FieldsContext: @@ -255,7 +276,9 @@ class FieldInstanceTracker: # deferred fields haven't changed if field in self.deferred_fields and field not in self.instance.__dict__: return False - return self.previous(field) != self.get_field_value(field) + prev: object = self.previous(field) + curr: object = self.get_field_value(field) + return prev != curr else: raise FieldError('field "%s" not tracked' % field) @@ -348,7 +371,7 @@ class FieldTracker: self.fields = (field.attname for field in sender._meta.fields) self.fields = set(self.fields) for field_name in self.fields: - descriptor: models.Field = getattr(sender, field_name) + descriptor: models.Field[Any, Any] = getattr(sender, field_name) wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor) wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname) setattr(sender, field_name, wrapped_descriptor) @@ -426,7 +449,9 @@ class ModelInstanceTracker(FieldInstanceTracker): if not self.instance.pk: return True elif field in self.saved_data: - return self.previous(field) != self.get_field_value(field) + prev: object = self.previous(field) + curr: object = self.get_field_value(field) + return prev != curr else: raise FieldError('field "%s" not tracked' % field)