Add @disjoint_base decorator in the stdlib (#14599)

And fix some other new stubtest finds.
This commit is contained in:
Jelle Zijlstra
2025-08-24 07:27:14 -07:00
committed by GitHub
parent 2565f34946
commit e8ba06f710
55 changed files with 701 additions and 307 deletions
+20 -2
View File
@@ -3,7 +3,7 @@ from _typeshed import MaybeNone
from collections.abc import Callable, Iterable, Iterator
from types import GenericAlias
from typing import Any, Generic, Literal, SupportsComplex, SupportsFloat, SupportsIndex, SupportsInt, TypeVar, overload
from typing_extensions import Self, TypeAlias
from typing_extensions import Self, TypeAlias, disjoint_base
_T = TypeVar("_T")
_S = TypeVar("_S")
@@ -27,6 +27,7 @@ _Predicate: TypeAlias = Callable[[_T], object]
# Technically count can take anything that implements a number protocol and has an add method
# but we can't enforce the add method
@disjoint_base
class count(Generic[_N]):
@overload
def __new__(cls) -> count[int]: ...
@@ -37,11 +38,13 @@ class count(Generic[_N]):
def __next__(self) -> _N: ...
def __iter__(self) -> Self: ...
@disjoint_base
class cycle(Generic[_T]):
def __new__(cls, iterable: Iterable[_T], /) -> Self: ...
def __next__(self) -> _T: ...
def __iter__(self) -> Self: ...
@disjoint_base
class repeat(Generic[_T]):
@overload
def __new__(cls, object: _T) -> Self: ...
@@ -51,6 +54,7 @@ class repeat(Generic[_T]):
def __iter__(self) -> Self: ...
def __length_hint__(self) -> int: ...
@disjoint_base
class accumulate(Generic[_T]):
@overload
def __new__(cls, iterable: Iterable[_T], func: None = None, *, initial: _T | None = ...) -> Self: ...
@@ -59,6 +63,7 @@ class accumulate(Generic[_T]):
def __iter__(self) -> Self: ...
def __next__(self) -> _T: ...
@disjoint_base
class chain(Generic[_T]):
def __new__(cls, *iterables: Iterable[_T]) -> Self: ...
def __next__(self) -> _T: ...
@@ -68,21 +73,25 @@ class chain(Generic[_T]):
def from_iterable(cls: type[Any], iterable: Iterable[Iterable[_S]], /) -> chain[_S]: ...
def __class_getitem__(cls, item: Any, /) -> GenericAlias: ...
@disjoint_base
class compress(Generic[_T]):
def __new__(cls, data: Iterable[_T], selectors: Iterable[Any]) -> Self: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _T: ...
@disjoint_base
class dropwhile(Generic[_T]):
def __new__(cls, predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Self: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _T: ...
@disjoint_base
class filterfalse(Generic[_T]):
def __new__(cls, function: _Predicate[_T] | None, iterable: Iterable[_T], /) -> Self: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _T: ...
@disjoint_base
class groupby(Generic[_T_co, _S_co]):
@overload
def __new__(cls, iterable: Iterable[_T1], key: None = None) -> groupby[_T1, _T1]: ...
@@ -91,6 +100,7 @@ class groupby(Generic[_T_co, _S_co]):
def __iter__(self) -> Self: ...
def __next__(self) -> tuple[_T_co, Iterator[_S_co]]: ...
@disjoint_base
class islice(Generic[_T]):
@overload
def __new__(cls, iterable: Iterable[_T], stop: int | None, /) -> Self: ...
@@ -99,18 +109,20 @@ class islice(Generic[_T]):
def __iter__(self) -> Self: ...
def __next__(self) -> _T: ...
@disjoint_base
class starmap(Generic[_T_co]):
def __new__(cls, function: Callable[..., _T], iterable: Iterable[Iterable[Any]], /) -> starmap[_T]: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _T_co: ...
@disjoint_base
class takewhile(Generic[_T]):
def __new__(cls, predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Self: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _T: ...
def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]: ...
@disjoint_base
class zip_longest(Generic[_T_co]):
# one iterable (fillvalue doesn't matter)
@overload
@@ -189,6 +201,7 @@ class zip_longest(Generic[_T_co]):
def __iter__(self) -> Self: ...
def __next__(self) -> _T_co: ...
@disjoint_base
class product(Generic[_T_co]):
@overload
def __new__(cls, iter1: Iterable[_T1], /) -> product[tuple[_T1]]: ...
@@ -274,6 +287,7 @@ class product(Generic[_T_co]):
def __iter__(self) -> Self: ...
def __next__(self) -> _T_co: ...
@disjoint_base
class permutations(Generic[_T_co]):
@overload
def __new__(cls, iterable: Iterable[_T], r: Literal[2]) -> permutations[tuple[_T, _T]]: ...
@@ -288,6 +302,7 @@ class permutations(Generic[_T_co]):
def __iter__(self) -> Self: ...
def __next__(self) -> _T_co: ...
@disjoint_base
class combinations(Generic[_T_co]):
@overload
def __new__(cls, iterable: Iterable[_T], r: Literal[2]) -> combinations[tuple[_T, _T]]: ...
@@ -302,6 +317,7 @@ class combinations(Generic[_T_co]):
def __iter__(self) -> Self: ...
def __next__(self) -> _T_co: ...
@disjoint_base
class combinations_with_replacement(Generic[_T_co]):
@overload
def __new__(cls, iterable: Iterable[_T], r: Literal[2]) -> combinations_with_replacement[tuple[_T, _T]]: ...
@@ -317,12 +333,14 @@ class combinations_with_replacement(Generic[_T_co]):
def __next__(self) -> _T_co: ...
if sys.version_info >= (3, 10):
@disjoint_base
class pairwise(Generic[_T_co]):
def __new__(cls, iterable: Iterable[_T], /) -> pairwise[tuple[_T, _T]]: ...
def __iter__(self) -> Self: ...
def __next__(self) -> _T_co: ...
if sys.version_info >= (3, 12):
@disjoint_base
class batched(Generic[_T_co]):
if sys.version_info >= (3, 13):
def __new__(cls, iterable: Iterable[_T_co], n: int, *, strict: bool = False) -> Self: ...