diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index f7b1eb4b0..f25777eea 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -50,7 +50,6 @@ from typing import ( # noqa: Y027 SupportsComplex, SupportsFloat, SupportsInt, - SupportsRound, TypeVar, overload, type_check_only, @@ -1619,12 +1618,21 @@ class reversed(Iterator[_T], Generic[_T]): def __length_hint__(self) -> int: ... def repr(__obj: object) -> str: ... + +# See https://github.com/python/typeshed/pull/9141 +# and https://github.com/python/typeshed/pull/9151 +# on why we don't use `SupportsRound` from `typing.pyi` + +class _SupportsRound1(Protocol[_T_co]): + def __round__(self) -> _T_co: ... + +class _SupportsRound2(Protocol[_T_co]): + def __round__(self, __ndigits: int) -> _T_co: ... + @overload -def round(number: SupportsRound[Any]) -> int: ... +def round(number: _SupportsRound1[_T], ndigits: None = ...) -> _T: ... @overload -def round(number: SupportsRound[Any], ndigits: None) -> int: ... -@overload -def round(number: SupportsRound[_T], ndigits: SupportsIndex) -> _T: ... +def round(number: _SupportsRound2[_T], ndigits: SupportsIndex) -> _T: ... # See https://github.com/python/typeshed/pull/6292#discussion_r748875189 # for why arg 3 of `setattr` should be annotated with `Any` and not `object` diff --git a/test_cases/stdlib/builtins/check_round.py b/test_cases/stdlib/builtins/check_round.py new file mode 100644 index 000000000..2648208bf --- /dev/null +++ b/test_cases/stdlib/builtins/check_round.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from typing import overload +from typing_extensions import assert_type + + +class CustomIndex: + def __index__(self) -> int: + return 1 + + +# float: + +assert_type(round(5.5), int) +assert_type(round(5.5, None), int) +assert_type(round(5.5, 0), float) +assert_type(round(5.5, 1), float) +assert_type(round(5.5, 5), float) +assert_type(round(5.5, CustomIndex()), float) + +# int: + +assert_type(round(1), int) +assert_type(round(1, 1), int) +assert_type(round(1, None), int) +assert_type(round(1, CustomIndex()), int) + +# Protocols: + + +class WithCustomRound1: + def __round__(self) -> str: + return "a" + + +assert_type(round(WithCustomRound1()), str) +assert_type(round(WithCustomRound1(), None), str) +# Errors: +round(WithCustomRound1(), 1) # type: ignore +round(WithCustomRound1(), CustomIndex()) # type: ignore + + +class WithCustomRound2: + def __round__(self, digits: int) -> str: + return "a" + + +assert_type(round(WithCustomRound2(), 1), str) +assert_type(round(WithCustomRound2(), CustomIndex()), str) +# Errors: +round(WithCustomRound2(), None) # type: ignore +round(WithCustomRound2()) # type: ignore + + +class WithOverloadedRound: + @overload + def __round__(self, ndigits: None = ...) -> str: + ... + + @overload + def __round__(self, ndigits: int) -> bytes: + ... + + def __round__(self, ndigits: int | None = None) -> str | bytes: + return b"" if ndigits is None else "" + + +assert_type(round(WithOverloadedRound()), str) +assert_type(round(WithOverloadedRound(), None), str) +assert_type(round(WithOverloadedRound(), 1), bytes)