From 2519a2a6868c73bd0e78eb701e4e2032c0143e1b Mon Sep 17 00:00:00 2001 From: kasium <15907922+kasium@users.noreply.github.com> Date: Wed, 19 Oct 2022 13:30:35 +0200 Subject: [PATCH] Add types to `invoke.Task.__call__` (#8918) --- stubs/invoke/invoke/executor.pyi | 4 +-- stubs/invoke/invoke/tasks.pyi | 47 ++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/stubs/invoke/invoke/executor.pyi b/stubs/invoke/invoke/executor.pyi index cb79f60c1..fc9b6c147 100644 --- a/stubs/invoke/invoke/executor.pyi +++ b/stubs/invoke/invoke/executor.pyi @@ -11,7 +11,7 @@ class Executor: config: Config core: ParseResult | None def __init__(self, collection: Collection, config: Config | None = ..., core: ParseResult | None = ...) -> None: ... - def execute(self, *tasks: str | tuple[str, dict[str, Any]] | ParserContext) -> dict[Task, Any]: ... + def execute(self, *tasks: str | tuple[str, dict[str, Any]] | ParserContext) -> dict[Task[..., Any], Any]: ... def normalize(self, tasks: Iterable[str | tuple[str, dict[str, Any]] | ParserContext]): ... def dedupe(self, calls: Iterable[Call]) -> list[Call]: ... - def expand_calls(self, calls: Iterable[Call | Task]) -> list[Call]: ... + def expand_calls(self, calls: Iterable[Call | Task[..., Any]]) -> list[Call]: ... diff --git a/stubs/invoke/invoke/tasks.pyi b/stubs/invoke/invoke/tasks.pyi index 9e60b0b5f..ca1379723 100644 --- a/stubs/invoke/invoke/tasks.pyi +++ b/stubs/invoke/invoke/tasks.pyi @@ -1,17 +1,20 @@ from _typeshed import Self from collections.abc import Callable, Iterable -from typing import Any, TypeVar, overload +from typing import Any, Generic, TypeVar, overload +from typing_extensions import ParamSpec from .config import Config from .context import Context from .parser import Argument -_TaskT = TypeVar("_TaskT", bound=Task) +_P = ParamSpec("_P") +_R_co = TypeVar("_R_co", covariant=True) +_TaskT = TypeVar("_TaskT", bound=Task[..., Any]) NO_DEFAULT: object -class Task: - body: Callable[..., Any] +class Task(Generic[_P, _R_co]): + body: Callable[_P, _R_co] __doc__: str | None __name__: str __module__: str @@ -23,8 +26,8 @@ class Task: incrementable: Iterable[str] auto_shortflags: bool help: dict[str, str] - pre: Iterable[Task] - post: Iterable[Task] + pre: Iterable[Task[..., Any]] + post: Iterable[Task[..., Any]] times_called: int autoprint: bool def __init__( @@ -37,8 +40,8 @@ class Task: default: bool = ..., auto_shortflags: bool = ..., help: dict[str, str] | None = ..., - pre: Iterable[Task] | None = ..., - post: Iterable[Task] | None = ..., + pre: Iterable[Task[..., Any]] | None = ..., + post: Iterable[Task[..., Any]] | None = ..., autoprint: bool = ..., iterable: Iterable[str] | None = ..., incrementable: Iterable[str] | None = ..., @@ -47,7 +50,7 @@ class Task: def name(self): ... def __eq__(self, other: Task) -> bool: ... # type: ignore[override] def __hash__(self) -> int: ... - def __call__(self, *args, **kwargs): ... + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ... @property def called(self) -> bool: ... def argspec(self, body): ... @@ -56,10 +59,10 @@ class Task: def get_arguments(self, ignore_unknown_help: bool | None = ...) -> list[Argument]: ... @overload -def task(__func: Callable[..., Any]) -> Task: ... +def task(__func: Callable[_P, _R_co]) -> Task[_P, _R_co]: ... @overload def task( - *args: Task, + *args: Task[..., Any], name: str | None = ..., aliases: tuple[str, ...] = ..., positional: Iterable[str] | None = ..., @@ -67,15 +70,15 @@ def task( default: bool = ..., auto_shortflags: bool = ..., help: dict[str, str] | None = ..., - pre: list[Task] | None = ..., - post: list[Task] | None = ..., + pre: list[Task[..., Any]] | None = ..., + post: list[Task[..., Any]] | None = ..., autoprint: bool = ..., iterable: Iterable[str] | None = ..., incrementable: Iterable[str] | None = ..., -) -> Callable[[Callable[..., Any]], Task]: ... +) -> Callable[[Callable[_P, _R_co]], Task[_P, _R_co]]: ... @overload def task( - *args: Task, + *args: Task[..., Any], name: str | None = ..., aliases: tuple[str, ...] = ..., positional: Iterable[str] | None = ..., @@ -83,8 +86,8 @@ def task( default: bool = ..., auto_shortflags: bool = ..., help: dict[str, str] | None = ..., - pre: list[Task] | None = ..., - post: list[Task] | None = ..., + pre: list[Task[..., Any]] | None = ..., + post: list[Task[..., Any]] | None = ..., autoprint: bool = ..., iterable: Iterable[str] | None = ..., incrementable: Iterable[str] | None = ..., @@ -92,12 +95,16 @@ def task( ) -> Callable[[Callable[..., Any]], _TaskT]: ... class Call: - task: Task + task: Task[..., Any] called_as: str | None args: tuple[Any, ...] kwargs: dict[str, Any] def __init__( - self, task: Task, called_as: str | None = ..., args: tuple[Any, ...] | None = ..., kwargs: dict[str, Any] | None = ... + self, + task: Task[..., Any], + called_as: str | None = ..., + args: tuple[Any, ...] | None = ..., + kwargs: dict[str, Any] | None = ..., ) -> None: ... def __getattr__(self, name: str) -> Any: ... def __deepcopy__(self: Self, memo: Any) -> Self: ... @@ -107,4 +114,4 @@ class Call: # TODO use overload def clone(self, into: type[Call] | None = ..., with_: dict[str, Any] | None = ...) -> Call: ... -def call(task: Task, *args: Any, **kwargs: Any) -> Call: ... +def call(task: Task[..., Any], *args: Any, **kwargs: Any) -> Call: ...