diff --git a/stdlib/_typeshed/__init__.pyi b/stdlib/_typeshed/__init__.pyi index 3e5601e9b..aef6b553e 100644 --- a/stdlib/_typeshed/__init__.pyi +++ b/stdlib/_typeshed/__init__.pyi @@ -69,8 +69,11 @@ SupportsRichComparisonT = TypeVar("SupportsRichComparisonT", bound=SupportsRichC # Dunder protocols -class SupportsAdd(Protocol): - def __add__(self, __x: Any) -> Any: ... +class SupportsAdd(Protocol[_T_contra, _T_co]): + def __add__(self, __x: _T_contra) -> _T_co: ... + +class SupportsRAdd(Protocol[_T_contra, _T_co]): + def __radd__(self, __x: _T_contra) -> _T_co: ... class SupportsDivMod(Protocol[_T_contra, _T_co]): def __divmod__(self, __other: _T_contra) -> _T_co: ... diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 34b5a9e61..577d5fd99 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -20,6 +20,7 @@ from _typeshed import ( SupportsKeysAndGetItem, SupportsLenAndGetItem, SupportsNext, + SupportsRAdd, SupportsRDivMod, SupportsRichComparison, SupportsRichComparisonT, @@ -1637,8 +1638,12 @@ def sorted( @overload def sorted(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> list[_T]: ... -_SumT = TypeVar("_SumT", bound=SupportsAdd) -_SumS = TypeVar("_SumS", bound=SupportsAdd) +_AddableT1 = TypeVar("_AddableT1", bound=SupportsAdd[Any, Any]) +_AddableT2 = TypeVar("_AddableT2", bound=SupportsAdd[Any, Any]) + +class _SupportsSumWithNoDefaultGiven(SupportsAdd[Any, Any], SupportsRAdd[int, Any], Protocol): ... + +_SupportsSumNoDefaultT = TypeVar("_SupportsSumNoDefaultT", bound=_SupportsSumWithNoDefaultGiven) # In general, the return type of `x + x` is *not* guaranteed to be the same type as x. # However, we can't express that in the stub for `sum()` @@ -1653,15 +1658,15 @@ else: def sum(__iterable: Iterable[bool], __start: int = ...) -> int: ... # type: ignore[misc] @overload -def sum(__iterable: Iterable[_SumT]) -> _SumT | Literal[0]: ... +def sum(__iterable: Iterable[_SupportsSumNoDefaultT]) -> _SupportsSumNoDefaultT | Literal[0]: ... if sys.version_info >= (3, 8): @overload - def sum(__iterable: Iterable[_SumT], start: _SumS) -> _SumT | _SumS: ... + def sum(__iterable: Iterable[_AddableT1], start: _AddableT2) -> _AddableT1 | _AddableT2: ... else: @overload - def sum(__iterable: Iterable[_SumT], __start: _SumS) -> _SumT | _SumS: ... + def sum(__iterable: Iterable[_AddableT1], __start: _AddableT2) -> _AddableT1 | _AddableT2: ... # The argument to `vars()` has to have a `__dict__` attribute, so can't be annotated with `object` # (A "SupportsDunderDict" protocol doesn't work) diff --git a/test_cases/stdlib/builtins/test_sum.py b/test_cases/stdlib/builtins/test_sum.py new file mode 100644 index 000000000..c41f895fe --- /dev/null +++ b/test_cases/stdlib/builtins/test_sum.py @@ -0,0 +1,52 @@ +# pyright: reportUnnecessaryTypeIgnoreComment=true + +from typing import Any, List, Union +from typing_extensions import Literal, assert_type + + +class Foo: + def __add__(self, other: Any) -> "Foo": + return Foo() + + +class Bar: + def __radd__(self, other: Any) -> "Bar": + return Bar() + + +class Baz: + def __add__(self, other: Any) -> "Baz": + return Baz() + + def __radd__(self, other: Any) -> "Baz": + return Baz() + + +assert_type(sum([2, 4]), int) +assert_type(sum([3, 5], 4), int) + +assert_type(sum([True, False]), int) +assert_type(sum([True, False], True), int) + +assert_type(sum([["foo"], ["bar"]], ["baz"]), List[str]) + +assert_type(sum([Foo(), Foo()], Foo()), Foo) +assert_type(sum([Baz(), Baz()]), Union[Baz, Literal[0]]) + +# mypy and pyright infer the types differently for these, so we can't use assert_type +# Just test that no error is emitted for any of these +sum([("foo",), ("bar", "baz")], ()) # mypy: `tuple[str, ...]`; pyright: `tuple[()] | tuple[str] | tuple[str, str]` +sum([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]` +sum([2.5, 5.8], 5) # mypy: `float`; pyright: `float | int` + +# These all fail at runtime +sum("abcde") # type: ignore[arg-type] +sum([["foo"], ["bar"]]) # type: ignore[list-item] +sum([("foo",), ("bar", "baz")]) # type: ignore[list-item] +sum([Foo(), Foo()]) # type: ignore[list-item] +sum([Bar(), Bar()], Bar()) # type: ignore[call-overload] +sum([Bar(), Bar()]) # type: ignore[list-item] + +# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error: +# sum([3, Fraction(7, 22), complex(8, 0), 9.83]) +# sum([3, Decimal('0.98')])