diff --git a/model_utils/tracker.py b/model_utils/tracker.py index d89f4a8..266ff3b 100644 --- a/model_utils/tracker.py +++ b/model_utils/tracker.py @@ -296,14 +296,30 @@ class FieldTracker: # finalize_class() will replace None; pretend it is never None. self.fields = cast(Iterable[str], fields) + @overload def __call__( self, - func: Callable | None = None, + func: None = None, fields: Iterable[str] | None = None - ) -> Any: - def decorator(f: Callable) -> Callable: + ) -> Callable[[Callable[..., T]], Callable[..., T]]: + ... + + @overload + def __call__( + self, + func: Callable[..., T], + fields: Iterable[str] | None = None + ) -> Callable[..., T]: + ... + + def __call__( + self, + func: Callable[..., T] | None = None, + fields: Iterable[str] | None = None + ) -> Callable[[Callable[..., T]], Callable[..., T]] | Callable[..., T]: + def decorator(f: Callable[..., T]) -> Callable[..., T]: @wraps(f) - def inner(obj: models.Model, *args: object, **kwargs: object) -> object: + def inner(obj: models.Model, *args: object, **kwargs: object) -> T: tracker = getattr(obj, self.attname) field_list = tracker.fields if fields is None else fields with tracker(*field_list):