Improving type support for math.prod (#13572)

This commit is contained in:
cake-monotone
2025-03-04 00:41:34 +09:00
committed by GitHub
parent 175e700656
commit 9f11db4296
3 changed files with 92 additions and 3 deletions
+63
View File
@@ -0,0 +1,63 @@
from __future__ import annotations
from decimal import Decimal
from fractions import Fraction
from math import prod
from typing import Any, Literal, Union
from typing_extensions import assert_type
class SupportsMul:
def __mul__(self, other: Any) -> SupportsMul:
return SupportsMul()
class SupportsRMul:
def __rmul__(self, other: Any) -> SupportsRMul:
return SupportsRMul()
class SupportsMulAndRMul:
def __mul__(self, other: Any) -> SupportsMulAndRMul:
return SupportsMulAndRMul()
def __rmul__(self, other: Any) -> SupportsMulAndRMul:
return SupportsMulAndRMul()
literal_list: list[Literal[0, 1]] = [0, 1, 1]
assert_type(prod([2, 4]), int)
assert_type(prod([3, 5], start=4), int)
assert_type(prod([True, False]), int)
assert_type(prod([True, False], start=True), int)
assert_type(prod(literal_list), int)
assert_type(prod([SupportsMul(), SupportsMul()], start=SupportsMul()), SupportsMul)
assert_type(prod([SupportsMulAndRMul(), SupportsMulAndRMul()]), Union[SupportsMulAndRMul, Literal[1]])
assert_type(prod([5.6, 3.2]), Union[float, Literal[1]])
assert_type(prod([5.6, 3.2], start=3), Union[float, int])
assert_type(prod([Fraction(7, 2), Fraction(3, 5)]), Union[Fraction, Literal[1]])
assert_type(prod([Fraction(7, 2), Fraction(3, 5)], start=Fraction(1)), Fraction)
assert_type(prod([Decimal("3.14"), Decimal("2.71")]), Union[Decimal, Literal[1]])
assert_type(prod([Decimal("3.14"), Decimal("2.71")], start=Decimal("1.00")), Decimal)
assert_type(prod([complex(7, 2), complex(3, 5)]), Union[complex, Literal[1]])
assert_type(prod([complex(7, 2), complex(3, 5)], start=complex(1, 0)), complex)
# mypy and pyright infer the types differently for these, so we can't use assert_type
# Just test that no error is emitted for any of these
prod([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]`
prod([2.5, 5.8], start=5) # mypy: `float`; pyright: `float | int`
# These all fail at runtime
prod([SupportsMul(), SupportsMul()]) # type: ignore
prod([SupportsRMul(), SupportsRMul()], start=SupportsRMul()) # type: ignore
prod([SupportsRMul(), SupportsRMul()]) # type: ignore
# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error:
# prod([3, Fraction(7, 22), complex(8, 0), 9.83])
# prod([3, Decimal("0.98")])
+6
View File
@@ -117,6 +117,12 @@ class SupportsSub(Protocol[_T_contra, _T_co]):
class SupportsRSub(Protocol[_T_contra, _T_co]):
def __rsub__(self, x: _T_contra, /) -> _T_co: ...
class SupportsMul(Protocol[_T_contra, _T_co]):
def __mul__(self, x: _T_contra, /) -> _T_co: ...
class SupportsRMul(Protocol[_T_contra, _T_co]):
def __rmul__(self, x: _T_contra, /) -> _T_co: ...
class SupportsDivMod(Protocol[_T_contra, _T_co]):
def __divmod__(self, other: _T_contra, /) -> _T_co: ...
+23 -3
View File
@@ -1,6 +1,7 @@
import sys
from _typeshed import SupportsMul, SupportsRMul
from collections.abc import Iterable
from typing import Final, Protocol, SupportsFloat, SupportsIndex, TypeVar, overload
from typing import Any, Final, Literal, Protocol, SupportsFloat, SupportsIndex, TypeVar, overload
from typing_extensions import TypeAlias
_T = TypeVar("_T")
@@ -99,10 +100,29 @@ elif sys.version_info >= (3, 9):
def perm(n: SupportsIndex, k: SupportsIndex | None = None, /) -> int: ...
def pow(x: _SupportsFloatOrIndex, y: _SupportsFloatOrIndex, /) -> float: ...
_PositiveInteger: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]
_NegativeInteger: TypeAlias = Literal[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20]
_LiteralInteger = _PositiveInteger | _NegativeInteger | Literal[0] # noqa: Y026 # TODO: Use TypeAlias once mypy bugs are fixed
_MultiplicableT1 = TypeVar("_MultiplicableT1", bound=SupportsMul[Any, Any])
_MultiplicableT2 = TypeVar("_MultiplicableT2", bound=SupportsMul[Any, Any])
class _SupportsProdWithNoDefaultGiven(SupportsMul[Any, Any], SupportsRMul[int, Any], Protocol): ...
_SupportsProdNoDefaultT = TypeVar("_SupportsProdNoDefaultT", bound=_SupportsProdWithNoDefaultGiven)
# This stub is based on the type stub for `builtins.sum`.
# Like `builtins.sum`, it cannot be precisely represented in a type stub
# without introducing many false positives.
# For more details on its limitations and false positives, see #13572.
# Instead, just like `builtins.sum`, we explicitly handle several useful cases.
@overload
def prod(iterable: Iterable[SupportsIndex], /, *, start: SupportsIndex = 1) -> int: ... # type: ignore[overload-overlap]
def prod(iterable: Iterable[bool | _LiteralInteger], /, *, start: int = 1) -> int: ... # type: ignore[overload-overlap]
@overload
def prod(iterable: Iterable[_SupportsFloatOrIndex], /, *, start: _SupportsFloatOrIndex = 1) -> float: ...
def prod(iterable: Iterable[_SupportsProdNoDefaultT], /) -> _SupportsProdNoDefaultT | Literal[1]: ...
@overload
def prod(iterable: Iterable[_MultiplicableT1], /, *, start: _MultiplicableT2) -> _MultiplicableT1 | _MultiplicableT2: ...
def radians(x: _SupportsFloatOrIndex, /) -> float: ...
def remainder(x: _SupportsFloatOrIndex, y: _SupportsFloatOrIndex, /) -> float: ...
def sin(x: _SupportsFloatOrIndex, /) -> float: ...