From f8dd877e4869147e23fd06cf837fa4d41e176dd9 Mon Sep 17 00:00:00 2001 From: Nikita Sobolev Date: Thu, 10 Nov 2022 17:45:07 +0300 Subject: [PATCH] Improve `math.{ceil,floor,trunc}` (#9141) Co-authored-by: Alex Waygood --- stdlib/math.pyi | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/stdlib/math.pyi b/stdlib/math.pyi index 58eda98d8..ca30acd7e 100644 --- a/stdlib/math.pyi +++ b/stdlib/math.pyi @@ -1,9 +1,11 @@ import sys -from _typeshed import SupportsTrunc from collections.abc import Iterable -from typing import SupportsFloat, overload +from typing import Protocol, SupportsFloat, TypeVar, overload from typing_extensions import SupportsIndex, TypeAlias +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + if sys.version_info >= (3, 8): _SupportsFloatOrIndex: TypeAlias = SupportsFloat | SupportsIndex else: @@ -26,6 +28,12 @@ def atanh(__x: _SupportsFloatOrIndex) -> float: ... if sys.version_info >= (3, 11): def cbrt(__x: _SupportsFloatOrIndex) -> float: ... +class _SupportsCeil(Protocol[_T_co]): + def __ceil__(self) -> _T_co: ... + +@overload +def ceil(__x: _SupportsCeil[_T]) -> _T: ... +@overload def ceil(__x: _SupportsFloatOrIndex) -> int: ... if sys.version_info >= (3, 8): @@ -55,6 +63,12 @@ if sys.version_info >= (3, 8): else: def factorial(__x: int) -> int: ... +class _SupportsFloor(Protocol[_T_co]): + def __floor__(self) -> _T_co: ... + +@overload +def floor(__x: _SupportsFloor[_T]) -> _T: ... +@overload def floor(__x: _SupportsFloatOrIndex) -> int: ... def fmod(__x: _SupportsFloatOrIndex, __y: _SupportsFloatOrIndex) -> float: ... def frexp(__x: _SupportsFloatOrIndex) -> tuple[float, int]: ... @@ -119,7 +133,12 @@ def sinh(__x: _SupportsFloatOrIndex) -> float: ... def sqrt(__x: _SupportsFloatOrIndex) -> float: ... def tan(__x: _SupportsFloatOrIndex) -> float: ... def tanh(__x: _SupportsFloatOrIndex) -> float: ... -def trunc(__x: SupportsTrunc) -> int: ... + +# Is different from `_typeshed.SupportsTrunc`, which is not generic +class _SupportsTrunc(Protocol[_T_co]): + def __trunc__(self) -> _T_co: ... + +def trunc(__x: _SupportsTrunc[_T]) -> _T: ... if sys.version_info >= (3, 9): def ulp(__x: _SupportsFloatOrIndex) -> float: ...