From 7ccbbdb30acabfb272e35d3bd1c2d4a1a32aa3d9 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 2 Feb 2022 18:14:57 +0000 Subject: [PATCH] stdlib: Improve many `__iter__` and constructor methods (#7112) --- stdlib/_py_abc.pyi | 3 ++- stdlib/builtins.pyi | 10 +++++----- stdlib/decimal.pyi | 2 +- stdlib/dis.pyi | 3 ++- stdlib/email/headerregistry.pyi | 3 ++- stdlib/fractions.pyi | 4 ++-- stdlib/sqlite3/dbapi2.pyi | 2 +- stdlib/tarfile.pyi | 5 +++-- stdlib/traceback.pyi | 10 +++++----- stdlib/unittest/mock.pyi | 12 +++++++++--- stdlib/zipfile.pyi | 6 ++++-- 11 files changed, 36 insertions(+), 24 deletions(-) diff --git a/stdlib/_py_abc.pyi b/stdlib/_py_abc.pyi index c7794079d..ddf04364a 100644 --- a/stdlib/_py_abc.pyi +++ b/stdlib/_py_abc.pyi @@ -1,3 +1,4 @@ +from _typeshed import Self from typing import Any, NewType, TypeVar _T = TypeVar("_T") @@ -7,5 +8,5 @@ _CacheToken = NewType("_CacheToken", int) def get_cache_token() -> _CacheToken: ... class ABCMeta(type): - def __new__(__mcls, __name: str, __bases: tuple[type[Any], ...], __namespace: dict[str, Any]) -> ABCMeta: ... + def __new__(__mcls: type[Self], __name: str, __bases: tuple[type[Any], ...], __namespace: dict[str, Any]) -> Self: ... def register(cls, subclass: type[_T]) -> type[_T]: ... diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index a42cd0423..3aa74fa53 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -964,7 +964,7 @@ class frozenset(AbstractSet[_T_co], Generic[_T_co]): class enumerate(Iterator[tuple[int, _T]], Generic[_T]): def __init__(self, iterable: Iterable[_T], start: int = ...) -> None: ... - def __iter__(self) -> Iterator[tuple[int, _T]]: ... + def __iter__(self: Self) -> Self: ... def __next__(self) -> tuple[int, _T]: ... if sys.version_info >= (3, 9): def __class_getitem__(cls, __item: Any) -> GenericAlias: ... @@ -1095,7 +1095,7 @@ class filter(Iterator[_T], Generic[_T]): def __init__(self, __function: Callable[[_S], TypeGuard[_T]], __iterable: Iterable[_S]) -> None: ... @overload def __init__(self, __function: Callable[[_T], Any], __iterable: Iterable[_T]) -> None: ... - def __iter__(self) -> Iterator[_T]: ... + def __iter__(self: Self) -> Self: ... def __next__(self) -> _T: ... def format(__value: object, __format_spec: str = ...) -> str: ... # TODO unicode @@ -1186,7 +1186,7 @@ class map(Iterator[_S], Generic[_S]): __iter6: Iterable[Any], *iterables: Iterable[Any], ) -> None: ... - def __iter__(self) -> Iterator[_S]: ... + def __iter__(self: Self) -> Self: ... def __next__(self) -> _S: ... @overload @@ -1699,7 +1699,7 @@ if sys.version_info >= (3, 11): _SplitCondition = type[BaseException] | tuple[type[BaseException], ...] | Callable[[BaseException], bool] class BaseExceptionGroup(BaseException): - def __new__(cls, __message: str, __exceptions: Sequence[BaseException]) -> BaseExceptionGroup | ExceptionGroup: ... + def __new__(cls: type[Self], __message: str, __exceptions: Sequence[BaseException]) -> Self: ... @property def message(self) -> str: ... @property @@ -1709,6 +1709,6 @@ if sys.version_info >= (3, 11): def derive(self: Self, __excs: Sequence[BaseException]) -> Self: ... class ExceptionGroup(BaseExceptionGroup, Exception): - def __new__(cls, __message: str, __exceptions: Sequence[Exception]) -> ExceptionGroup: ... + def __new__(cls: type[Self], __message: str, __exceptions: Sequence[Exception]) -> Self: ... @property def exceptions(self) -> tuple[Exception, ...]: ... diff --git a/stdlib/decimal.pyi b/stdlib/decimal.pyi index a7fe060a4..c888e5874 100644 --- a/stdlib/decimal.pyi +++ b/stdlib/decimal.pyi @@ -52,7 +52,7 @@ def localcontext(ctx: Context | None = ...) -> _ContextManager: ... class Decimal: def __new__(cls: type[Self], value: _DecimalNew = ..., context: Context | None = ...) -> Self: ... @classmethod - def from_float(cls, __f: float) -> Decimal: ... + def from_float(cls: type[Self], __f: float) -> Self: ... def __bool__(self) -> bool: ... def compare(self, other: _Decimal, context: Context | None = ...) -> Decimal: ... def __hash__(self) -> int: ... diff --git a/stdlib/dis.pyi b/stdlib/dis.pyi index 8c423a424..dcb82cb7a 100644 --- a/stdlib/dis.pyi +++ b/stdlib/dis.pyi @@ -1,5 +1,6 @@ import sys import types +from _typeshed import Self from opcode import * # `dis` re-exports it as a part of public API from typing import IO, Any, Callable, Iterator, NamedTuple, Union @@ -54,7 +55,7 @@ class Bytecode: def info(self) -> str: ... def dis(self) -> str: ... @classmethod - def from_traceback(cls, tb: types.TracebackType) -> Bytecode: ... + def from_traceback(cls: type[Self], tb: types.TracebackType) -> Self: ... COMPILER_FLAG_NAMES: dict[int, str] diff --git a/stdlib/email/headerregistry.pyi b/stdlib/email/headerregistry.pyi index 1dabf6a4a..39fe82a18 100644 --- a/stdlib/email/headerregistry.pyi +++ b/stdlib/email/headerregistry.pyi @@ -1,5 +1,6 @@ import sys import types +from _typeshed import Self from collections.abc import Iterable, Mapping from datetime import datetime as _datetime from email._header_value_parser import ( @@ -23,7 +24,7 @@ class BaseHeader(str): def name(self) -> str: ... @property def defects(self) -> tuple[MessageDefect, ...]: ... - def __new__(cls, name: str, value: Any) -> BaseHeader: ... + def __new__(cls: type[Self], name: str, value: Any) -> Self: ... def init(self, name: str, *, parse_tree: TokenList, defects: Iterable[MessageDefect]) -> None: ... def fold(self, *, policy: Policy) -> str: ... diff --git a/stdlib/fractions.pyi b/stdlib/fractions.pyi index f9b98765e..5f404b015 100644 --- a/stdlib/fractions.pyi +++ b/stdlib/fractions.pyi @@ -25,9 +25,9 @@ class Fraction(Rational): @overload def __new__(cls: type[Self], __value: float | Decimal | str, *, _normalize: bool = ...) -> Self: ... @classmethod - def from_float(cls, f: float) -> Fraction: ... + def from_float(cls: type[Self], f: float) -> Self: ... @classmethod - def from_decimal(cls, dec: Decimal) -> Fraction: ... + def from_decimal(cls: type[Self], dec: Decimal) -> Self: ... def limit_denominator(self, max_denominator: int = ...) -> Fraction: ... if sys.version_info >= (3, 8): def as_integer_ratio(self) -> tuple[int, int]: ... diff --git a/stdlib/sqlite3/dbapi2.pyi b/stdlib/sqlite3/dbapi2.pyi index 1c5960ad5..820b40baa 100644 --- a/stdlib/sqlite3/dbapi2.pyi +++ b/stdlib/sqlite3/dbapi2.pyi @@ -204,7 +204,7 @@ class Cursor(Iterator[Any]): def fetchone(self) -> Any: ... def setinputsizes(self, __sizes: object) -> None: ... # does nothing def setoutputsize(self, __size: object, __column: object = ...) -> None: ... # does nothing - def __iter__(self) -> Cursor: ... + def __iter__(self: Self) -> Self: ... def __next__(self) -> Any: ... class DataError(DatabaseError): ... diff --git a/stdlib/tarfile.pyi b/stdlib/tarfile.pyi index 640cd5188..d6282494d 100644 --- a/stdlib/tarfile.pyi +++ b/stdlib/tarfile.pyi @@ -2,6 +2,7 @@ import bz2 import io import sys from _typeshed import Self, StrOrBytesPath, StrPath +from builtins import type as Type # alias to avoid name clashes with fields named "type" from collections.abc import Callable, Iterable, Iterator, Mapping from gzip import _ReadableFileobj as _GzipReadableFileobj, _WritableFileobj as _GzipWritableFileobj from types import TracebackType @@ -339,9 +340,9 @@ class TarInfo: pax_headers: Mapping[str, str] def __init__(self, name: str = ...) -> None: ... @classmethod - def frombuf(cls, buf: bytes, encoding: str, errors: str) -> TarInfo: ... + def frombuf(cls: Type[Self], buf: bytes, encoding: str, errors: str) -> Self: ... @classmethod - def fromtarfile(cls, tarfile: TarFile) -> TarInfo: ... + def fromtarfile(cls: Type[Self], tarfile: TarFile) -> Self: ... @property def linkpath(self) -> str: ... @linkpath.setter diff --git a/stdlib/traceback.pyi b/stdlib/traceback.pyi index 30d67d8dc..0e946762f 100644 --- a/stdlib/traceback.pyi +++ b/stdlib/traceback.pyi @@ -1,5 +1,5 @@ import sys -from _typeshed import SupportsWrite +from _typeshed import Self, SupportsWrite from types import FrameType, TracebackType from typing import IO, Any, Generator, Iterable, Iterator, Mapping, Optional, overload @@ -98,14 +98,14 @@ class TracebackException: ) -> None: ... @classmethod def from_exception( - cls, + cls: type[Self], exc: BaseException, *, limit: int | None = ..., lookup_lines: bool = ..., capture_locals: bool = ..., compact: bool = ..., - ) -> TracebackException: ... + ) -> Self: ... else: def __init__( self, @@ -120,8 +120,8 @@ class TracebackException: ) -> None: ... @classmethod def from_exception( - cls, exc: BaseException, *, limit: int | None = ..., lookup_lines: bool = ..., capture_locals: bool = ... - ) -> TracebackException: ... + cls: type[Self], exc: BaseException, *, limit: int | None = ..., lookup_lines: bool = ..., capture_locals: bool = ... + ) -> Self: ... def format(self, *, chain: bool = ...) -> Generator[str, None, None]: ... def format_exception_only(self) -> Generator[str, None, None]: ... diff --git a/stdlib/unittest/mock.pyi b/stdlib/unittest/mock.pyi index 4e26ac01f..1f0a02dfe 100644 --- a/stdlib/unittest/mock.pyi +++ b/stdlib/unittest/mock.pyi @@ -1,4 +1,5 @@ import sys +from _typeshed import Self from typing import Any, Awaitable, Callable, Generic, Iterable, Mapping, Sequence, TypeVar, overload from typing_extensions import Literal @@ -76,8 +77,13 @@ DEFAULT: Any class _Call(tuple[Any, ...]): def __new__( - cls, value: Any = ..., name: Any | None = ..., parent: Any | None = ..., two: bool = ..., from_kall: bool = ... - ) -> Any: ... + cls: type[Self], + value: Any = ..., + name: Any | None = ..., + parent: Any | None = ..., + two: bool = ..., + from_kall: bool = ..., + ) -> Self: ... name: Any parent: Any from_kall: Any @@ -105,7 +111,7 @@ class Base: def __init__(self, *args: Any, **kwargs: Any) -> None: ... class NonCallableMock(Base, Any): - def __new__(__cls, *args: Any, **kw: Any) -> NonCallableMock: ... + def __new__(__cls: type[Self], *args: Any, **kw: Any) -> Self: ... def __init__( self, spec: list[str] | object | type[object] | None = ..., diff --git a/stdlib/zipfile.pyi b/stdlib/zipfile.pyi index a31a06c16..ec64b02e9 100644 --- a/stdlib/zipfile.pyi +++ b/stdlib/zipfile.pyi @@ -212,10 +212,12 @@ class ZipInfo: def __init__(self, filename: str = ..., date_time: _DateTuple = ...) -> None: ... if sys.version_info >= (3, 8): @classmethod - def from_file(cls, filename: StrPath, arcname: StrPath | None = ..., *, strict_timestamps: bool = ...) -> ZipInfo: ... + def from_file( + cls: type[Self], filename: StrPath, arcname: StrPath | None = ..., *, strict_timestamps: bool = ... + ) -> Self: ... else: @classmethod - def from_file(cls, filename: StrPath, arcname: StrPath | None = ...) -> ZipInfo: ... + def from_file(cls: type[Self], filename: StrPath, arcname: StrPath | None = ...) -> Self: ... def is_dir(self) -> bool: ... def FileHeader(self, zip64: bool | None = ...) -> bytes: ...