mirror of
https://github.com/davidhalter/typeshed.git
synced 2026-05-04 20:45:49 +08:00
Enable Ruff PLC (Pylint Convention) (#13306)
This commit is contained in:
+16
-8
@@ -45,6 +45,7 @@ select = [
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"PGH", # pygrep-hooks
|
||||
"PLC", # Pylint Convention
|
||||
"RUF", # Ruff-specific and unused-noqa
|
||||
"TRY", # tryceratops
|
||||
"UP", # pyupgrade
|
||||
@@ -159,19 +160,26 @@ ignore = [
|
||||
# A lot of stubs are incomplete on purpose, and that's configured through pyright
|
||||
# Some ANN204 (special method) are autofixable in stubs, but not all.
|
||||
"ANN2", # Missing return type annotation for ...
|
||||
# Most pep8-naming rules don't apply for third-party stubs like typeshed.
|
||||
# N811 to N814 could apply, but we often use them to disambiguate a name whilst making it look like a more common one
|
||||
"N8",
|
||||
# Ruff 0.8.0 added sorting of __all__ and __slots_.
|
||||
# There is no consensus on whether we want to apply this to stubs, so keeping the status quo.
|
||||
# See https://github.com/python/typeshed/pull/13108
|
||||
"RUF022", # `__all__` is not sorted
|
||||
"RUF023", # `{}.__slots__` is not sorted
|
||||
###
|
||||
# Rules that are out of the control of stub authors:
|
||||
###
|
||||
"F403", # `from . import *` used; unable to detect undefined names
|
||||
# Stubs can sometimes re-export entire modules.
|
||||
# Issues with using a star-imported name will be caught by type-checkers.
|
||||
"F405", # may be undefined, or defined from star imports
|
||||
# Ruff 0.8.0 added sorting of __all__ and __slots_.
|
||||
# There is no consensus on whether we want to apply this to stubs, so keeping the status quo.
|
||||
# See https://github.com/python/typeshed/pull/13108
|
||||
"RUF022",
|
||||
"RUF023",
|
||||
# Most pep8-naming rules don't apply for third-party stubs like typeshed.
|
||||
# N811 to N814 could apply, but we often use them to disambiguate a name whilst making it look like a more common one
|
||||
"N8", # pep8-naming
|
||||
"PLC2701", # Private name import from external module
|
||||
]
|
||||
"lib/ts_utils/**" = [
|
||||
# Doesn't affect stubs. The only re-exports we have should be in our local lib ts_utils
|
||||
"PLC0414", # Import alias does not rename original package
|
||||
]
|
||||
"*_pb2.pyi" = [
|
||||
# Leave the docstrings as-is, matching source
|
||||
|
||||
+15
-15
@@ -89,8 +89,8 @@ _T2 = TypeVar("_T2")
|
||||
_T3 = TypeVar("_T3")
|
||||
_T4 = TypeVar("_T4")
|
||||
_T5 = TypeVar("_T5")
|
||||
_SupportsNextT = TypeVar("_SupportsNextT", bound=SupportsNext[Any], covariant=True)
|
||||
_SupportsAnextT = TypeVar("_SupportsAnextT", bound=SupportsAnext[Any], covariant=True)
|
||||
_SupportsNextT_co = TypeVar("_SupportsNextT_co", bound=SupportsNext[Any], covariant=True)
|
||||
_SupportsAnextT_co = TypeVar("_SupportsAnextT_co", bound=SupportsAnext[Any], covariant=True)
|
||||
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
|
||||
_AwaitableT_co = TypeVar("_AwaitableT_co", bound=Awaitable[Any], covariant=True)
|
||||
_P = ParamSpec("_P")
|
||||
@@ -1319,7 +1319,7 @@ class _PathLike(Protocol[AnyStr_co]):
|
||||
def __fspath__(self) -> AnyStr_co: ...
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
def aiter(async_iterable: SupportsAiter[_SupportsAnextT], /) -> _SupportsAnextT: ...
|
||||
def aiter(async_iterable: SupportsAiter[_SupportsAnextT_co], /) -> _SupportsAnextT_co: ...
|
||||
|
||||
class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):
|
||||
def __anext__(self) -> _AwaitableT_co: ...
|
||||
@@ -1481,7 +1481,7 @@ class _GetItemIterable(Protocol[_T_co]):
|
||||
def __getitem__(self, i: int, /) -> _T_co: ...
|
||||
|
||||
@overload
|
||||
def iter(object: SupportsIter[_SupportsNextT], /) -> _SupportsNextT: ...
|
||||
def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ...
|
||||
@overload
|
||||
def iter(object: _GetItemIterable[_T], /) -> Iterator[_T]: ...
|
||||
@overload
|
||||
@@ -1688,17 +1688,17 @@ def print(
|
||||
*values: object, sep: str | None = " ", end: str | None = "\n", file: _SupportsWriteAndFlush[str] | None = None, flush: bool
|
||||
) -> None: ...
|
||||
|
||||
_E = TypeVar("_E", contravariant=True)
|
||||
_M = TypeVar("_M", contravariant=True)
|
||||
_E_contra = TypeVar("_E_contra", contravariant=True)
|
||||
_M_contra = TypeVar("_M_contra", contravariant=True)
|
||||
|
||||
class _SupportsPow2(Protocol[_E, _T_co]):
|
||||
def __pow__(self, other: _E, /) -> _T_co: ...
|
||||
class _SupportsPow2(Protocol[_E_contra, _T_co]):
|
||||
def __pow__(self, other: _E_contra, /) -> _T_co: ...
|
||||
|
||||
class _SupportsPow3NoneOnly(Protocol[_E, _T_co]):
|
||||
def __pow__(self, other: _E, modulo: None = None, /) -> _T_co: ...
|
||||
class _SupportsPow3NoneOnly(Protocol[_E_contra, _T_co]):
|
||||
def __pow__(self, other: _E_contra, modulo: None = None, /) -> _T_co: ...
|
||||
|
||||
class _SupportsPow3(Protocol[_E, _M, _T_co]):
|
||||
def __pow__(self, other: _E, modulo: _M, /) -> _T_co: ...
|
||||
class _SupportsPow3(Protocol[_E_contra, _M_contra, _T_co]):
|
||||
def __pow__(self, other: _E_contra, modulo: _M_contra, /) -> _T_co: ...
|
||||
|
||||
_SupportsSomeKindOfPow = ( # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed
|
||||
_SupportsPow2[Any, Any] | _SupportsPow3NoneOnly[Any, Any] | _SupportsPow3[Any, Any, Any]
|
||||
@@ -1734,11 +1734,11 @@ def pow(base: float, exp: complex | _SupportsSomeKindOfPow, mod: None = None) ->
|
||||
@overload
|
||||
def pow(base: complex, exp: complex | _SupportsSomeKindOfPow, mod: None = None) -> complex: ...
|
||||
@overload
|
||||
def pow(base: _SupportsPow2[_E, _T_co], exp: _E, mod: None = None) -> _T_co: ... # type: ignore[overload-overlap]
|
||||
def pow(base: _SupportsPow2[_E_contra, _T_co], exp: _E_contra, mod: None = None) -> _T_co: ... # type: ignore[overload-overlap]
|
||||
@overload
|
||||
def pow(base: _SupportsPow3NoneOnly[_E, _T_co], exp: _E, mod: None = None) -> _T_co: ... # type: ignore[overload-overlap]
|
||||
def pow(base: _SupportsPow3NoneOnly[_E_contra, _T_co], exp: _E_contra, mod: None = None) -> _T_co: ... # type: ignore[overload-overlap]
|
||||
@overload
|
||||
def pow(base: _SupportsPow3[_E, _M, _T_co], exp: _E, mod: _M) -> _T_co: ...
|
||||
def pow(base: _SupportsPow3[_E_contra, _M_contra, _T_co], exp: _E_contra, mod: _M_contra) -> _T_co: ...
|
||||
@overload
|
||||
def pow(base: _SupportsSomeKindOfPow, exp: float, mod: None = None) -> Any: ...
|
||||
@overload
|
||||
|
||||
@@ -33,7 +33,7 @@ _T_co = TypeVar("_T_co", covariant=True)
|
||||
_T_io = TypeVar("_T_io", bound=IO[str] | None)
|
||||
_ExitT_co = TypeVar("_ExitT_co", covariant=True, bound=bool | None, default=bool | None)
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
_G = TypeVar("_G", bound=Generator[Any, Any, Any] | AsyncGenerator[Any, Any], covariant=True)
|
||||
_G_co = TypeVar("_G_co", bound=Generator[Any, Any, Any] | AsyncGenerator[Any, Any], covariant=True)
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
|
||||
@@ -68,11 +68,11 @@ class ContextDecorator:
|
||||
def _recreate_cm(self) -> Self: ...
|
||||
def __call__(self, func: _F) -> _F: ...
|
||||
|
||||
class _GeneratorContextManagerBase(Generic[_G]):
|
||||
class _GeneratorContextManagerBase(Generic[_G_co]):
|
||||
# Ideally this would use ParamSpec, but that requires (*args, **kwargs), which this isn't. see #6676
|
||||
def __init__(self, func: Callable[..., _G], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
|
||||
gen: _G
|
||||
func: Callable[..., _G]
|
||||
def __init__(self, func: Callable[..., _G_co], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
|
||||
gen: _G_co
|
||||
func: Callable[..., _G_co]
|
||||
args: tuple[Any, ...]
|
||||
kwds: dict[str, Any]
|
||||
|
||||
|
||||
+6
-6
@@ -143,8 +143,8 @@ if sys.version_info >= (3, 11):
|
||||
_P = ParamSpec("_P")
|
||||
_T = TypeVar("_T")
|
||||
_F = TypeVar("_F", bound=Callable[..., Any])
|
||||
_T_cont = TypeVar("_T_cont", contravariant=True)
|
||||
_V_cont = TypeVar("_V_cont", contravariant=True)
|
||||
_T_contra = TypeVar("_T_contra", contravariant=True)
|
||||
_V_contra = TypeVar("_V_contra", contravariant=True)
|
||||
|
||||
#
|
||||
# Types and members
|
||||
@@ -228,11 +228,11 @@ def isasyncgenfunction(obj: Callable[_P, Any]) -> TypeGuard[Callable[_P, AsyncGe
|
||||
@overload
|
||||
def isasyncgenfunction(obj: object) -> TypeGuard[Callable[..., AsyncGeneratorType[Any, Any]]]: ...
|
||||
|
||||
class _SupportsSet(Protocol[_T_cont, _V_cont]):
|
||||
def __set__(self, instance: _T_cont, value: _V_cont, /) -> None: ...
|
||||
class _SupportsSet(Protocol[_T_contra, _V_contra]):
|
||||
def __set__(self, instance: _T_contra, value: _V_contra, /) -> None: ...
|
||||
|
||||
class _SupportsDelete(Protocol[_T_cont]):
|
||||
def __delete__(self, instance: _T_cont, /) -> None: ...
|
||||
class _SupportsDelete(Protocol[_T_contra]):
|
||||
def __delete__(self, instance: _T_contra, /) -> None: ...
|
||||
|
||||
def isasyncgen(object: object) -> TypeIs[AsyncGeneratorType[Any, Any]]: ...
|
||||
def istraceback(object: object) -> TypeIs[TracebackType]: ...
|
||||
|
||||
@@ -12,10 +12,10 @@ __all__ = ["Client", "Listener", "Pipe", "wait"]
|
||||
_Address: TypeAlias = str | tuple[str, int]
|
||||
|
||||
# Defaulting to Any to avoid forcing generics on a lot of pre-existing code
|
||||
_SendT = TypeVar("_SendT", contravariant=True, default=Any)
|
||||
_RecvT = TypeVar("_RecvT", covariant=True, default=Any)
|
||||
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=Any)
|
||||
_RecvT_co = TypeVar("_RecvT_co", covariant=True, default=Any)
|
||||
|
||||
class _ConnectionBase(Generic[_SendT, _RecvT]):
|
||||
class _ConnectionBase(Generic[_SendT_contra, _RecvT_co]):
|
||||
def __init__(self, handle: SupportsIndex, readable: bool = True, writable: bool = True) -> None: ...
|
||||
@property
|
||||
def closed(self) -> bool: ... # undocumented
|
||||
@@ -26,10 +26,10 @@ class _ConnectionBase(Generic[_SendT, _RecvT]):
|
||||
def fileno(self) -> int: ...
|
||||
def close(self) -> None: ...
|
||||
def send_bytes(self, buf: ReadableBuffer, offset: int = 0, size: int | None = None) -> None: ...
|
||||
def send(self, obj: _SendT) -> None: ...
|
||||
def send(self, obj: _SendT_contra) -> None: ...
|
||||
def recv_bytes(self, maxlength: int | None = None) -> bytes: ...
|
||||
def recv_bytes_into(self, buf: Any, offset: int = 0) -> int: ...
|
||||
def recv(self) -> _RecvT: ...
|
||||
def recv(self) -> _RecvT_co: ...
|
||||
def poll(self, timeout: float | None = 0.0) -> bool: ...
|
||||
def __enter__(self) -> Self: ...
|
||||
def __exit__(
|
||||
@@ -37,10 +37,10 @@ class _ConnectionBase(Generic[_SendT, _RecvT]):
|
||||
) -> None: ...
|
||||
def __del__(self) -> None: ...
|
||||
|
||||
class Connection(_ConnectionBase[_SendT, _RecvT]): ...
|
||||
class Connection(_ConnectionBase[_SendT_contra, _RecvT_co]): ...
|
||||
|
||||
if sys.platform == "win32":
|
||||
class PipeConnection(_ConnectionBase[_SendT, _RecvT]): ...
|
||||
class PipeConnection(_ConnectionBase[_SendT_contra, _RecvT_co]): ...
|
||||
|
||||
class Listener:
|
||||
def __init__(
|
||||
@@ -66,8 +66,8 @@ else:
|
||||
|
||||
def answer_challenge(connection: Connection[Any, Any], authkey: bytes) -> None: ...
|
||||
def wait(
|
||||
object_list: Iterable[Connection[_SendT, _RecvT] | socket.socket | int], timeout: float | None = None
|
||||
) -> list[Connection[_SendT, _RecvT] | socket.socket | int]: ...
|
||||
object_list: Iterable[Connection[_SendT_contra, _RecvT_co] | socket.socket | int], timeout: float | None = None
|
||||
) -> list[Connection[_SendT_contra, _RecvT_co] | socket.socket | int]: ...
|
||||
def Client(address: _Address, family: str | None = None, authkey: bytes | None = None) -> Connection[Any, Any]: ...
|
||||
|
||||
# N.B. Keep this in sync with multiprocessing.context.BaseContext.Pipe.
|
||||
|
||||
+7
-7
@@ -510,15 +510,15 @@ class Awaitable(Protocol[_T_co]):
|
||||
def __await__(self) -> Generator[Any, Any, _T_co]: ...
|
||||
|
||||
# Non-default variations to accommodate couroutines, and `AwaitableGenerator` having a 4th type parameter.
|
||||
_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True)
|
||||
_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True)
|
||||
_SendT_nd_contra = TypeVar("_SendT_nd_contra", contravariant=True)
|
||||
_ReturnT_nd_co = TypeVar("_ReturnT_nd_co", covariant=True)
|
||||
|
||||
class Coroutine(Awaitable[_ReturnT_co_nd], Generic[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]):
|
||||
class Coroutine(Awaitable[_ReturnT_nd_co], Generic[_YieldT_co, _SendT_nd_contra, _ReturnT_nd_co]):
|
||||
__name__: str
|
||||
__qualname__: str
|
||||
|
||||
@abstractmethod
|
||||
def send(self, value: _SendT_contra_nd, /) -> _YieldT_co: ...
|
||||
def send(self, value: _SendT_nd_contra, /) -> _YieldT_co: ...
|
||||
@overload
|
||||
@abstractmethod
|
||||
def throw(
|
||||
@@ -534,9 +534,9 @@ class Coroutine(Awaitable[_ReturnT_co_nd], Generic[_YieldT_co, _SendT_contra_nd,
|
||||
# The parameters correspond to Generator, but the 4th is the original type.
|
||||
@type_check_only
|
||||
class AwaitableGenerator(
|
||||
Awaitable[_ReturnT_co_nd],
|
||||
Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
|
||||
Generic[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd, _S],
|
||||
Awaitable[_ReturnT_nd_co],
|
||||
Generator[_YieldT_co, _SendT_nd_contra, _ReturnT_nd_co],
|
||||
Generic[_YieldT_co, _SendT_nd_contra, _ReturnT_nd_co, _S],
|
||||
metaclass=ABCMeta,
|
||||
): ...
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ __all__ = (
|
||||
"Disabled",
|
||||
)
|
||||
|
||||
_ValuesT = TypeVar("_ValuesT", bound=Collection[Any], contravariant=True)
|
||||
_ValuesT_contra = TypeVar("_ValuesT_contra", bound=Collection[Any], contravariant=True)
|
||||
|
||||
class ValidationError(ValueError):
|
||||
def __init__(self, message: str = "", *args: object) -> None: ...
|
||||
@@ -150,9 +150,13 @@ class AnyOf:
|
||||
@overload
|
||||
def __init__(self, values: Collection[Any], message: str | None = None, values_formatter: None = None) -> None: ...
|
||||
@overload
|
||||
def __init__(self, values: _ValuesT, message: str | None, values_formatter: Callable[[_ValuesT], str]) -> None: ...
|
||||
def __init__(
|
||||
self, values: _ValuesT_contra, message: str | None, values_formatter: Callable[[_ValuesT_contra], str]
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, values: _ValuesT, message: str | None = None, *, values_formatter: Callable[[_ValuesT], str]) -> None: ...
|
||||
def __init__(
|
||||
self, values: _ValuesT_contra, message: str | None = None, *, values_formatter: Callable[[_ValuesT_contra], str]
|
||||
) -> None: ...
|
||||
def __call__(self, form: BaseForm, field: Field) -> None: ...
|
||||
@staticmethod
|
||||
def default_values_formatter(values: Iterable[object]) -> str: ...
|
||||
@@ -164,9 +168,13 @@ class NoneOf:
|
||||
@overload
|
||||
def __init__(self, values: Collection[Any], message: str | None = None, values_formatter: None = None) -> None: ...
|
||||
@overload
|
||||
def __init__(self, values: _ValuesT, message: str | None, values_formatter: Callable[[_ValuesT], str]) -> None: ...
|
||||
def __init__(
|
||||
self, values: _ValuesT_contra, message: str | None, values_formatter: Callable[[_ValuesT_contra], str]
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __init__(self, values: _ValuesT, message: str | None = None, *, values_formatter: Callable[[_ValuesT], str]) -> None: ...
|
||||
def __init__(
|
||||
self, values: _ValuesT_contra, message: str | None = None, *, values_formatter: Callable[[_ValuesT_contra], str]
|
||||
) -> None: ...
|
||||
def __call__(self, form: BaseForm, field: Field) -> None: ...
|
||||
@staticmethod
|
||||
def default_values_formatter(v: Iterable[object]) -> str: ...
|
||||
|
||||
@@ -22,31 +22,31 @@ class Callpoint:
|
||||
def from_tb(cls, tb: TracebackType) -> Self: ...
|
||||
def tb_frame_str(self) -> str: ...
|
||||
|
||||
_CallpointT = TypeVar("_CallpointT", bound=Callpoint, covariant=True, default=Callpoint)
|
||||
_CallpointT_co = TypeVar("_CallpointT_co", bound=Callpoint, covariant=True, default=Callpoint)
|
||||
|
||||
class TracebackInfo(Generic[_CallpointT]):
|
||||
callpoint_type: type[_CallpointT]
|
||||
frames: list[_CallpointT]
|
||||
def __init__(self, frames: list[_CallpointT]) -> None: ...
|
||||
class TracebackInfo(Generic[_CallpointT_co]):
|
||||
callpoint_type: type[_CallpointT_co]
|
||||
frames: list[_CallpointT_co]
|
||||
def __init__(self, frames: list[_CallpointT_co]) -> None: ...
|
||||
@classmethod
|
||||
def from_frame(cls, frame: FrameType | None = None, level: int = 1, limit: int | None = None) -> Self: ...
|
||||
@classmethod
|
||||
def from_traceback(cls, tb: TracebackType | None = None, limit: int | None = None) -> Self: ...
|
||||
@classmethod
|
||||
def from_dict(cls, d: Mapping[Literal["frames"], list[_CallpointT]]) -> Self: ...
|
||||
def to_dict(self) -> dict[str, list[dict[str, _CallpointT]]]: ...
|
||||
def from_dict(cls, d: Mapping[Literal["frames"], list[_CallpointT_co]]) -> Self: ...
|
||||
def to_dict(self) -> dict[str, list[dict[str, _CallpointT_co]]]: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __iter__(self) -> Iterator[_CallpointT]: ...
|
||||
def __iter__(self) -> Iterator[_CallpointT_co]: ...
|
||||
def get_formatted(self) -> str: ...
|
||||
|
||||
_TracebackInfoT = TypeVar("_TracebackInfoT", bound=TracebackInfo, covariant=True, default=TracebackInfo)
|
||||
_TracebackInfoT_co = TypeVar("_TracebackInfoT_co", bound=TracebackInfo, covariant=True, default=TracebackInfo)
|
||||
|
||||
class ExceptionInfo(Generic[_TracebackInfoT]):
|
||||
tb_info_type: type[_TracebackInfoT]
|
||||
class ExceptionInfo(Generic[_TracebackInfoT_co]):
|
||||
tb_info_type: type[_TracebackInfoT_co]
|
||||
exc_type: str
|
||||
exc_msg: str
|
||||
tb_info: _TracebackInfoT
|
||||
def __init__(self, exc_type: str, exc_msg: str, tb_info: _TracebackInfoT) -> None: ...
|
||||
tb_info: _TracebackInfoT_co
|
||||
def __init__(self, exc_type: str, exc_msg: str, tb_info: _TracebackInfoT_co) -> None: ...
|
||||
@classmethod
|
||||
def from_exc_info(cls, exc_type: type[BaseException], exc_value: BaseException, traceback: TracebackType) -> Self: ...
|
||||
@classmethod
|
||||
|
||||
@@ -15,10 +15,10 @@ def difference(G: Graph[_Node], H: Graph[_Node]): ...
|
||||
@_dispatchable
|
||||
def symmetric_difference(G: Graph[_Node], H: Graph[_Node]): ...
|
||||
|
||||
_X = TypeVar("_X", bound=Hashable, covariant=True)
|
||||
_Y = TypeVar("_Y", bound=Hashable, covariant=True)
|
||||
_X_co = TypeVar("_X_co", bound=Hashable, covariant=True)
|
||||
_Y_co = TypeVar("_Y_co", bound=Hashable, covariant=True)
|
||||
|
||||
@_dispatchable
|
||||
def compose(G: Graph[_X], H: Graph[_Y]) -> DiGraph[_X | _Y]: ...
|
||||
def compose(G: Graph[_X_co], H: Graph[_Y_co]) -> DiGraph[_X_co | _Y_co]: ...
|
||||
@_dispatchable
|
||||
def union(G: Graph[_X], H: Graph[_Y], rename: Iterable[Incomplete] | None = ()) -> DiGraph[_X | _Y]: ...
|
||||
def union(G: Graph[_X_co], H: Graph[_Y_co], rename: Iterable[Incomplete] | None = ()) -> DiGraph[_X_co | _Y_co]: ...
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, ClassVar, Generic, TypedDict, TypeVar
|
||||
from typing_extensions import ParamSpec, Self
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_AbstractListener_T = TypeVar("_AbstractListener_T", bound=AbstractListener)
|
||||
_AbstractListenerT = TypeVar("_AbstractListenerT", bound=AbstractListener)
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
class _RESOLUTIONS(TypedDict):
|
||||
@@ -49,15 +49,15 @@ class AbstractListener(threading.Thread):
|
||||
def _stop_platform(self) -> None: ... # undocumented
|
||||
def join(self, timeout: float | None = None, *args: Any) -> None: ...
|
||||
|
||||
class Events(Generic[_T, _AbstractListener_T]):
|
||||
_Listener: type[_AbstractListener_T] | None # undocumented
|
||||
class Events(Generic[_T, _AbstractListenerT]):
|
||||
_Listener: type[_AbstractListenerT] | None # undocumented
|
||||
|
||||
class Event:
|
||||
def __eq__(self, other: object) -> bool: ...
|
||||
|
||||
_event_queue: Queue[_T] # undocumented
|
||||
_sentinel: object # undocumented
|
||||
_listener: _AbstractListener_T # undocumented
|
||||
_listener: _AbstractListenerT # undocumented
|
||||
start: Callable[[], None]
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def __enter__(self) -> Self: ...
|
||||
|
||||
@@ -8,10 +8,10 @@ from typing_extensions import Self
|
||||
|
||||
from serial import Serial
|
||||
|
||||
_AnyStr_T = TypeVar("_AnyStr_T", contravariant=True)
|
||||
_AnyStrT_contra = TypeVar("_AnyStrT_contra", contravariant=True)
|
||||
|
||||
@type_check_only
|
||||
class _SupportsWriteAndFlush(SupportsWrite[_AnyStr_T], SupportsFlush, Protocol): ...
|
||||
class _SupportsWriteAndFlush(SupportsWrite[_AnyStrT_contra], SupportsFlush, Protocol): ...
|
||||
|
||||
@type_check_only
|
||||
class _SupportsRead(Protocol):
|
||||
|
||||
@@ -14,21 +14,21 @@ from tensorflow.dtypes import DType
|
||||
from tensorflow.io import _CompressionTypes
|
||||
from tensorflow.python.trackable.base import Trackable
|
||||
|
||||
_T1 = TypeVar("_T1", covariant=True)
|
||||
_T1_co = TypeVar("_T1_co", covariant=True)
|
||||
_T2 = TypeVar("_T2")
|
||||
_T3 = TypeVar("_T3")
|
||||
|
||||
class Iterator(_Iterator[_T1], Trackable, ABC):
|
||||
class Iterator(_Iterator[_T1_co], Trackable, ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def element_spec(self) -> ContainerGeneric[TypeSpec[Any]]: ...
|
||||
@abstractmethod
|
||||
def get_next(self) -> _T1: ...
|
||||
def get_next(self) -> _T1_co: ...
|
||||
@abstractmethod
|
||||
def get_next_as_optional(self) -> tf.experimental.Optional[_T1]: ...
|
||||
def get_next_as_optional(self) -> tf.experimental.Optional[_T1_co]: ...
|
||||
|
||||
class Dataset(ABC, Generic[_T1]):
|
||||
def apply(self, transformation_func: Callable[[Dataset[_T1]], Dataset[_T2]]) -> Dataset[_T2]: ...
|
||||
class Dataset(ABC, Generic[_T1_co]):
|
||||
def apply(self, transformation_func: Callable[[Dataset[_T1_co]], Dataset[_T2]]) -> Dataset[_T2]: ...
|
||||
def as_numpy_iterator(self) -> Iterator[np.ndarray[Any, Any]]: ...
|
||||
def batch(
|
||||
self,
|
||||
@@ -37,10 +37,10 @@ class Dataset(ABC, Generic[_T1]):
|
||||
num_parallel_calls: int | None = None,
|
||||
deterministic: bool | None = None,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T1]: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
def bucket_by_sequence_length(
|
||||
self,
|
||||
element_length_func: Callable[[_T1], ScalarTensorCompatible],
|
||||
element_length_func: Callable[[_T1_co], ScalarTensorCompatible],
|
||||
bucket_boundaries: Sequence[int],
|
||||
bucket_batch_sizes: Sequence[int],
|
||||
padded_shapes: ContainerGeneric[tf.TensorShape | TensorCompatible] | None = None,
|
||||
@@ -49,14 +49,14 @@ class Dataset(ABC, Generic[_T1]):
|
||||
no_padding: bool = False,
|
||||
drop_remainder: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T1]: ...
|
||||
def cache(self, filename: str = "", name: str | None = None) -> Dataset[_T1]: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
def cache(self, filename: str = "", name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def cardinality(self) -> int: ...
|
||||
@staticmethod
|
||||
def choose_from_datasets(
|
||||
datasets: Sequence[Dataset[_T2]], choice_dataset: Dataset[tf.Tensor], stop_on_empty_dataset: bool = True
|
||||
) -> Dataset[_T2]: ...
|
||||
def concatenate(self, dataset: Dataset[_T1], name: str | None = None) -> Dataset[_T1]: ...
|
||||
def concatenate(self, dataset: Dataset[_T1_co], name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
@staticmethod
|
||||
def counter(
|
||||
start: ScalarTensorCompatible = 0, step: ScalarTensorCompatible = 1, dtype: DType = ..., name: str | None = None
|
||||
@@ -64,9 +64,9 @@ class Dataset(ABC, Generic[_T1]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def element_spec(self) -> ContainerGeneric[TypeSpec[Any]]: ...
|
||||
def enumerate(self, start: ScalarTensorCompatible = 0, name: str | None = None) -> Dataset[tuple[int, _T1]]: ...
|
||||
def filter(self, predicate: Callable[[_T1], bool | tf.Tensor], name: str | None = None) -> Dataset[_T1]: ...
|
||||
def flat_map(self, map_func: Callable[[_T1], Dataset[_T2]], name: str | None = None) -> Dataset[_T2]: ...
|
||||
def enumerate(self, start: ScalarTensorCompatible = 0, name: str | None = None) -> Dataset[tuple[int, _T1_co]]: ...
|
||||
def filter(self, predicate: Callable[[_T1_co], bool | tf.Tensor], name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def flat_map(self, map_func: Callable[[_T1_co], Dataset[_T2]], name: str | None = None) -> Dataset[_T2]: ...
|
||||
# PEP 646 can be used here for a more precise type when better supported.
|
||||
@staticmethod
|
||||
def from_generator(
|
||||
@@ -81,26 +81,26 @@ class Dataset(ABC, Generic[_T1]):
|
||||
def from_tensors(tensors: Any, name: str | None = None) -> Dataset[Any]: ...
|
||||
@staticmethod
|
||||
def from_tensor_slices(tensors: TensorCompatible, name: str | None = None) -> Dataset[Any]: ...
|
||||
def get_single_element(self, name: str | None = None) -> _T1: ...
|
||||
def get_single_element(self, name: str | None = None) -> _T1_co: ...
|
||||
def group_by_window(
|
||||
self,
|
||||
key_func: Callable[[_T1], tf.Tensor],
|
||||
reduce_func: Callable[[tf.Tensor, Dataset[_T1]], Dataset[_T2]],
|
||||
key_func: Callable[[_T1_co], tf.Tensor],
|
||||
reduce_func: Callable[[tf.Tensor, Dataset[_T1_co]], Dataset[_T2]],
|
||||
window_size: ScalarTensorCompatible | None = None,
|
||||
window_size_func: Callable[[tf.Tensor], tf.Tensor] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T2]: ...
|
||||
def ignore_errors(self, log_warning: bool = False, name: str | None = None) -> Dataset[_T1]: ...
|
||||
def ignore_errors(self, log_warning: bool = False, name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def interleave(
|
||||
self,
|
||||
map_func: Callable[[_T1], Dataset[_T2]],
|
||||
map_func: Callable[[_T1_co], Dataset[_T2]],
|
||||
cycle_length: int | None = None,
|
||||
block_length: int | None = None,
|
||||
num_parallel_calls: int | None = None,
|
||||
deterministic: bool | None = None,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T2]: ...
|
||||
def __iter__(self) -> Iterator[_T1]: ...
|
||||
def __iter__(self) -> Iterator[_T1_co]: ...
|
||||
@staticmethod
|
||||
def list_files(
|
||||
file_pattern: str | Sequence[str] | TensorCompatible,
|
||||
@@ -134,8 +134,8 @@ class Dataset(ABC, Generic[_T1]):
|
||||
padding_values: ContainerGeneric[ScalarTensorCompatible] | None = None,
|
||||
drop_remainder: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T1]: ...
|
||||
def prefetch(self, buffer_size: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1]: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
def prefetch(self, buffer_size: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def ragged_batch(
|
||||
self,
|
||||
batch_size: ScalarTensorCompatible,
|
||||
@@ -162,62 +162,62 @@ class Dataset(ABC, Generic[_T1]):
|
||||
) -> Dataset[tf.Tensor]: ...
|
||||
def rebatch(
|
||||
self, batch_size: ScalarTensorCompatible, drop_remainder: bool = False, name: str | None = None
|
||||
) -> Dataset[_T1]: ...
|
||||
def reduce(self, initial_state: _T2, reduce_func: Callable[[_T2, _T1], _T2], name: str | None = None) -> _T2: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
def reduce(self, initial_state: _T2, reduce_func: Callable[[_T2, _T1_co], _T2], name: str | None = None) -> _T2: ...
|
||||
def rejection_resample(
|
||||
self,
|
||||
class_func: Callable[[_T1], ScalarTensorCompatible],
|
||||
class_func: Callable[[_T1_co], ScalarTensorCompatible],
|
||||
target_dist: TensorCompatible,
|
||||
initial_dist: TensorCompatible | None = None,
|
||||
seed: int | None = None,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T1]: ...
|
||||
def repeat(self, count: ScalarTensorCompatible | None = None, name: str | None = None) -> Dataset[_T1]: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
def repeat(self, count: ScalarTensorCompatible | None = None, name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
@staticmethod
|
||||
def sample_from_datasets(
|
||||
datasets: Sequence[Dataset[_T1]],
|
||||
datasets: Sequence[Dataset[_T1_co]],
|
||||
weights: TensorCompatible | None = None,
|
||||
seed: int | None = None,
|
||||
stop_on_empty_dataset: bool = False,
|
||||
rerandomize_each_iteration: bool | None = None,
|
||||
) -> Dataset[_T1]: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
# Incomplete as tf.train.CheckpointOptions not yet covered.
|
||||
def save(
|
||||
self,
|
||||
path: str,
|
||||
compression: _CompressionTypes = None,
|
||||
shard_func: Callable[[_T1], int] | None = None,
|
||||
shard_func: Callable[[_T1_co], int] | None = None,
|
||||
checkpoint_args: Incomplete | None = None,
|
||||
) -> None: ...
|
||||
def scan(
|
||||
self, initial_state: _T2, scan_func: Callable[[_T2, _T1], tuple[_T2, _T3]], name: str | None = None
|
||||
self, initial_state: _T2, scan_func: Callable[[_T2, _T1_co], tuple[_T2, _T3]], name: str | None = None
|
||||
) -> Dataset[_T3]: ...
|
||||
def shard(
|
||||
self, num_shards: ScalarTensorCompatible, index: ScalarTensorCompatible, name: str | None = None
|
||||
) -> Dataset[_T1]: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
def shuffle(
|
||||
self,
|
||||
buffer_size: ScalarTensorCompatible,
|
||||
seed: int | None = None,
|
||||
reshuffle_each_iteration: bool = True,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T1]: ...
|
||||
def skip(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1]: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
def skip(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def snapshot(
|
||||
self,
|
||||
path: str,
|
||||
compression: _CompressionTypes = "AUTO",
|
||||
reader_func: Callable[[Dataset[Dataset[_T1]]], Dataset[_T1]] | None = None,
|
||||
shard_func: Callable[[_T1], ScalarTensorCompatible] | None = None,
|
||||
reader_func: Callable[[Dataset[Dataset[_T1_co]]], Dataset[_T1_co]] | None = None,
|
||||
shard_func: Callable[[_T1_co], ScalarTensorCompatible] | None = None,
|
||||
name: str | None = None,
|
||||
) -> Dataset[_T1]: ...
|
||||
) -> Dataset[_T1_co]: ...
|
||||
def sparse_batch(
|
||||
self, batch_size: ScalarTensorCompatible, row_shape: tf.TensorShape | TensorCompatible, name: str | None = None
|
||||
) -> Dataset[tf.SparseTensor]: ...
|
||||
def take(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1]: ...
|
||||
def take_while(self, predicate: Callable[[_T1], ScalarTensorCompatible], name: str | None = None) -> Dataset[_T1]: ...
|
||||
def unbatch(self, name: str | None = None) -> Dataset[_T1]: ...
|
||||
def unique(self, name: str | None = None) -> Dataset[_T1]: ...
|
||||
def take(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def take_while(self, predicate: Callable[[_T1_co], ScalarTensorCompatible], name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def unbatch(self, name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def unique(self, name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
def window(
|
||||
self,
|
||||
size: ScalarTensorCompatible,
|
||||
@@ -225,8 +225,8 @@ class Dataset(ABC, Generic[_T1]):
|
||||
stride: ScalarTensorCompatible = 1,
|
||||
drop_remainder: bool = False,
|
||||
name: str | None = None,
|
||||
) -> Dataset[Dataset[_T1]]: ...
|
||||
def with_options(self, options: Options, name: str | None = None) -> Dataset[_T1]: ...
|
||||
) -> Dataset[Dataset[_T1_co]]: ...
|
||||
def with_options(self, options: Options, name: str | None = None) -> Dataset[_T1_co]: ...
|
||||
@overload
|
||||
@staticmethod
|
||||
def zip(
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Generic, TypeVar
|
||||
|
||||
from tensorflow._aliases import AnyArray
|
||||
|
||||
_Value = TypeVar("_Value", covariant=True)
|
||||
_Value_co = TypeVar("_Value_co", covariant=True)
|
||||
|
||||
class RemoteValue(Generic[_Value]):
|
||||
class RemoteValue(Generic[_Value_co]):
|
||||
def fetch(self) -> AnyArray: ...
|
||||
def get(self) -> _Value: ...
|
||||
def get(self) -> _Value_co: ...
|
||||
|
||||
def __getattr__(name: str) -> Incomplete: ...
|
||||
|
||||
@@ -11,8 +11,8 @@ from tensorflow.keras.constraints import Constraint
|
||||
from tensorflow.keras.initializers import _Initializer
|
||||
from tensorflow.keras.regularizers import Regularizer, _Regularizer
|
||||
|
||||
_InputT = TypeVar("_InputT", contravariant=True)
|
||||
_OutputT = TypeVar("_OutputT", covariant=True)
|
||||
_InputT_contra = TypeVar("_InputT_contra", contravariant=True)
|
||||
_OutputT_co = TypeVar("_OutputT_co", covariant=True)
|
||||
|
||||
class InputSpec:
|
||||
dtype: str | None
|
||||
@@ -39,9 +39,9 @@ class InputSpec:
|
||||
|
||||
# Most layers have input and output type of just Tensor and when we support default type variables,
|
||||
# maybe worth trying.
|
||||
class Layer(tf.Module, Generic[_InputT, _OutputT]):
|
||||
class Layer(tf.Module, Generic[_InputT_contra, _OutputT_co]):
|
||||
# The most general type is ContainerGeneric[InputSpec] as it really
|
||||
# depends on _InputT. For most Layers it is just InputSpec
|
||||
# depends on _InputT_contra. For most Layers it is just InputSpec
|
||||
# though. Maybe describable with HKT?
|
||||
input_spec: InputSpec | Any
|
||||
|
||||
@@ -65,11 +65,13 @@ class Layer(tf.Module, Generic[_InputT, _OutputT]):
|
||||
# *args/**kwargs are allowed, but have obscure footguns and tensorflow documentation discourages their usage.
|
||||
# First argument will automatically be cast to layer's compute dtype, but any other tensor arguments will not be.
|
||||
# Also various tensorflow tools/apis can misbehave if they encounter a layer with *args/**kwargs.
|
||||
def __call__(self, inputs: _InputT, *, training: bool = False, mask: TensorCompatible | None = None) -> _OutputT: ...
|
||||
def call(self, inputs: _InputT, /) -> _OutputT: ...
|
||||
def __call__(
|
||||
self, inputs: _InputT_contra, *, training: bool = False, mask: TensorCompatible | None = None
|
||||
) -> _OutputT_co: ...
|
||||
def call(self, inputs: _InputT_contra, /) -> _OutputT_co: ...
|
||||
|
||||
# input_shape's real type depends on _InputT, but we can't express that without HKT.
|
||||
# For example _InputT tf.Tensor -> tf.TensorShape, _InputT dict[str, tf.Tensor] -> dict[str, tf.TensorShape].
|
||||
# input_shape's real type depends on _InputT_contra, but we can't express that without HKT.
|
||||
# For example _InputT_contra tf.Tensor -> tf.TensorShape, _InputT_contra dict[str, tf.Tensor] -> dict[str, tf.TensorShape].
|
||||
def build(self, input_shape: Any, /) -> None: ...
|
||||
@overload
|
||||
def compute_output_shape(self: Layer[tf.Tensor, tf.Tensor], input_shape: tf.TensorShape, /) -> tf.TensorShape: ...
|
||||
|
||||
@@ -9,14 +9,14 @@ import numpy.typing as npt
|
||||
import tensorflow as tf
|
||||
from tensorflow import Variable
|
||||
from tensorflow._aliases import ContainerGeneric, ShapeLike, TensorCompatible
|
||||
from tensorflow.keras.layers import Layer, _InputT, _OutputT
|
||||
from tensorflow.keras.layers import Layer, _InputT_contra, _OutputT_co
|
||||
from tensorflow.keras.optimizers import Optimizer
|
||||
|
||||
_Loss: TypeAlias = str | tf.keras.losses.Loss | Callable[[TensorCompatible, TensorCompatible], tf.Tensor]
|
||||
_Metric: TypeAlias = str | tf.keras.metrics.Metric | Callable[[TensorCompatible, TensorCompatible], tf.Tensor] | None
|
||||
|
||||
# Missing keras.src.backend.tensorflow.trainer.TensorFlowTrainer as a base class, which is not exposed by tensorflow
|
||||
class Model(Layer[_InputT, _OutputT]):
|
||||
class Model(Layer[_InputT_contra, _OutputT_co]):
|
||||
_train_counter: tf.Variable
|
||||
_test_counter: tf.Variable
|
||||
optimizer: Optimizer | None
|
||||
@@ -27,13 +27,15 @@ class Model(Layer[_InputT, _OutputT]):
|
||||
) -> tf.Tensor | None: ...
|
||||
stop_training: bool
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Model[_InputT, _OutputT]: ...
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Model[_InputT_contra, _OutputT_co]: ...
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
|
||||
def __setattr__(self, name: str, value: Any) -> None: ...
|
||||
def __reduce__(self): ...
|
||||
def build(self, input_shape: ShapeLike) -> None: ...
|
||||
def __call__(self, inputs: _InputT, *, training: bool = False, mask: TensorCompatible | None = None) -> _OutputT: ...
|
||||
def call(self, inputs: _InputT, training: bool | None = None, mask: TensorCompatible | None = None) -> _OutputT: ...
|
||||
def __call__(
|
||||
self, inputs: _InputT_contra, *, training: bool = False, mask: TensorCompatible | None = None
|
||||
) -> _OutputT_co: ...
|
||||
def call(self, inputs: _InputT_contra, training: bool | None = None, mask: TensorCompatible | None = None) -> _OutputT_co: ...
|
||||
# Ideally loss/metrics/output would share the same structure but higher kinded types are not supported.
|
||||
def compile(
|
||||
self,
|
||||
@@ -106,8 +108,8 @@ class Model(Layer[_InputT, _OutputT]):
|
||||
return_dict: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> float | list[float]: ...
|
||||
def predict_step(self, data: _InputT) -> _OutputT: ...
|
||||
def make_predict_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], _OutputT]: ...
|
||||
def predict_step(self, data: _InputT_contra) -> _OutputT_co: ...
|
||||
def make_predict_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], _OutputT_co]: ...
|
||||
def predict(
|
||||
self,
|
||||
x: TensorCompatible | tf.data.Dataset[Incomplete],
|
||||
@@ -115,7 +117,7 @@ class Model(Layer[_InputT, _OutputT]):
|
||||
verbose: Literal["auto", 0, 1, 2] = "auto",
|
||||
steps: int | None = None,
|
||||
callbacks: list[tf.keras.callbacks.Callback] | None = None,
|
||||
) -> _OutputT: ...
|
||||
) -> _OutputT_co: ...
|
||||
def reset_metrics(self) -> None: ...
|
||||
def train_on_batch(
|
||||
self,
|
||||
@@ -132,7 +134,7 @@ class Model(Layer[_InputT, _OutputT]):
|
||||
sample_weight: npt.NDArray[np.float64] | None = None,
|
||||
return_dict: bool = False,
|
||||
) -> float | list[float]: ...
|
||||
def predict_on_batch(self, x: Iterator[_InputT]) -> npt.NDArray[Incomplete]: ...
|
||||
def predict_on_batch(self, x: Iterator[_InputT_contra]) -> npt.NDArray[Incomplete]: ...
|
||||
@property
|
||||
def trainable_weights(self) -> list[Variable]: ...
|
||||
@property
|
||||
|
||||
@@ -10,7 +10,7 @@ from tensorflow.saved_model.experimental import VariablePolicy
|
||||
from tensorflow.types.experimental import ConcreteFunction, PolymorphicFunction
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R", covariant=True)
|
||||
_R_co = TypeVar("_R_co", covariant=True)
|
||||
|
||||
class Asset:
|
||||
@property
|
||||
@@ -77,10 +77,10 @@ class SaveOptions:
|
||||
|
||||
def contains_saved_model(export_dir: str | Path) -> bool: ...
|
||||
|
||||
class _LoadedAttributes(Generic[_P, _R]):
|
||||
signatures: Mapping[str, ConcreteFunction[_P, _R]]
|
||||
class _LoadedAttributes(Generic[_P, _R_co]):
|
||||
signatures: Mapping[str, ConcreteFunction[_P, _R_co]]
|
||||
|
||||
class _LoadedModel(AutoTrackable, _LoadedAttributes[_P, _R]):
|
||||
class _LoadedModel(AutoTrackable, _LoadedAttributes[_P, _R_co]):
|
||||
variables: list[tf.Variable]
|
||||
trainable_variables: list[tf.Variable]
|
||||
# TF1 model artifact specific
|
||||
|
||||
@@ -7,23 +7,23 @@ import tensorflow as tf
|
||||
from tensorflow._aliases import ContainerGeneric
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R", covariant=True)
|
||||
_R_co = TypeVar("_R_co", covariant=True)
|
||||
|
||||
class Callable(Generic[_P, _R], metaclass=abc.ABCMeta):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
|
||||
class Callable(Generic[_P, _R_co], metaclass=abc.ABCMeta):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ...
|
||||
|
||||
class ConcreteFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
|
||||
class ConcreteFunction(Callable[_P, _R_co], metaclass=abc.ABCMeta):
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ...
|
||||
|
||||
class PolymorphicFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
|
||||
class PolymorphicFunction(Callable[_P, _R_co], metaclass=abc.ABCMeta):
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R]: ...
|
||||
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R_co]: ...
|
||||
@overload
|
||||
@abc.abstractmethod
|
||||
def get_concrete_function(
|
||||
self, *args: ContainerGeneric[tf.TypeSpec[Any]], **kwargs: ContainerGeneric[tf.TypeSpec[Any]]
|
||||
) -> ConcreteFunction[_P, _R]: ...
|
||||
) -> ConcreteFunction[_P, _R_co]: ...
|
||||
def experimental_get_compiler_ir(self, *args, **kwargs): ...
|
||||
|
||||
GenericFunction = PolymorphicFunction
|
||||
|
||||
Reference in New Issue
Block a user