Enable Ruff PLC (Pylint Convention) (#13306)

This commit is contained in:
Avasam
2025-03-03 09:39:40 -05:00
committed by GitHub
parent 738cc5046a
commit 6d6e858e63
17 changed files with 173 additions and 153 deletions
+16 -8
View File
@@ -45,6 +45,7 @@ select = [
"I", # isort
"N", # pep8-naming
"PGH", # pygrep-hooks
"PLC", # Pylint Convention
"RUF", # Ruff-specific and unused-noqa
"TRY", # tryceratops
"UP", # pyupgrade
@@ -159,19 +160,26 @@ ignore = [
# A lot of stubs are incomplete on purpose, and that's configured through pyright
# Some ANN204 (special method) are autofixable in stubs, but not all.
"ANN2", # Missing return type annotation for ...
# Most pep8-naming rules don't apply for third-party stubs like typeshed.
# N811 to N814 could apply, but we often use them to disambiguate a name whilst making it look like a more common one
"N8",
# Ruff 0.8.0 added sorting of __all__ and __slots_.
# There is no consensus on whether we want to apply this to stubs, so keeping the status quo.
# See https://github.com/python/typeshed/pull/13108
"RUF022", # `__all__` is not sorted
"RUF023", # `{}.__slots__` is not sorted
###
# Rules that are out of the control of stub authors:
###
"F403", # `from . import *` used; unable to detect undefined names
# Stubs can sometimes re-export entire modules.
# Issues with using a star-imported name will be caught by type-checkers.
"F405", # may be undefined, or defined from star imports
# Ruff 0.8.0 added sorting of __all__ and __slots_.
# There is no consensus on whether we want to apply this to stubs, so keeping the status quo.
# See https://github.com/python/typeshed/pull/13108
"RUF022",
"RUF023",
# Most pep8-naming rules don't apply for third-party stubs like typeshed.
# N811 to N814 could apply, but we often use them to disambiguate a name whilst making it look like a more common one
"N8", # pep8-naming
"PLC2701", # Private name import from external module
]
"lib/ts_utils/**" = [
# Doesn't affect stubs. The only re-exports we have should be in our local lib ts_utils
"PLC0414", # Import alias does not rename original package
]
"*_pb2.pyi" = [
# Leave the docstrings as-is, matching source
+15 -15
View File
@@ -89,8 +89,8 @@ _T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
_T4 = TypeVar("_T4")
_T5 = TypeVar("_T5")
_SupportsNextT = TypeVar("_SupportsNextT", bound=SupportsNext[Any], covariant=True)
_SupportsAnextT = TypeVar("_SupportsAnextT", bound=SupportsAnext[Any], covariant=True)
_SupportsNextT_co = TypeVar("_SupportsNextT_co", bound=SupportsNext[Any], covariant=True)
_SupportsAnextT_co = TypeVar("_SupportsAnextT_co", bound=SupportsAnext[Any], covariant=True)
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
_AwaitableT_co = TypeVar("_AwaitableT_co", bound=Awaitable[Any], covariant=True)
_P = ParamSpec("_P")
@@ -1319,7 +1319,7 @@ class _PathLike(Protocol[AnyStr_co]):
def __fspath__(self) -> AnyStr_co: ...
if sys.version_info >= (3, 10):
def aiter(async_iterable: SupportsAiter[_SupportsAnextT], /) -> _SupportsAnextT: ...
def aiter(async_iterable: SupportsAiter[_SupportsAnextT_co], /) -> _SupportsAnextT_co: ...
class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):
def __anext__(self) -> _AwaitableT_co: ...
@@ -1481,7 +1481,7 @@ class _GetItemIterable(Protocol[_T_co]):
def __getitem__(self, i: int, /) -> _T_co: ...
@overload
def iter(object: SupportsIter[_SupportsNextT], /) -> _SupportsNextT: ...
def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ...
@overload
def iter(object: _GetItemIterable[_T], /) -> Iterator[_T]: ...
@overload
@@ -1688,17 +1688,17 @@ def print(
*values: object, sep: str | None = " ", end: str | None = "\n", file: _SupportsWriteAndFlush[str] | None = None, flush: bool
) -> None: ...
_E = TypeVar("_E", contravariant=True)
_M = TypeVar("_M", contravariant=True)
_E_contra = TypeVar("_E_contra", contravariant=True)
_M_contra = TypeVar("_M_contra", contravariant=True)
class _SupportsPow2(Protocol[_E, _T_co]):
def __pow__(self, other: _E, /) -> _T_co: ...
class _SupportsPow2(Protocol[_E_contra, _T_co]):
def __pow__(self, other: _E_contra, /) -> _T_co: ...
class _SupportsPow3NoneOnly(Protocol[_E, _T_co]):
def __pow__(self, other: _E, modulo: None = None, /) -> _T_co: ...
class _SupportsPow3NoneOnly(Protocol[_E_contra, _T_co]):
def __pow__(self, other: _E_contra, modulo: None = None, /) -> _T_co: ...
class _SupportsPow3(Protocol[_E, _M, _T_co]):
def __pow__(self, other: _E, modulo: _M, /) -> _T_co: ...
class _SupportsPow3(Protocol[_E_contra, _M_contra, _T_co]):
def __pow__(self, other: _E_contra, modulo: _M_contra, /) -> _T_co: ...
_SupportsSomeKindOfPow = ( # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed
_SupportsPow2[Any, Any] | _SupportsPow3NoneOnly[Any, Any] | _SupportsPow3[Any, Any, Any]
@@ -1734,11 +1734,11 @@ def pow(base: float, exp: complex | _SupportsSomeKindOfPow, mod: None = None) ->
@overload
def pow(base: complex, exp: complex | _SupportsSomeKindOfPow, mod: None = None) -> complex: ...
@overload
def pow(base: _SupportsPow2[_E, _T_co], exp: _E, mod: None = None) -> _T_co: ... # type: ignore[overload-overlap]
def pow(base: _SupportsPow2[_E_contra, _T_co], exp: _E_contra, mod: None = None) -> _T_co: ... # type: ignore[overload-overlap]
@overload
def pow(base: _SupportsPow3NoneOnly[_E, _T_co], exp: _E, mod: None = None) -> _T_co: ... # type: ignore[overload-overlap]
def pow(base: _SupportsPow3NoneOnly[_E_contra, _T_co], exp: _E_contra, mod: None = None) -> _T_co: ... # type: ignore[overload-overlap]
@overload
def pow(base: _SupportsPow3[_E, _M, _T_co], exp: _E, mod: _M) -> _T_co: ...
def pow(base: _SupportsPow3[_E_contra, _M_contra, _T_co], exp: _E_contra, mod: _M_contra) -> _T_co: ...
@overload
def pow(base: _SupportsSomeKindOfPow, exp: float, mod: None = None) -> Any: ...
@overload
+5 -5
View File
@@ -33,7 +33,7 @@ _T_co = TypeVar("_T_co", covariant=True)
_T_io = TypeVar("_T_io", bound=IO[str] | None)
_ExitT_co = TypeVar("_ExitT_co", covariant=True, bound=bool | None, default=bool | None)
_F = TypeVar("_F", bound=Callable[..., Any])
_G = TypeVar("_G", bound=Generator[Any, Any, Any] | AsyncGenerator[Any, Any], covariant=True)
_G_co = TypeVar("_G_co", bound=Generator[Any, Any, Any] | AsyncGenerator[Any, Any], covariant=True)
_P = ParamSpec("_P")
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=None)
@@ -68,11 +68,11 @@ class ContextDecorator:
def _recreate_cm(self) -> Self: ...
def __call__(self, func: _F) -> _F: ...
class _GeneratorContextManagerBase(Generic[_G]):
class _GeneratorContextManagerBase(Generic[_G_co]):
# Ideally this would use ParamSpec, but that requires (*args, **kwargs), which this isn't. see #6676
def __init__(self, func: Callable[..., _G], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
gen: _G
func: Callable[..., _G]
def __init__(self, func: Callable[..., _G_co], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
gen: _G_co
func: Callable[..., _G_co]
args: tuple[Any, ...]
kwds: dict[str, Any]
+6 -6
View File
@@ -143,8 +143,8 @@ if sys.version_info >= (3, 11):
_P = ParamSpec("_P")
_T = TypeVar("_T")
_F = TypeVar("_F", bound=Callable[..., Any])
_T_cont = TypeVar("_T_cont", contravariant=True)
_V_cont = TypeVar("_V_cont", contravariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)
_V_contra = TypeVar("_V_contra", contravariant=True)
#
# Types and members
@@ -228,11 +228,11 @@ def isasyncgenfunction(obj: Callable[_P, Any]) -> TypeGuard[Callable[_P, AsyncGe
@overload
def isasyncgenfunction(obj: object) -> TypeGuard[Callable[..., AsyncGeneratorType[Any, Any]]]: ...
class _SupportsSet(Protocol[_T_cont, _V_cont]):
def __set__(self, instance: _T_cont, value: _V_cont, /) -> None: ...
class _SupportsSet(Protocol[_T_contra, _V_contra]):
def __set__(self, instance: _T_contra, value: _V_contra, /) -> None: ...
class _SupportsDelete(Protocol[_T_cont]):
def __delete__(self, instance: _T_cont, /) -> None: ...
class _SupportsDelete(Protocol[_T_contra]):
def __delete__(self, instance: _T_contra, /) -> None: ...
def isasyncgen(object: object) -> TypeIs[AsyncGeneratorType[Any, Any]]: ...
def istraceback(object: object) -> TypeIs[TracebackType]: ...
+9 -9
View File
@@ -12,10 +12,10 @@ __all__ = ["Client", "Listener", "Pipe", "wait"]
_Address: TypeAlias = str | tuple[str, int]
# Defaulting to Any to avoid forcing generics on a lot of pre-existing code
_SendT = TypeVar("_SendT", contravariant=True, default=Any)
_RecvT = TypeVar("_RecvT", covariant=True, default=Any)
_SendT_contra = TypeVar("_SendT_contra", contravariant=True, default=Any)
_RecvT_co = TypeVar("_RecvT_co", covariant=True, default=Any)
class _ConnectionBase(Generic[_SendT, _RecvT]):
class _ConnectionBase(Generic[_SendT_contra, _RecvT_co]):
def __init__(self, handle: SupportsIndex, readable: bool = True, writable: bool = True) -> None: ...
@property
def closed(self) -> bool: ... # undocumented
@@ -26,10 +26,10 @@ class _ConnectionBase(Generic[_SendT, _RecvT]):
def fileno(self) -> int: ...
def close(self) -> None: ...
def send_bytes(self, buf: ReadableBuffer, offset: int = 0, size: int | None = None) -> None: ...
def send(self, obj: _SendT) -> None: ...
def send(self, obj: _SendT_contra) -> None: ...
def recv_bytes(self, maxlength: int | None = None) -> bytes: ...
def recv_bytes_into(self, buf: Any, offset: int = 0) -> int: ...
def recv(self) -> _RecvT: ...
def recv(self) -> _RecvT_co: ...
def poll(self, timeout: float | None = 0.0) -> bool: ...
def __enter__(self) -> Self: ...
def __exit__(
@@ -37,10 +37,10 @@ class _ConnectionBase(Generic[_SendT, _RecvT]):
) -> None: ...
def __del__(self) -> None: ...
class Connection(_ConnectionBase[_SendT, _RecvT]): ...
class Connection(_ConnectionBase[_SendT_contra, _RecvT_co]): ...
if sys.platform == "win32":
class PipeConnection(_ConnectionBase[_SendT, _RecvT]): ...
class PipeConnection(_ConnectionBase[_SendT_contra, _RecvT_co]): ...
class Listener:
def __init__(
@@ -66,8 +66,8 @@ else:
def answer_challenge(connection: Connection[Any, Any], authkey: bytes) -> None: ...
def wait(
object_list: Iterable[Connection[_SendT, _RecvT] | socket.socket | int], timeout: float | None = None
) -> list[Connection[_SendT, _RecvT] | socket.socket | int]: ...
object_list: Iterable[Connection[_SendT_contra, _RecvT_co] | socket.socket | int], timeout: float | None = None
) -> list[Connection[_SendT_contra, _RecvT_co] | socket.socket | int]: ...
def Client(address: _Address, family: str | None = None, authkey: bytes | None = None) -> Connection[Any, Any]: ...
# N.B. Keep this in sync with multiprocessing.context.BaseContext.Pipe.
+7 -7
View File
@@ -510,15 +510,15 @@ class Awaitable(Protocol[_T_co]):
def __await__(self) -> Generator[Any, Any, _T_co]: ...
# Non-default variations to accommodate couroutines, and `AwaitableGenerator` having a 4th type parameter.
_SendT_contra_nd = TypeVar("_SendT_contra_nd", contravariant=True)
_ReturnT_co_nd = TypeVar("_ReturnT_co_nd", covariant=True)
_SendT_nd_contra = TypeVar("_SendT_nd_contra", contravariant=True)
_ReturnT_nd_co = TypeVar("_ReturnT_nd_co", covariant=True)
class Coroutine(Awaitable[_ReturnT_co_nd], Generic[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd]):
class Coroutine(Awaitable[_ReturnT_nd_co], Generic[_YieldT_co, _SendT_nd_contra, _ReturnT_nd_co]):
__name__: str
__qualname__: str
@abstractmethod
def send(self, value: _SendT_contra_nd, /) -> _YieldT_co: ...
def send(self, value: _SendT_nd_contra, /) -> _YieldT_co: ...
@overload
@abstractmethod
def throw(
@@ -534,9 +534,9 @@ class Coroutine(Awaitable[_ReturnT_co_nd], Generic[_YieldT_co, _SendT_contra_nd,
# The parameters correspond to Generator, but the 4th is the original type.
@type_check_only
class AwaitableGenerator(
Awaitable[_ReturnT_co_nd],
Generator[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd],
Generic[_YieldT_co, _SendT_contra_nd, _ReturnT_co_nd, _S],
Awaitable[_ReturnT_nd_co],
Generator[_YieldT_co, _SendT_nd_contra, _ReturnT_nd_co],
Generic[_YieldT_co, _SendT_nd_contra, _ReturnT_nd_co, _S],
metaclass=ABCMeta,
): ...
+13 -5
View File
@@ -42,7 +42,7 @@ __all__ = (
"Disabled",
)
_ValuesT = TypeVar("_ValuesT", bound=Collection[Any], contravariant=True)
_ValuesT_contra = TypeVar("_ValuesT_contra", bound=Collection[Any], contravariant=True)
class ValidationError(ValueError):
def __init__(self, message: str = "", *args: object) -> None: ...
@@ -150,9 +150,13 @@ class AnyOf:
@overload
def __init__(self, values: Collection[Any], message: str | None = None, values_formatter: None = None) -> None: ...
@overload
def __init__(self, values: _ValuesT, message: str | None, values_formatter: Callable[[_ValuesT], str]) -> None: ...
def __init__(
self, values: _ValuesT_contra, message: str | None, values_formatter: Callable[[_ValuesT_contra], str]
) -> None: ...
@overload
def __init__(self, values: _ValuesT, message: str | None = None, *, values_formatter: Callable[[_ValuesT], str]) -> None: ...
def __init__(
self, values: _ValuesT_contra, message: str | None = None, *, values_formatter: Callable[[_ValuesT_contra], str]
) -> None: ...
def __call__(self, form: BaseForm, field: Field) -> None: ...
@staticmethod
def default_values_formatter(values: Iterable[object]) -> str: ...
@@ -164,9 +168,13 @@ class NoneOf:
@overload
def __init__(self, values: Collection[Any], message: str | None = None, values_formatter: None = None) -> None: ...
@overload
def __init__(self, values: _ValuesT, message: str | None, values_formatter: Callable[[_ValuesT], str]) -> None: ...
def __init__(
self, values: _ValuesT_contra, message: str | None, values_formatter: Callable[[_ValuesT_contra], str]
) -> None: ...
@overload
def __init__(self, values: _ValuesT, message: str | None = None, *, values_formatter: Callable[[_ValuesT], str]) -> None: ...
def __init__(
self, values: _ValuesT_contra, message: str | None = None, *, values_formatter: Callable[[_ValuesT_contra], str]
) -> None: ...
def __call__(self, form: BaseForm, field: Field) -> None: ...
@staticmethod
def default_values_formatter(v: Iterable[object]) -> str: ...
+13 -13
View File
@@ -22,31 +22,31 @@ class Callpoint:
def from_tb(cls, tb: TracebackType) -> Self: ...
def tb_frame_str(self) -> str: ...
_CallpointT = TypeVar("_CallpointT", bound=Callpoint, covariant=True, default=Callpoint)
_CallpointT_co = TypeVar("_CallpointT_co", bound=Callpoint, covariant=True, default=Callpoint)
class TracebackInfo(Generic[_CallpointT]):
callpoint_type: type[_CallpointT]
frames: list[_CallpointT]
def __init__(self, frames: list[_CallpointT]) -> None: ...
class TracebackInfo(Generic[_CallpointT_co]):
callpoint_type: type[_CallpointT_co]
frames: list[_CallpointT_co]
def __init__(self, frames: list[_CallpointT_co]) -> None: ...
@classmethod
def from_frame(cls, frame: FrameType | None = None, level: int = 1, limit: int | None = None) -> Self: ...
@classmethod
def from_traceback(cls, tb: TracebackType | None = None, limit: int | None = None) -> Self: ...
@classmethod
def from_dict(cls, d: Mapping[Literal["frames"], list[_CallpointT]]) -> Self: ...
def to_dict(self) -> dict[str, list[dict[str, _CallpointT]]]: ...
def from_dict(cls, d: Mapping[Literal["frames"], list[_CallpointT_co]]) -> Self: ...
def to_dict(self) -> dict[str, list[dict[str, _CallpointT_co]]]: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_CallpointT]: ...
def __iter__(self) -> Iterator[_CallpointT_co]: ...
def get_formatted(self) -> str: ...
_TracebackInfoT = TypeVar("_TracebackInfoT", bound=TracebackInfo, covariant=True, default=TracebackInfo)
_TracebackInfoT_co = TypeVar("_TracebackInfoT_co", bound=TracebackInfo, covariant=True, default=TracebackInfo)
class ExceptionInfo(Generic[_TracebackInfoT]):
tb_info_type: type[_TracebackInfoT]
class ExceptionInfo(Generic[_TracebackInfoT_co]):
tb_info_type: type[_TracebackInfoT_co]
exc_type: str
exc_msg: str
tb_info: _TracebackInfoT
def __init__(self, exc_type: str, exc_msg: str, tb_info: _TracebackInfoT) -> None: ...
tb_info: _TracebackInfoT_co
def __init__(self, exc_type: str, exc_msg: str, tb_info: _TracebackInfoT_co) -> None: ...
@classmethod
def from_exc_info(cls, exc_type: type[BaseException], exc_value: BaseException, traceback: TracebackType) -> Self: ...
@classmethod
@@ -15,10 +15,10 @@ def difference(G: Graph[_Node], H: Graph[_Node]): ...
@_dispatchable
def symmetric_difference(G: Graph[_Node], H: Graph[_Node]): ...
_X = TypeVar("_X", bound=Hashable, covariant=True)
_Y = TypeVar("_Y", bound=Hashable, covariant=True)
_X_co = TypeVar("_X_co", bound=Hashable, covariant=True)
_Y_co = TypeVar("_Y_co", bound=Hashable, covariant=True)
@_dispatchable
def compose(G: Graph[_X], H: Graph[_Y]) -> DiGraph[_X | _Y]: ...
def compose(G: Graph[_X_co], H: Graph[_Y_co]) -> DiGraph[_X_co | _Y_co]: ...
@_dispatchable
def union(G: Graph[_X], H: Graph[_Y], rename: Iterable[Incomplete] | None = ()) -> DiGraph[_X | _Y]: ...
def union(G: Graph[_X_co], H: Graph[_Y_co], rename: Iterable[Incomplete] | None = ()) -> DiGraph[_X_co | _Y_co]: ...
+4 -4
View File
@@ -7,7 +7,7 @@ from typing import Any, ClassVar, Generic, TypedDict, TypeVar
from typing_extensions import ParamSpec, Self
_T = TypeVar("_T")
_AbstractListener_T = TypeVar("_AbstractListener_T", bound=AbstractListener)
_AbstractListenerT = TypeVar("_AbstractListenerT", bound=AbstractListener)
_P = ParamSpec("_P")
class _RESOLUTIONS(TypedDict):
@@ -49,15 +49,15 @@ class AbstractListener(threading.Thread):
def _stop_platform(self) -> None: ... # undocumented
def join(self, timeout: float | None = None, *args: Any) -> None: ...
class Events(Generic[_T, _AbstractListener_T]):
_Listener: type[_AbstractListener_T] | None # undocumented
class Events(Generic[_T, _AbstractListenerT]):
_Listener: type[_AbstractListenerT] | None # undocumented
class Event:
def __eq__(self, other: object) -> bool: ...
_event_queue: Queue[_T] # undocumented
_sentinel: object # undocumented
_listener: _AbstractListener_T # undocumented
_listener: _AbstractListenerT # undocumented
start: Callable[[], None]
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def __enter__(self) -> Self: ...
+2 -2
View File
@@ -8,10 +8,10 @@ from typing_extensions import Self
from serial import Serial
_AnyStr_T = TypeVar("_AnyStr_T", contravariant=True)
_AnyStrT_contra = TypeVar("_AnyStrT_contra", contravariant=True)
@type_check_only
class _SupportsWriteAndFlush(SupportsWrite[_AnyStr_T], SupportsFlush, Protocol): ...
class _SupportsWriteAndFlush(SupportsWrite[_AnyStrT_contra], SupportsFlush, Protocol): ...
@type_check_only
class _SupportsRead(Protocol):
+43 -43
View File
@@ -14,21 +14,21 @@ from tensorflow.dtypes import DType
from tensorflow.io import _CompressionTypes
from tensorflow.python.trackable.base import Trackable
_T1 = TypeVar("_T1", covariant=True)
_T1_co = TypeVar("_T1_co", covariant=True)
_T2 = TypeVar("_T2")
_T3 = TypeVar("_T3")
class Iterator(_Iterator[_T1], Trackable, ABC):
class Iterator(_Iterator[_T1_co], Trackable, ABC):
@property
@abstractmethod
def element_spec(self) -> ContainerGeneric[TypeSpec[Any]]: ...
@abstractmethod
def get_next(self) -> _T1: ...
def get_next(self) -> _T1_co: ...
@abstractmethod
def get_next_as_optional(self) -> tf.experimental.Optional[_T1]: ...
def get_next_as_optional(self) -> tf.experimental.Optional[_T1_co]: ...
class Dataset(ABC, Generic[_T1]):
def apply(self, transformation_func: Callable[[Dataset[_T1]], Dataset[_T2]]) -> Dataset[_T2]: ...
class Dataset(ABC, Generic[_T1_co]):
def apply(self, transformation_func: Callable[[Dataset[_T1_co]], Dataset[_T2]]) -> Dataset[_T2]: ...
def as_numpy_iterator(self) -> Iterator[np.ndarray[Any, Any]]: ...
def batch(
self,
@@ -37,10 +37,10 @@ class Dataset(ABC, Generic[_T1]):
num_parallel_calls: int | None = None,
deterministic: bool | None = None,
name: str | None = None,
) -> Dataset[_T1]: ...
) -> Dataset[_T1_co]: ...
def bucket_by_sequence_length(
self,
element_length_func: Callable[[_T1], ScalarTensorCompatible],
element_length_func: Callable[[_T1_co], ScalarTensorCompatible],
bucket_boundaries: Sequence[int],
bucket_batch_sizes: Sequence[int],
padded_shapes: ContainerGeneric[tf.TensorShape | TensorCompatible] | None = None,
@@ -49,14 +49,14 @@ class Dataset(ABC, Generic[_T1]):
no_padding: bool = False,
drop_remainder: bool = False,
name: str | None = None,
) -> Dataset[_T1]: ...
def cache(self, filename: str = "", name: str | None = None) -> Dataset[_T1]: ...
) -> Dataset[_T1_co]: ...
def cache(self, filename: str = "", name: str | None = None) -> Dataset[_T1_co]: ...
def cardinality(self) -> int: ...
@staticmethod
def choose_from_datasets(
datasets: Sequence[Dataset[_T2]], choice_dataset: Dataset[tf.Tensor], stop_on_empty_dataset: bool = True
) -> Dataset[_T2]: ...
def concatenate(self, dataset: Dataset[_T1], name: str | None = None) -> Dataset[_T1]: ...
def concatenate(self, dataset: Dataset[_T1_co], name: str | None = None) -> Dataset[_T1_co]: ...
@staticmethod
def counter(
start: ScalarTensorCompatible = 0, step: ScalarTensorCompatible = 1, dtype: DType = ..., name: str | None = None
@@ -64,9 +64,9 @@ class Dataset(ABC, Generic[_T1]):
@property
@abstractmethod
def element_spec(self) -> ContainerGeneric[TypeSpec[Any]]: ...
def enumerate(self, start: ScalarTensorCompatible = 0, name: str | None = None) -> Dataset[tuple[int, _T1]]: ...
def filter(self, predicate: Callable[[_T1], bool | tf.Tensor], name: str | None = None) -> Dataset[_T1]: ...
def flat_map(self, map_func: Callable[[_T1], Dataset[_T2]], name: str | None = None) -> Dataset[_T2]: ...
def enumerate(self, start: ScalarTensorCompatible = 0, name: str | None = None) -> Dataset[tuple[int, _T1_co]]: ...
def filter(self, predicate: Callable[[_T1_co], bool | tf.Tensor], name: str | None = None) -> Dataset[_T1_co]: ...
def flat_map(self, map_func: Callable[[_T1_co], Dataset[_T2]], name: str | None = None) -> Dataset[_T2]: ...
# PEP 646 can be used here for a more precise type when better supported.
@staticmethod
def from_generator(
@@ -81,26 +81,26 @@ class Dataset(ABC, Generic[_T1]):
def from_tensors(tensors: Any, name: str | None = None) -> Dataset[Any]: ...
@staticmethod
def from_tensor_slices(tensors: TensorCompatible, name: str | None = None) -> Dataset[Any]: ...
def get_single_element(self, name: str | None = None) -> _T1: ...
def get_single_element(self, name: str | None = None) -> _T1_co: ...
def group_by_window(
self,
key_func: Callable[[_T1], tf.Tensor],
reduce_func: Callable[[tf.Tensor, Dataset[_T1]], Dataset[_T2]],
key_func: Callable[[_T1_co], tf.Tensor],
reduce_func: Callable[[tf.Tensor, Dataset[_T1_co]], Dataset[_T2]],
window_size: ScalarTensorCompatible | None = None,
window_size_func: Callable[[tf.Tensor], tf.Tensor] | None = None,
name: str | None = None,
) -> Dataset[_T2]: ...
def ignore_errors(self, log_warning: bool = False, name: str | None = None) -> Dataset[_T1]: ...
def ignore_errors(self, log_warning: bool = False, name: str | None = None) -> Dataset[_T1_co]: ...
def interleave(
self,
map_func: Callable[[_T1], Dataset[_T2]],
map_func: Callable[[_T1_co], Dataset[_T2]],
cycle_length: int | None = None,
block_length: int | None = None,
num_parallel_calls: int | None = None,
deterministic: bool | None = None,
name: str | None = None,
) -> Dataset[_T2]: ...
def __iter__(self) -> Iterator[_T1]: ...
def __iter__(self) -> Iterator[_T1_co]: ...
@staticmethod
def list_files(
file_pattern: str | Sequence[str] | TensorCompatible,
@@ -134,8 +134,8 @@ class Dataset(ABC, Generic[_T1]):
padding_values: ContainerGeneric[ScalarTensorCompatible] | None = None,
drop_remainder: bool = False,
name: str | None = None,
) -> Dataset[_T1]: ...
def prefetch(self, buffer_size: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1]: ...
) -> Dataset[_T1_co]: ...
def prefetch(self, buffer_size: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1_co]: ...
def ragged_batch(
self,
batch_size: ScalarTensorCompatible,
@@ -162,62 +162,62 @@ class Dataset(ABC, Generic[_T1]):
) -> Dataset[tf.Tensor]: ...
def rebatch(
self, batch_size: ScalarTensorCompatible, drop_remainder: bool = False, name: str | None = None
) -> Dataset[_T1]: ...
def reduce(self, initial_state: _T2, reduce_func: Callable[[_T2, _T1], _T2], name: str | None = None) -> _T2: ...
) -> Dataset[_T1_co]: ...
def reduce(self, initial_state: _T2, reduce_func: Callable[[_T2, _T1_co], _T2], name: str | None = None) -> _T2: ...
def rejection_resample(
self,
class_func: Callable[[_T1], ScalarTensorCompatible],
class_func: Callable[[_T1_co], ScalarTensorCompatible],
target_dist: TensorCompatible,
initial_dist: TensorCompatible | None = None,
seed: int | None = None,
name: str | None = None,
) -> Dataset[_T1]: ...
def repeat(self, count: ScalarTensorCompatible | None = None, name: str | None = None) -> Dataset[_T1]: ...
) -> Dataset[_T1_co]: ...
def repeat(self, count: ScalarTensorCompatible | None = None, name: str | None = None) -> Dataset[_T1_co]: ...
@staticmethod
def sample_from_datasets(
datasets: Sequence[Dataset[_T1]],
datasets: Sequence[Dataset[_T1_co]],
weights: TensorCompatible | None = None,
seed: int | None = None,
stop_on_empty_dataset: bool = False,
rerandomize_each_iteration: bool | None = None,
) -> Dataset[_T1]: ...
) -> Dataset[_T1_co]: ...
# Incomplete as tf.train.CheckpointOptions not yet covered.
def save(
self,
path: str,
compression: _CompressionTypes = None,
shard_func: Callable[[_T1], int] | None = None,
shard_func: Callable[[_T1_co], int] | None = None,
checkpoint_args: Incomplete | None = None,
) -> None: ...
def scan(
self, initial_state: _T2, scan_func: Callable[[_T2, _T1], tuple[_T2, _T3]], name: str | None = None
self, initial_state: _T2, scan_func: Callable[[_T2, _T1_co], tuple[_T2, _T3]], name: str | None = None
) -> Dataset[_T3]: ...
def shard(
self, num_shards: ScalarTensorCompatible, index: ScalarTensorCompatible, name: str | None = None
) -> Dataset[_T1]: ...
) -> Dataset[_T1_co]: ...
def shuffle(
self,
buffer_size: ScalarTensorCompatible,
seed: int | None = None,
reshuffle_each_iteration: bool = True,
name: str | None = None,
) -> Dataset[_T1]: ...
def skip(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1]: ...
) -> Dataset[_T1_co]: ...
def skip(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1_co]: ...
def snapshot(
self,
path: str,
compression: _CompressionTypes = "AUTO",
reader_func: Callable[[Dataset[Dataset[_T1]]], Dataset[_T1]] | None = None,
shard_func: Callable[[_T1], ScalarTensorCompatible] | None = None,
reader_func: Callable[[Dataset[Dataset[_T1_co]]], Dataset[_T1_co]] | None = None,
shard_func: Callable[[_T1_co], ScalarTensorCompatible] | None = None,
name: str | None = None,
) -> Dataset[_T1]: ...
) -> Dataset[_T1_co]: ...
def sparse_batch(
self, batch_size: ScalarTensorCompatible, row_shape: tf.TensorShape | TensorCompatible, name: str | None = None
) -> Dataset[tf.SparseTensor]: ...
def take(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1]: ...
def take_while(self, predicate: Callable[[_T1], ScalarTensorCompatible], name: str | None = None) -> Dataset[_T1]: ...
def unbatch(self, name: str | None = None) -> Dataset[_T1]: ...
def unique(self, name: str | None = None) -> Dataset[_T1]: ...
def take(self, count: ScalarTensorCompatible, name: str | None = None) -> Dataset[_T1_co]: ...
def take_while(self, predicate: Callable[[_T1_co], ScalarTensorCompatible], name: str | None = None) -> Dataset[_T1_co]: ...
def unbatch(self, name: str | None = None) -> Dataset[_T1_co]: ...
def unique(self, name: str | None = None) -> Dataset[_T1_co]: ...
def window(
self,
size: ScalarTensorCompatible,
@@ -225,8 +225,8 @@ class Dataset(ABC, Generic[_T1]):
stride: ScalarTensorCompatible = 1,
drop_remainder: bool = False,
name: str | None = None,
) -> Dataset[Dataset[_T1]]: ...
def with_options(self, options: Options, name: str | None = None) -> Dataset[_T1]: ...
) -> Dataset[Dataset[_T1_co]]: ...
def with_options(self, options: Options, name: str | None = None) -> Dataset[_T1_co]: ...
@overload
@staticmethod
def zip(
@@ -3,10 +3,10 @@ from typing import Generic, TypeVar
from tensorflow._aliases import AnyArray
_Value = TypeVar("_Value", covariant=True)
_Value_co = TypeVar("_Value_co", covariant=True)
class RemoteValue(Generic[_Value]):
class RemoteValue(Generic[_Value_co]):
def fetch(self) -> AnyArray: ...
def get(self) -> _Value: ...
def get(self) -> _Value_co: ...
def __getattr__(name: str) -> Incomplete: ...
@@ -11,8 +11,8 @@ from tensorflow.keras.constraints import Constraint
from tensorflow.keras.initializers import _Initializer
from tensorflow.keras.regularizers import Regularizer, _Regularizer
_InputT = TypeVar("_InputT", contravariant=True)
_OutputT = TypeVar("_OutputT", covariant=True)
_InputT_contra = TypeVar("_InputT_contra", contravariant=True)
_OutputT_co = TypeVar("_OutputT_co", covariant=True)
class InputSpec:
dtype: str | None
@@ -39,9 +39,9 @@ class InputSpec:
# Most layers have input and output type of just Tensor and when we support default type variables,
# maybe worth trying.
class Layer(tf.Module, Generic[_InputT, _OutputT]):
class Layer(tf.Module, Generic[_InputT_contra, _OutputT_co]):
# The most general type is ContainerGeneric[InputSpec] as it really
# depends on _InputT. For most Layers it is just InputSpec
# depends on _InputT_contra. For most Layers it is just InputSpec
# though. Maybe describable with HKT?
input_spec: InputSpec | Any
@@ -65,11 +65,13 @@ class Layer(tf.Module, Generic[_InputT, _OutputT]):
# *args/**kwargs are allowed, but have obscure footguns and tensorflow documentation discourages their usage.
# First argument will automatically be cast to layer's compute dtype, but any other tensor arguments will not be.
# Also various tensorflow tools/apis can misbehave if they encounter a layer with *args/**kwargs.
def __call__(self, inputs: _InputT, *, training: bool = False, mask: TensorCompatible | None = None) -> _OutputT: ...
def call(self, inputs: _InputT, /) -> _OutputT: ...
def __call__(
self, inputs: _InputT_contra, *, training: bool = False, mask: TensorCompatible | None = None
) -> _OutputT_co: ...
def call(self, inputs: _InputT_contra, /) -> _OutputT_co: ...
# input_shape's real type depends on _InputT, but we can't express that without HKT.
# For example _InputT tf.Tensor -> tf.TensorShape, _InputT dict[str, tf.Tensor] -> dict[str, tf.TensorShape].
# input_shape's real type depends on _InputT_contra, but we can't express that without HKT.
# For example _InputT_contra tf.Tensor -> tf.TensorShape, _InputT_contra dict[str, tf.Tensor] -> dict[str, tf.TensorShape].
def build(self, input_shape: Any, /) -> None: ...
@overload
def compute_output_shape(self: Layer[tf.Tensor, tf.Tensor], input_shape: tf.TensorShape, /) -> tf.TensorShape: ...
+11 -9
View File
@@ -9,14 +9,14 @@ import numpy.typing as npt
import tensorflow as tf
from tensorflow import Variable
from tensorflow._aliases import ContainerGeneric, ShapeLike, TensorCompatible
from tensorflow.keras.layers import Layer, _InputT, _OutputT
from tensorflow.keras.layers import Layer, _InputT_contra, _OutputT_co
from tensorflow.keras.optimizers import Optimizer
_Loss: TypeAlias = str | tf.keras.losses.Loss | Callable[[TensorCompatible, TensorCompatible], tf.Tensor]
_Metric: TypeAlias = str | tf.keras.metrics.Metric | Callable[[TensorCompatible, TensorCompatible], tf.Tensor] | None
# Missing keras.src.backend.tensorflow.trainer.TensorFlowTrainer as a base class, which is not exposed by tensorflow
class Model(Layer[_InputT, _OutputT]):
class Model(Layer[_InputT_contra, _OutputT_co]):
_train_counter: tf.Variable
_test_counter: tf.Variable
optimizer: Optimizer | None
@@ -27,13 +27,15 @@ class Model(Layer[_InputT, _OutputT]):
) -> tf.Tensor | None: ...
stop_training: bool
def __new__(cls, *args: Any, **kwargs: Any) -> Model[_InputT, _OutputT]: ...
def __new__(cls, *args: Any, **kwargs: Any) -> Model[_InputT_contra, _OutputT_co]: ...
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def __setattr__(self, name: str, value: Any) -> None: ...
def __reduce__(self): ...
def build(self, input_shape: ShapeLike) -> None: ...
def __call__(self, inputs: _InputT, *, training: bool = False, mask: TensorCompatible | None = None) -> _OutputT: ...
def call(self, inputs: _InputT, training: bool | None = None, mask: TensorCompatible | None = None) -> _OutputT: ...
def __call__(
self, inputs: _InputT_contra, *, training: bool = False, mask: TensorCompatible | None = None
) -> _OutputT_co: ...
def call(self, inputs: _InputT_contra, training: bool | None = None, mask: TensorCompatible | None = None) -> _OutputT_co: ...
# Ideally loss/metrics/output would share the same structure but higher kinded types are not supported.
def compile(
self,
@@ -106,8 +108,8 @@ class Model(Layer[_InputT, _OutputT]):
return_dict: bool = False,
**kwargs: Any,
) -> float | list[float]: ...
def predict_step(self, data: _InputT) -> _OutputT: ...
def make_predict_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], _OutputT]: ...
def predict_step(self, data: _InputT_contra) -> _OutputT_co: ...
def make_predict_function(self, force: bool = False) -> Callable[[tf.data.Iterator[Incomplete]], _OutputT_co]: ...
def predict(
self,
x: TensorCompatible | tf.data.Dataset[Incomplete],
@@ -115,7 +117,7 @@ class Model(Layer[_InputT, _OutputT]):
verbose: Literal["auto", 0, 1, 2] = "auto",
steps: int | None = None,
callbacks: list[tf.keras.callbacks.Callback] | None = None,
) -> _OutputT: ...
) -> _OutputT_co: ...
def reset_metrics(self) -> None: ...
def train_on_batch(
self,
@@ -132,7 +134,7 @@ class Model(Layer[_InputT, _OutputT]):
sample_weight: npt.NDArray[np.float64] | None = None,
return_dict: bool = False,
) -> float | list[float]: ...
def predict_on_batch(self, x: Iterator[_InputT]) -> npt.NDArray[Incomplete]: ...
def predict_on_batch(self, x: Iterator[_InputT_contra]) -> npt.NDArray[Incomplete]: ...
@property
def trainable_weights(self) -> list[Variable]: ...
@property
@@ -10,7 +10,7 @@ from tensorflow.saved_model.experimental import VariablePolicy
from tensorflow.types.experimental import ConcreteFunction, PolymorphicFunction
_P = ParamSpec("_P")
_R = TypeVar("_R", covariant=True)
_R_co = TypeVar("_R_co", covariant=True)
class Asset:
@property
@@ -77,10 +77,10 @@ class SaveOptions:
def contains_saved_model(export_dir: str | Path) -> bool: ...
class _LoadedAttributes(Generic[_P, _R]):
signatures: Mapping[str, ConcreteFunction[_P, _R]]
class _LoadedAttributes(Generic[_P, _R_co]):
signatures: Mapping[str, ConcreteFunction[_P, _R_co]]
class _LoadedModel(AutoTrackable, _LoadedAttributes[_P, _R]):
class _LoadedModel(AutoTrackable, _LoadedAttributes[_P, _R_co]):
variables: list[tf.Variable]
trainable_variables: list[tf.Variable]
# TF1 model artifact specific
@@ -7,23 +7,23 @@ import tensorflow as tf
from tensorflow._aliases import ContainerGeneric
_P = ParamSpec("_P")
_R = TypeVar("_R", covariant=True)
_R_co = TypeVar("_R_co", covariant=True)
class Callable(Generic[_P, _R], metaclass=abc.ABCMeta):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
class Callable(Generic[_P, _R_co], metaclass=abc.ABCMeta):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ...
class ConcreteFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
class ConcreteFunction(Callable[_P, _R_co], metaclass=abc.ABCMeta):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ...
class PolymorphicFunction(Callable[_P, _R], metaclass=abc.ABCMeta):
class PolymorphicFunction(Callable[_P, _R_co], metaclass=abc.ABCMeta):
@overload
@abc.abstractmethod
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R]: ...
def get_concrete_function(self, *args: _P.args, **kwargs: _P.kwargs) -> ConcreteFunction[_P, _R_co]: ...
@overload
@abc.abstractmethod
def get_concrete_function(
self, *args: ContainerGeneric[tf.TypeSpec[Any]], **kwargs: ContainerGeneric[tf.TypeSpec[Any]]
) -> ConcreteFunction[_P, _R]: ...
) -> ConcreteFunction[_P, _R_co]: ...
def experimental_get_compiler_ir(self, *args, **kwargs): ...
GenericFunction = PolymorphicFunction