Preserve tracked function's return type in FieldTracker

This commit is contained in:
Maarten ter Huurne 2024-04-17 16:14:33 +02:00
parent 23a756e13e
commit f4653f08e5

View file

@ -296,14 +296,30 @@ class FieldTracker:
# finalize_class() will replace None; pretend it is never None.
self.fields = cast(Iterable[str], fields)
@overload
def __call__(
self,
func: Callable | None = None,
func: None = None,
fields: Iterable[str] | None = None
) -> Any:
def decorator(f: Callable) -> Callable:
) -> Callable[[Callable[..., T]], Callable[..., T]]:
...
@overload
def __call__(
self,
func: Callable[..., T],
fields: Iterable[str] | None = None
) -> Callable[..., T]:
...
def __call__(
self,
func: Callable[..., T] | None = None,
fields: Iterable[str] | None = None
) -> Callable[[Callable[..., T]], Callable[..., T]] | Callable[..., T]:
def decorator(f: Callable[..., T]) -> Callable[..., T]:
@wraps(f)
def inner(obj: models.Model, *args: object, **kwargs: object) -> object:
def inner(obj: models.Model, *args: object, **kwargs: object) -> T:
tracker = getattr(obj, self.attname)
field_list = tracker.fields if fields is None else fields
with tracker(*field_list):