diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 568c8931f..9e04ac4e5 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -1202,9 +1202,15 @@ def help(request: object = ...) -> None: ... def hex(__number: int | SupportsIndex) -> str: ... def id(__obj: object) -> int: ... def input(__prompt: object = ...) -> str: ... + +class _GetItemIterable(Protocol[_T_co]): + def __getitem__(self, __i: int) -> _T_co: ... + @overload def iter(__iterable: SupportsIter[_SupportsNextT]) -> _SupportsNextT: ... @overload +def iter(__iterable: _GetItemIterable[_T]) -> Iterator[_T]: ... +@overload def iter(__function: Callable[[], _T | None], __sentinel: None) -> Iterator[_T]: ... @overload def iter(__function: Callable[[], _T], __sentinel: object) -> Iterator[_T]: ... diff --git a/test_cases/stdlib/builtins/test_iteration.py b/test_cases/stdlib/builtins/test_iteration.py new file mode 100644 index 000000000..7195c8bcd --- /dev/null +++ b/test_cases/stdlib/builtins/test_iteration.py @@ -0,0 +1,14 @@ +from typing import Iterator +from typing_extensions import assert_type + + +class OldStyleIter: + def __getitem__(self, index: int) -> str: + return str(index) + + +for x in iter(OldStyleIter()): + assert_type(x, str) + +assert_type(iter(OldStyleIter()), Iterator[str]) +assert_type(next(iter(OldStyleIter())), str)