mirror of
https://github.com/Hopiu/django-model-utils.git
synced 2026-03-16 20:00:23 +00:00
Add type argument to DescriptorWrapper
This preserves the type of the wrapped descriptor (usually a field). Maybe this is overkill, as `DescriptorWrapper` seems to only be used as part of the `FieldTracker` implementation and is not documented and barely tested. But technically, it is public API.
This commit is contained in:
parent
1db7d6ba33
commit
00937608fa
1 changed files with 41 additions and 16 deletions
|
|
@ -2,7 +2,16 @@ from __future__ import annotations
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
Protocol,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
from django.core.exceptions import FieldError
|
from django.core.exceptions import FieldError
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
@ -16,10 +25,22 @@ if TYPE_CHECKING:
|
||||||
_instance_initialized: bool
|
_instance_initialized: bool
|
||||||
_deferred_fields: set[str]
|
_deferred_fields: set[str]
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Descriptor(Protocol[T]):
|
||||||
|
def __get__(self, instance: object, owner: type[object]) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __set__(self, instance: object, value: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class FullDescriptor(Descriptor[T]):
|
||||||
|
def __delete__(self, instance: object) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class LightStateFieldFile(FieldFile):
|
class LightStateFieldFile(FieldFile):
|
||||||
"""
|
"""
|
||||||
FieldFile subclass with the only aim to remove the instance from the state.
|
FieldFile subclass with the only aim to remove the instance from the state.
|
||||||
|
|
@ -53,22 +74,22 @@ def lightweight_deepcopy(value: T) -> T:
|
||||||
return deepcopy(value)
|
return deepcopy(value)
|
||||||
|
|
||||||
|
|
||||||
class DescriptorWrapper:
|
class DescriptorWrapper(Generic[T]):
|
||||||
|
|
||||||
def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str):
|
def __init__(self, field_name: str, descriptor: Descriptor[T], tracker_attname: str):
|
||||||
self.field_name = field_name
|
self.field_name = field_name
|
||||||
self.descriptor = descriptor
|
self.descriptor = descriptor
|
||||||
self.tracker_attname = tracker_attname
|
self.tracker_attname = tracker_attname
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper:
|
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper[T]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field:
|
def __get__(self, instance: models.Model, owner: type[models.Model]) -> T:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field:
|
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper[T] | T:
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return self
|
return self
|
||||||
was_deferred = self.field_name in instance.get_deferred_fields()
|
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||||
|
|
@ -78,7 +99,7 @@ class DescriptorWrapper:
|
||||||
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
|
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __set__(self, instance: models.Model, value: models.Field) -> None:
|
def __set__(self, instance: models.Model, value: T) -> None:
|
||||||
initialized = hasattr(instance, '_instance_initialized')
|
initialized = hasattr(instance, '_instance_initialized')
|
||||||
was_deferred = self.field_name in instance.get_deferred_fields()
|
was_deferred = self.field_name in instance.get_deferred_fields()
|
||||||
|
|
||||||
|
|
@ -101,23 +122,23 @@ class DescriptorWrapper:
|
||||||
else:
|
else:
|
||||||
instance.__dict__[self.field_name] = value
|
instance.__dict__[self.field_name] = value
|
||||||
|
|
||||||
def __getattr__(self, attr: str) -> models.Field:
|
def __getattr__(self, attr: str) -> T:
|
||||||
return getattr(self.descriptor, attr)
|
return getattr(self.descriptor, attr)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]:
|
def cls_for_descriptor(descriptor: Descriptor[T]) -> type[DescriptorWrapper[T]]:
|
||||||
if hasattr(descriptor, '__delete__'):
|
if hasattr(descriptor, '__delete__'):
|
||||||
return FullDescriptorWrapper
|
return FullDescriptorWrapper
|
||||||
else:
|
else:
|
||||||
return DescriptorWrapper
|
return DescriptorWrapper
|
||||||
|
|
||||||
|
|
||||||
class FullDescriptorWrapper(DescriptorWrapper):
|
class FullDescriptorWrapper(DescriptorWrapper[T]):
|
||||||
"""
|
"""
|
||||||
Wrapper for descriptors with all three descriptor methods.
|
Wrapper for descriptors with all three descriptor methods.
|
||||||
"""
|
"""
|
||||||
def __delete__(self, obj: models.Field) -> None:
|
def __delete__(self, obj: models.Model) -> None:
|
||||||
self.descriptor.__delete__(obj) # type: ignore[attr-defined]
|
cast(FullDescriptor[T], self.descriptor).__delete__(obj)
|
||||||
|
|
||||||
|
|
||||||
class FieldsContext:
|
class FieldsContext:
|
||||||
|
|
@ -255,7 +276,9 @@ class FieldInstanceTracker:
|
||||||
# deferred fields haven't changed
|
# deferred fields haven't changed
|
||||||
if field in self.deferred_fields and field not in self.instance.__dict__:
|
if field in self.deferred_fields and field not in self.instance.__dict__:
|
||||||
return False
|
return False
|
||||||
return self.previous(field) != self.get_field_value(field)
|
prev: object = self.previous(field)
|
||||||
|
curr: object = self.get_field_value(field)
|
||||||
|
return prev != curr
|
||||||
else:
|
else:
|
||||||
raise FieldError('field "%s" not tracked' % field)
|
raise FieldError('field "%s" not tracked' % field)
|
||||||
|
|
||||||
|
|
@ -348,7 +371,7 @@ class FieldTracker:
|
||||||
self.fields = (field.attname for field in sender._meta.fields)
|
self.fields = (field.attname for field in sender._meta.fields)
|
||||||
self.fields = set(self.fields)
|
self.fields = set(self.fields)
|
||||||
for field_name in self.fields:
|
for field_name in self.fields:
|
||||||
descriptor: models.Field = getattr(sender, field_name)
|
descriptor: models.Field[Any, Any] = getattr(sender, field_name)
|
||||||
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
|
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
|
||||||
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
|
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
|
||||||
setattr(sender, field_name, wrapped_descriptor)
|
setattr(sender, field_name, wrapped_descriptor)
|
||||||
|
|
@ -426,7 +449,9 @@ class ModelInstanceTracker(FieldInstanceTracker):
|
||||||
if not self.instance.pk:
|
if not self.instance.pk:
|
||||||
return True
|
return True
|
||||||
elif field in self.saved_data:
|
elif field in self.saved_data:
|
||||||
return self.previous(field) != self.get_field_value(field)
|
prev: object = self.previous(field)
|
||||||
|
curr: object = self.get_field_value(field)
|
||||||
|
return prev != curr
|
||||||
else:
|
else:
|
||||||
raise FieldError('field "%s" not tracked' % field)
|
raise FieldError('field "%s" not tracked' % field)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue